From 54d77968b092bd060411cf6bbbfd098fd17114c4 Mon Sep 17 00:00:00 2001
From: wangzeyangyi <tomzwang11@gmail.com>
Date: Fri, 29 Apr 2022 14:19:46 +0800
Subject: [PATCH] minor fix

---
 official/cv/vit/src/optimizer.py | 4 ++--
 official/cv/vit/src/vit.py       | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/official/cv/vit/src/optimizer.py b/official/cv/vit/src/optimizer.py
index 7519eb8dd..3c7ba7769 100644
--- a/official/cv/vit/src/optimizer.py
+++ b/official/cv/vit/src/optimizer.py
@@ -66,7 +66,7 @@ _adam_opt = ops.MultitypeFuncGraph("adam_opt")
 _scaler_one = Tensor(1, ms.int32)
 
 
-@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
+@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
                     "Tensor", "Bool", "Bool")
 def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay, param, \
                    m, v, gradient, decay_flag, optim_filter):
@@ -78,7 +78,7 @@ def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay
         beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
         eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
         lr (Tensor): Learning rate.
-        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
+        weight_decay (Tensor): Weight decay. Should be equal to or greater than 0.
         param (Tensor): Parameters.
         m (Tensor): m value of parameters.
         v (Tensor): v value of parameters.
diff --git a/official/cv/vit/src/vit.py b/official/cv/vit/src/vit.py
index 6d2739149..5bbf03e6b 100644
--- a/official/cv/vit/src/vit.py
+++ b/official/cv/vit/src/vit.py
@@ -84,11 +84,11 @@ class DropPath(Cell):
         self.dropout = Dropout(self.keep_prob)
 
     def construct(self, x):
-        x_shape = self.shape(x)
         if self.training:
             x_shape = self.shape(x) # B N C
             mask = self.ones((x_shape[0], 1, 1), ms.float32)
-        return self.dropout(mask)*x
+            x = self.dropout(mask)*x
+        return x
 
 
 class BatchDense(Cell):
-- 
GitLab