Skip to content
Snippets Groups Projects
Commit 54d77968 authored by wangzeyangyi's avatar wangzeyangyi
Browse files

minor fix

parent 576559d9
No related branches found
No related tags found
No related merge requests found
......@@ -66,7 +66,7 @@ _adam_opt = ops.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, ms.int32)
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay, param, \
m, v, gradient, decay_flag, optim_filter):
......@@ -78,7 +78,7 @@ def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
weight_decay (Tensor): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
......
......@@ -84,11 +84,11 @@ class DropPath(Cell):
self.dropout = Dropout(self.keep_prob)
def construct(self, x):
x_shape = self.shape(x)
if self.training:
x_shape = self.shape(x) # B N C
mask = self.ones((x_shape[0], 1, 1), ms.float32)
return self.dropout(mask)*x
x = self.dropout(mask)*x
return x
class BatchDense(Cell):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment