diff --git a/official/cv/vit/src/optimizer.py b/official/cv/vit/src/optimizer.py index 7519eb8ddd8fb44cd47f5fdf27f7d1657ee2a8fd..3c7ba7769018abd0f9c077447a2117daa5fd475b 100644 --- a/official/cv/vit/src/optimizer.py +++ b/official/cv/vit/src/optimizer.py @@ -66,7 +66,7 @@ _adam_opt = ops.MultitypeFuncGraph("adam_opt") _scaler_one = Tensor(1, ms.int32) -@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay, param, \ m, v, gradient, decay_flag, optim_filter): @@ -78,7 +78,7 @@ def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay (Number): Weight decay. Should be equal to or greater than 0. + weight_decay (Tensor): Weight decay. Should be equal to or greater than 0. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. diff --git a/official/cv/vit/src/vit.py b/official/cv/vit/src/vit.py index 6d27391498a39dad8cc4b89662d7b7b848ee39f3..5bbf03e6bc3629f154ae4503b35dc6338051e1a5 100644 --- a/official/cv/vit/src/vit.py +++ b/official/cv/vit/src/vit.py @@ -84,11 +84,11 @@ class DropPath(Cell): self.dropout = Dropout(self.keep_prob) def construct(self, x): - x_shape = self.shape(x) if self.training: x_shape = self.shape(x) # B N C mask = self.ones((x_shape[0], 1, 1), ms.float32) - return self.dropout(mask)*x + x = self.dropout(mask)*x + return x class BatchDense(Cell):