diff --git a/official/cv/resnet/golden_stick/pruner/scop/train.py b/official/cv/resnet/golden_stick/pruner/scop/train.py index 0aea13af94237aa8dc8b9764d51bfed3aff18975..cce1be35da52ff4099d1307b05fb45dd11f113ea 100644 --- a/official/cv/resnet/golden_stick/pruner/scop/train.py +++ b/official/cv/resnet/golden_stick/pruner/scop/train.py @@ -224,7 +224,7 @@ def train_net(): net = resnet(class_num=config.class_num) # apply golden-stick algo - algo_kf = PrunerKfCompressAlgo({}) + algo_kf = PrunerKfCompressAlgo(config) load_fp32_ckpt(net) net = algo_kf.apply(net) lr = get_lr(lr_init=config.lr_init, @@ -261,7 +261,7 @@ def train_ft(net): batch_size=config.batch_size, train_image_size=config.train_image_size, eval_image_size=config.eval_image_size, target=config.device_target, distribute=config.run_distribute) - algo_ft = PrunerFtCompressAlgo(prune_rate=config.prune_rate) + algo_ft = PrunerFtCompressAlgo(config) net = algo_ft.apply(net) load_pretrained_ckpt(net) lr_ft_new = ms.Tensor(get_lr(lr_init=config.lr_init,