Skip to content
Snippets Groups Projects
Commit b61c3b97 authored by ShawnXuan's avatar ShawnXuan Committed by Juncheng
Browse files

Fix adam mv initilizer (#2082)

* zero constant initilzer for adam m and v

* make of_format

* init adam m v beta1_t and beta2_t

* use value instead of initializer

* const float& -> const float

* update
parent a1149ff8
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,13 @@ namespace oneflow {
namespace {
OperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name,
JobBuilder* job_builder) {
const float initial_value, JobBuilder* job_builder) {
OperatorConf helper_variable_op(op.op_conf());
helper_variable_op.set_name(op.op_name() + "-" + name);
helper_variable_op.mutable_variable_conf()->set_out("out");
InitializerConf constant_initializer;
constant_initializer.mutable_constant_conf()->set_value(initial_value);
*(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;
BindTwoVariableOpObnSbpConf(helper_variable_op.name(), op.op_name(), job_builder);
return helper_variable_op;
}
......@@ -24,8 +27,8 @@ void SetScalarShapeAndSbpConf(OperatorConf* op_conf, JobBuilder* job_builder) {
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out,
const LogicalBlobId& total_loss_instance_num_lbi) {
const OperatorConf& m_var = GenerateAdamHelperVariableOpConf(op, "m", job_builder);
const OperatorConf& v_var = GenerateAdamHelperVariableOpConf(op, "v", job_builder);
const OperatorConf& m_var = GenerateAdamHelperVariableOpConf(op, "m", 0.f, job_builder);
const OperatorConf& v_var = GenerateAdamHelperVariableOpConf(op, "v", 0.f, job_builder);
job_builder->AddOps(parallel_conf, {m_var, v_var});
OperatorConf mdupdt_op;
......@@ -35,9 +38,10 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
GlobalJobDesc().job_conf().train_conf().model_update_conf();
OperatorConf beta1_t_var;
OperatorConf beta2_t_var;
if (mdupdt_op_conf->user_conf().adam_conf().do_bias_correction()) {
beta1_t_var = GenerateAdamHelperVariableOpConf(op, "beta1_t", job_builder);
beta2_t_var = GenerateAdamHelperVariableOpConf(op, "beta2_t", job_builder);
const AdamModelUpdateConf& adam_conf = mdupdt_op_conf->user_conf().adam_conf();
if (adam_conf.do_bias_correction()) {
beta1_t_var = GenerateAdamHelperVariableOpConf(op, "beta1_t", adam_conf.beta1(), job_builder);
beta2_t_var = GenerateAdamHelperVariableOpConf(op, "beta2_t", adam_conf.beta2(), job_builder);
job_builder->AddOps(parallel_conf, {beta1_t_var, beta2_t_var});
SetScalarShapeAndSbpConf(&beta1_t_var, job_builder);
SetScalarShapeAndSbpConf(&beta2_t_var, job_builder);
......@@ -45,7 +49,7 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf);
mdupdt_op_conf->set_m(m_var.name() + "/out");
mdupdt_op_conf->set_v(v_var.name() + "/out");
if (mdupdt_op_conf->user_conf().adam_conf().do_bias_correction()) {
if (adam_conf.do_bias_correction()) {
mdupdt_op_conf->set_beta1_t(beta1_t_var.name() + "/out");
mdupdt_op_conf->set_beta2_t(beta2_t_var.name() + "/out");
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment