Skip to content
Snippets Groups Projects
Unverified Commit ef22c291 authored by OuYang Yu's avatar OuYang Yu Committed by GitHub
Browse files

fix lamb grad functor (#4163)

parent aec3cd03
No related branches found
No related tags found
No related merge requests found
......@@ -101,7 +101,7 @@ struct LambGradFunctor {
CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);
const T next_m = beta1 * *m + (1 - beta1) * model_diff_t;
const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t;
*adam_diff = (next_m / (1 - *beta1_t)) / std::sqrt(next_v / (1 - *beta2_t) + epsilon);
*adam_diff = (next_m / (1 - *beta1_t)) / (std::sqrt(next_v / (1 - *beta2_t)) + epsilon);
*m = next_m;
*v = next_v;
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment