Skip to content
Snippets Groups Projects
Unverified Commit e3f3c7e6 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2600 vit minor fix

Merge pull request !2600 from wangzeyangyi/vit
parents f1a59135 54d77968
No related branches found
No related tags found
No related merge requests found
......@@ -66,7 +66,7 @@ _adam_opt = ops.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, ms.int32)
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay, param, \
m, v, gradient, decay_flag, optim_filter):
......@@ -78,7 +78,7 @@ def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
weight_decay (Tensor): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
......
......@@ -84,11 +84,11 @@ class DropPath(Cell):
self.dropout = Dropout(self.keep_prob)
def construct(self, x):
x_shape = self.shape(x)
if self.training:
x_shape = self.shape(x) # B N C
mask = self.ones((x_shape[0], 1, 1), ms.float32)
return self.dropout(mask)*x
x = self.dropout(mask)*x
return x
class BatchDense(Cell):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment