From 54d77968b092bd060411cf6bbbfd098fd17114c4 Mon Sep 17 00:00:00 2001 From: wangzeyangyi <tomzwang11@gmail.com> Date: Fri, 29 Apr 2022 14:19:46 +0800 Subject: [PATCH] minor fix --- official/cv/vit/src/optimizer.py | 4 ++-- official/cv/vit/src/vit.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/official/cv/vit/src/optimizer.py b/official/cv/vit/src/optimizer.py index 7519eb8dd..3c7ba7769 100644 --- a/official/cv/vit/src/optimizer.py +++ b/official/cv/vit/src/optimizer.py @@ -66,7 +66,7 @@ _adam_opt = ops.MultitypeFuncGraph("adam_opt") _scaler_one = Tensor(1, ms.int32) -@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay, param, \ m, v, gradient, decay_flag, optim_filter): @@ -78,7 +78,7 @@ def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. - weight_decay (Number): Weight decay. Should be equal to or greater than 0. + weight_decay (Tensor): Weight decay. Should be equal to or greater than 0. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. diff --git a/official/cv/vit/src/vit.py b/official/cv/vit/src/vit.py index 6d2739149..5bbf03e6b 100644 --- a/official/cv/vit/src/vit.py +++ b/official/cv/vit/src/vit.py @@ -84,11 +84,11 @@ class DropPath(Cell): self.dropout = Dropout(self.keep_prob) def construct(self, x): - x_shape = self.shape(x) if self.training: x_shape = self.shape(x) # B N C mask = self.ones((x_shape[0], 1, 1), ms.float32) - return self.dropout(mask)*x + x = self.dropout(mask)*x + return x class BatchDense(Cell): -- GitLab