diff --git a/official/cv/densenet/modelarts/train_start.py b/official/cv/densenet/modelarts/train_start.py
index f3a8b7d7071061fd1781fd1cac20af1a7d866dfa..3248f05a64a76aa813e42525d53e85fc08b641a7 100644
--- a/official/cv/densenet/modelarts/train_start.py
+++ b/official/cv/densenet/modelarts/train_start.py
@@ -149,7 +149,7 @@ def _export_air(ckpt_dir):
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
- input_arr = Tensor(np.zeros([1, 3, 224, 224],
+ input_arr = Tensor(np.zeros([1, 3, config.image_size[0], config.image_size[1]],
np.float32))
print("Start export air.")
export(net, input_arr, file_name=config.file_name,
diff --git a/research/cv/RCAN/script/run_ascend_distribute.sh b/research/cv/RCAN/script/run_ascend_distribute.sh
index 2bebe962b0fb34381e79dc7b24c2f426c7890a7c..9f09e5942ebfc25c1c25a3eee762c1357d488737 100644
--- a/research/cv/RCAN/script/run_ascend_distribute.sh
+++ b/research/cv/RCAN/script/run_ascend_distribute.sh
@@ -59,7 +59,7 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
nohup python train.py \
--batch_size 16 \
--lr 1e-4 \
- --scale 2+3+4 \
+ --scale 2 \
--task_id 0 \
--dir_data $PATH2 \
--epochs 500 \
diff --git a/research/cv/RCAN/src/rcan_model.py b/research/cv/RCAN/src/rcan_model.py
index 6ffbf6680aa9513e1485e2715fbf3f1a77332b02..49517596ca36c1b3e0250b00fef9d5ffad56b09e 100644
--- a/research/cv/RCAN/src/rcan_model.py
+++ b/research/cv/RCAN/src/rcan_model.py
@@ -86,12 +86,11 @@ class Upsampler(nn.Cell):
"""rcan"""
super(Upsampler, self).__init__()
m = []
- for s in scale:
- if (s & (s - 1)) == 0:
- for _ in range(int(math.log(s, 2))):
- m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
- elif s == 3:
- m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
+ if (scale & (scale - 1)) == 0:
+ for _ in range(int(math.log(scale, 2))):
+ m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
+ elif scale == 3:
+ m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
self.net = nn.SequentialCell(m)
def construct(self, x):
@@ -186,7 +185,7 @@ class RCAN(nn.Cell):
n_feats = args.n_feats
kernel_size = 3
reduction = args.reduction
- scale = args.scale
+ scale = args.scale[0]
self.dytpe = mstype.float16
# RGB mean for DIV2K