diff --git a/research/cv/RCAN/src/rcan_model.py b/research/cv/RCAN/src/rcan_model.py index 0fd5020768cfc18ef43945bcb4557f2fb31270f7..6ffbf6680aa9513e1485e2715fbf3f1a77332b02 100644 --- a/research/cv/RCAN/src/rcan_model.py +++ b/research/cv/RCAN/src/rcan_model.py @@ -86,11 +86,12 @@ class Upsampler(nn.Cell): """rcan""" super(Upsampler, self).__init__() m = [] - if (scale & (scale - 1)) == 0: - for _ in range(int(math.log(scale, 2))): - m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias)) - elif scale == 3: - m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias)) + for s in scale: + if (s & (s - 1)) == 0: + for _ in range(int(math.log(s, 2))): + m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias)) + elif s == 3: + m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias)) self.net = nn.SequentialCell(m) def construct(self, x):