diff --git a/oneflow/core/actor/normal_forward_compute_actor.cpp b/oneflow/core/actor/normal_forward_compute_actor.cpp
index 7b4c165541daceaad41140e6956ce4928cf49aa3..ec1bae38ffb58c59f217c44b822b8271361ab13b 100644
--- a/oneflow/core/actor/normal_forward_compute_actor.cpp
+++ b/oneflow/core/actor/normal_forward_compute_actor.cpp
@@ -159,8 +159,7 @@ void NormalForwardCompActor::AsyncInitModelAndConstBuf() {
   for (const ExecKernel& exec_kernel : exec_kernel_vec()) {
     KernelCtx kernel_ctx = GenDefaultKernelCtx();
     exec_kernel.kernel->InitModelAndConstBuf(
-        kernel_ctx, parallel_ctx(), Global<SnapshotMgr>::Get()->GetReadableSnapshot(),
-        [&](const std::string& bn_in_op) {
+        kernel_ctx, parallel_ctx(), nullptr, [&](const std::string& bn_in_op) {
           const LogicalBlobId& lbi = exec_kernel.kernel->BnInOp2Lbi(bn_in_op);
           Blob* blob = nullptr;
           if (model_regst_) { blob = model_regst_->GetBlobByLbi(lbi); }
diff --git a/oneflow/core/job/job.proto b/oneflow/core/job/job.proto
index adfa7e687757fc3795962921019e25cbb60fe50f..b5f777e2f1b44eb6f5376861a50ee0320e74119c 100644
--- a/oneflow/core/job/job.proto
+++ b/oneflow/core/job/job.proto
@@ -18,9 +18,11 @@ message TrainConf {
   repeated string loss_lbn = 6;
   optional int32 loss_scale_factor = 7 [default = 1];
   optional string global_step_lbn = 8;
+  optional string primary_lr_lbn = 9;
+  optional string secondary_lr_lbn = 10;
 
   required float primary_lr = 101;
-  optional float secondary_lr = 102 [default = -1];
+  optional float secondary_lr = 102;
   optional float weight_l1 = 103 [default = 0];
   optional float bias_l1 = 104 [default = 0];
   optional float weight_l2 = 105 [default = 0];
diff --git a/oneflow/core/job/job_desc.cpp b/oneflow/core/job/job_desc.cpp
index bea2ab1f40ff6f56f9adcfe08fac26880d9f55c4..2d894cdac133c00caabdffd8214aa435e3456f87 100644
--- a/oneflow/core/job/job_desc.cpp
+++ b/oneflow/core/job/job_desc.cpp
@@ -56,8 +56,6 @@ int64_t JobDesc::NumOfPiecesInBatch() const {
   CHECK_EQ(BatchSize() % RecordPieceSize(), 0);
   return BatchSize() / RecordPieceSize();
 }
-float JobDesc::primary_lr() const { return job_conf_.train_conf().primary_lr(); }
-float JobDesc::secondary_lr() const { return job_conf_.train_conf().secondary_lr(); }
 float JobDesc::weight_l1() const { return job_conf_.train_conf().weight_l1(); }
 float JobDesc::bias_l1() const { return job_conf_.train_conf().bias_l1(); }
 float JobDesc::weight_l2() const { return job_conf_.train_conf().weight_l2(); }
diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h
index 856cfa01c03bb0f4b726ac139d4ad66bbf46fd71..c28e88cd36e49456bef477204c084cc45a790b97 100644
--- a/oneflow/core/job/job_desc.h
+++ b/oneflow/core/job/job_desc.h
@@ -67,8 +67,6 @@ class JobDesc final {
   int64_t TotalBatchNum() const;
   int64_t BatchSize() const;
   int64_t NumOfPiecesInBatch() const;
-  float primary_lr() const;
-  float secondary_lr() const;
   float weight_l1() const;
   float bias_l1() const;
   float weight_l2() const;
diff --git a/oneflow/core/job_completer/adam_optm.cpp b/oneflow/core/job_completer/adam_optm.cpp
index c3372c1116314a2ef2c5e20c9961a088be2e558b..9170fcd4f4339b6ec69f2fdfe94291c0ac1ab5c8 100644
--- a/oneflow/core/job_completer/adam_optm.cpp
+++ b/oneflow/core/job_completer/adam_optm.cpp
@@ -46,7 +46,8 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
     SetScalarShapeAndSbpConf(&beta1_t_var, job_builder);
     SetScalarShapeAndSbpConf(&beta2_t_var, job_builder);
   }
-  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf);
+  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder,
+                        mdupdt_op_conf);
   mdupdt_op_conf->set_m(m_var.name() + "/out");
   mdupdt_op_conf->set_v(v_var.name() + "/out");
   if (adam_conf.do_bias_correction()) {
diff --git a/oneflow/core/job_completer/auto_global_step.cpp b/oneflow/core/job_completer/auto_global_step.cpp
index f8b32a9cac7b9facd4de632a494f1555bd77096b..d172cca258c81e8829e9b3420b0c3e6bcc2b6eef 100644
--- a/oneflow/core/job_completer/auto_global_step.cpp
+++ b/oneflow/core/job_completer/auto_global_step.cpp
@@ -36,7 +36,7 @@ void AutoGlobalStep(const OpGraph& op_graph, Job* job) {
   assign_conf->set_value(GenLogicalBlobName(scalar_add_op_conf.name(), scalar_add_conf->out()));
 
   JobBuilder job_builder(job);
-  job_builder.AddOps(GenParallelConfOfCpuZeroOnAllMachines(),
+  job_builder.AddOps(GenParallelConfOfCpuZeroOnMaster(),
                      {variable_op_conf, identity_op_conf, scalar_add_op_conf, assign_op_conf});
   job->mutable_job_conf()->mutable_train_conf()->set_global_step_lbn(global_step_lbn);
 }
diff --git a/oneflow/core/job_completer/auto_learning_rate.cpp b/oneflow/core/job_completer/auto_learning_rate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5bb213652162f6f77de5662b7b8a9ccaaf0fb0cb
--- /dev/null
+++ b/oneflow/core/job_completer/auto_learning_rate.cpp
@@ -0,0 +1,61 @@
+#include "oneflow/core/graph/op_graph.h"
+#include "oneflow/core/job/job.pb.h"
+
+namespace oneflow {
+
+void AutoLearningRate(const OpGraph& op_graph, Job* job) {
+  JobBuilder job_builder(job);
+  const TrainConf& train_conf = job->job_conf().train_conf();
+  auto AddScheduleOp = [&](const std::string& op_name, const float learning_rate) -> std::string {
+    const ParallelConf& parallel_conf =
+        op_graph.OpNode4OpName(GenLogicalBlobId(train_conf.global_step_lbn()).op_name())
+            ->parallel_desc()
+            .parallel_conf();
+    const NormalModelUpdateOpUserConf& model_update_conf = train_conf.model_update_conf();
+    if (model_update_conf.has_warmup_conf() || model_update_conf.has_learning_rate_decay()) {
+      OperatorConf schedule_op_conf{};
+      schedule_op_conf.set_name(op_name);
+      LearningRateScheduleOpConf* schedule_conf =
+          schedule_op_conf.mutable_learning_rate_schedule_conf();
+      schedule_conf->set_global_step(train_conf.global_step_lbn());
+      schedule_conf->set_learning_rate(learning_rate);
+      schedule_conf->set_out("out");
+      if (model_update_conf.has_warmup_conf()) {
+        *schedule_conf->mutable_warmup_conf() = model_update_conf.warmup_conf();
+      }
+      if (model_update_conf.has_learning_rate_decay()) {
+        *schedule_conf->mutable_learning_rate_decay() = model_update_conf.learning_rate_decay();
+      }
+      job_builder.AddOps(parallel_conf, {schedule_op_conf});
+      return GenLogicalBlobName(op_name, schedule_conf->out());
+    } else {
+      OperatorConf constant_op_conf{};
+      constant_op_conf.set_name(op_name);
+      ConstantOpConf* constant_conf = constant_op_conf.mutable_constant_conf();
+      constant_conf->set_out("out");
+      *constant_conf->mutable_shape()->mutable_dim()->Add() = 1;
+      constant_conf->set_data_type(DataType::kFloat);
+      constant_conf->mutable_initializer()->mutable_constant_conf()->set_value(learning_rate);
+      job_builder.AddOps(parallel_conf, {constant_op_conf});
+      return GenLogicalBlobName(op_name, constant_conf->out());
+    }
+  };
+  if (!train_conf.has_primary_lr_lbn()) {
+    CHECK(train_conf.has_primary_lr());
+    const std::string lbn =
+        AddScheduleOp("System-Train-PrimaryLearningRate-Scheduler", train_conf.primary_lr());
+    job->mutable_job_conf()->mutable_train_conf()->set_primary_lr_lbn(lbn);
+  }
+  if (!train_conf.has_secondary_lr_lbn()) {
+    if (train_conf.has_secondary_lr()) {
+      const std::string lbn =
+          AddScheduleOp("System-Train-SecondaryLearningRate-Scheduler", train_conf.secondary_lr());
+      job->mutable_job_conf()->mutable_train_conf()->set_secondary_lr_lbn(lbn);
+    } else {
+      job->mutable_job_conf()->mutable_train_conf()->set_secondary_lr_lbn(
+          train_conf.primary_lr_lbn());
+    }
+  }
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/job_completer/auto_learning_rate.h b/oneflow/core/job_completer/auto_learning_rate.h
new file mode 100644
index 0000000000000000000000000000000000000000..5ffdc5f7296beaf1229f8223477d8bae60fa6479
--- /dev/null
+++ b/oneflow/core/job_completer/auto_learning_rate.h
@@ -0,0 +1,13 @@
+#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
+#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
+
+namespace oneflow {
+
+class OpGraph;
+class Job;
+
+void AutoLearningRate(const OpGraph& op_graph, Job* job);
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp
index 20e00b67028a10ac1bb0201a05ad74558a72d368..118761df9685bbb6350354ef93c7e26443a72a4a 100644
--- a/oneflow/core/job_completer/job_completer.cpp
+++ b/oneflow/core/job_completer/job_completer.cpp
@@ -12,6 +12,7 @@
 #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
 #include "oneflow/core/job_completer/auto_mixed_precision.h"
 #include "oneflow/core/job_completer/auto_global_step.h"
+#include "oneflow/core/job_completer/auto_learning_rate.h"
 
 namespace oneflow {
 
@@ -341,6 +342,7 @@ void JobCompleter::Complete(Job* job) const {
     WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps);
     WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision);
     WithOpGraphAndMutJob(job, &AutoGlobalStep);
+    WithOpGraphAndMutJob(job, &AutoLearningRate);
     // complete ops for trainning
     WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning);
     WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce);
diff --git a/oneflow/core/job_completer/lars_optm.cpp b/oneflow/core/job_completer/lars_optm.cpp
index 81728fa21559ebddf7b2c823748e8c1531346532..b5ac48e0c294afc96194faa9aa99b768d4c52c7f 100644
--- a/oneflow/core/job_completer/lars_optm.cpp
+++ b/oneflow/core/job_completer/lars_optm.cpp
@@ -8,6 +8,9 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
                              JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out,
                              const LogicalBlobId& total_loss_instance_num_lbi) {
   OperatorConf momentum_var(op.op_conf());
+  InitializerConf constant_initializer;
+  constant_initializer.mutable_constant_conf()->set_value(0.f);
+  *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer;
   momentum_var.set_name(op.op_name() + "-momentum");
   momentum_var.mutable_variable_conf()->set_out("out");
   job_builder->AddOps(parallel_conf, {momentum_var});
@@ -16,7 +19,8 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
   OperatorConf mdupdt_op;
   mdupdt_op.set_name(op.op_name() + "_optimizer");
   auto* mdupdt_op_conf = mdupdt_op.mutable_lars_model_update_conf();
-  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf);
+  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder,
+                        mdupdt_op_conf);
   mdupdt_op_conf->set_momentum(momentum_var.name() + "/out");
   job_builder->AddOps(parallel_conf, {mdupdt_op});
 }
diff --git a/oneflow/core/job_completer/momentum_optm.cpp b/oneflow/core/job_completer/momentum_optm.cpp
index 443ce2e43d3158a28452bf70071e9d31df62d7fc..0f3964b759850a1fcca1accb20030de457c94930 100644
--- a/oneflow/core/job_completer/momentum_optm.cpp
+++ b/oneflow/core/job_completer/momentum_optm.cpp
@@ -8,6 +8,9 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
                              JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out,
                              const LogicalBlobId& total_loss_instance_num_lbi) {
   OperatorConf momentum_var(op.op_conf());
+  InitializerConf constant_initializer;
+  constant_initializer.mutable_constant_conf()->set_value(0.f);
+  *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer;
   momentum_var.set_name(op.op_name() + "-momentum");
   momentum_var.mutable_variable_conf()->set_out("out");
   job_builder->AddOps(parallel_conf, {momentum_var});
@@ -16,7 +19,8 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
   OperatorConf mdupdt_op;
   mdupdt_op.set_name(op.op_name() + "_optimizer");
   auto* mdupdt_op_conf = mdupdt_op.mutable_momentum_model_update_conf();
-  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf);
+  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder,
+                        mdupdt_op_conf);
   mdupdt_op_conf->set_momentum(momentum_var.name() + "/out");
   job_builder->AddOps(parallel_conf, {mdupdt_op});
 }
diff --git a/oneflow/core/job_completer/naive_optm.cpp b/oneflow/core/job_completer/naive_optm.cpp
index d688c2539992eac585710a912aac43de45e1cf62..67477480c6cc07c5c7c04641eb78d960dfbbae7b 100644
--- a/oneflow/core/job_completer/naive_optm.cpp
+++ b/oneflow/core/job_completer/naive_optm.cpp
@@ -9,7 +9,7 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
                              const LogicalBlobId& total_loss_instance_num_lbi) {
   OperatorConf mdupdt_op;
   mdupdt_op.set_name(op.op_name() + "_optimizer");
-  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi,
+  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder,
                         mdupdt_op.mutable_naive_model_update_conf());
   job_builder->AddOps(parallel_conf, {mdupdt_op});
 }
diff --git a/oneflow/core/job_completer/optimizer.cpp b/oneflow/core/job_completer/optimizer.cpp
index 7f85f0635d300b45b028666d8fb0c847c5a5854a..ec4df9f815d3f9bde89ddb45a4ad9f22f2a2708c 100644
--- a/oneflow/core/job_completer/optimizer.cpp
+++ b/oneflow/core/job_completer/optimizer.cpp
@@ -61,34 +61,36 @@ void BindTwoVariableOpObnSbpConf(const std::string& lhs_op_name, const std::stri
 
 template<typename T>
 void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out,
-                           const LogicalBlobId& total_loss_instance_num_lbi, T* mdupdt_op_conf) {
-  const auto& train_conf = GlobalJobDesc().job_conf().train_conf();
+                           const LogicalBlobId& total_loss_instance_num_lbi,
+                           JobBuilder* job_builder, T* mdupdt_op_conf) {
+  const auto& train_conf = job_builder->job().job_conf().train_conf();
   *mdupdt_op_conf->mutable_user_conf() = train_conf.model_update_conf();
   mdupdt_op_conf->set_model_diff(GenLogicalBlobName(diff_lbi_of_var_out));
   mdupdt_op_conf->set_total_instance_num_diff(GenLogicalBlobName(total_loss_instance_num_lbi));
   mdupdt_op_conf->set_model(GenLogicalBlobName(op.BnInOp2Lbi("out")));
-  float primary_lr = GlobalJobDesc().primary_lr();
-  float secondary_lr = GlobalJobDesc().secondary_lr();
-  if (secondary_lr < 0) { secondary_lr = primary_lr; }
+  mdupdt_op_conf->set_global_step(train_conf.global_step_lbn());
+  const std::string& primary_lr_lbn = train_conf.primary_lr_lbn();
+  const std::string& secondary_lr_lbn = train_conf.secondary_lr_lbn();
   if (op.op_conf().variable_conf().model_name() == "weight") {
-    mdupdt_op_conf->set_learning_rate(primary_lr);
-    mdupdt_op_conf->set_l1(GlobalJobDesc().weight_l1());
-    mdupdt_op_conf->set_l2(GlobalJobDesc().weight_l2());
+    mdupdt_op_conf->set_learning_rate(primary_lr_lbn);
+    mdupdt_op_conf->set_l1(train_conf.weight_l1());
+    mdupdt_op_conf->set_l2(train_conf.weight_l2());
   } else if (op.op_conf().variable_conf().model_name() == "bias") {
-    mdupdt_op_conf->set_learning_rate(secondary_lr);
-    mdupdt_op_conf->set_l1(GlobalJobDesc().bias_l1());
-    mdupdt_op_conf->set_l2(GlobalJobDesc().bias_l2());
+    mdupdt_op_conf->set_learning_rate(secondary_lr_lbn);
+    mdupdt_op_conf->set_l1(train_conf.bias_l1());
+    mdupdt_op_conf->set_l2(train_conf.bias_l2());
   } else {
-    mdupdt_op_conf->set_learning_rate(primary_lr);
+    mdupdt_op_conf->set_learning_rate(primary_lr_lbn);
     mdupdt_op_conf->set_l1(0);
     mdupdt_op_conf->set_l2(0);
   }
 }
 
-#define INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(T)                     \
-  template void ConstructMdUpdtOpConf<T>(                             \
-      const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out, \
-      const LogicalBlobId& total_loss_instance_num_lbi, T* mdupdt_op_conf)
+#define INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(T)                                          \
+  template void ConstructMdUpdtOpConf<T>(const VariableOp& op,                             \
+                                         const LogicalBlobId& diff_lbi_of_var_out,         \
+                                         const LogicalBlobId& total_loss_instance_num_lbi, \
+                                         JobBuilder* job_builder, T* mdupdt_op_conf)
 
 INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(NaiveModelUpdateOpConf);
 INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(MomentumModelUpdateOpConf);
diff --git a/oneflow/core/job_completer/optimizer.h b/oneflow/core/job_completer/optimizer.h
index 88197d4e2a83ee96b60f5cb00a8da6754499b82a..9479ec16f9e06bc3e61f8caafab9ec4f05bbc9b5 100644
--- a/oneflow/core/job_completer/optimizer.h
+++ b/oneflow/core/job_completer/optimizer.h
@@ -15,7 +15,8 @@ void BindTwoVariableOpObnSbpConf(const std::string& lhs_op_name, const std::stri
                                  JobBuilder* job_builder);
 template<typename T>
 void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out,
-                           const LogicalBlobId& total_loss_instance_num_lbi, T*);
+                           const LogicalBlobId& total_loss_instance_num_lbi,
+                           JobBuilder* job_builder, T*);
 
 class GenerateOptimizerOpConfWrapperStruct final {
  public:
diff --git a/oneflow/core/job_completer/rmsprop_optm.cpp b/oneflow/core/job_completer/rmsprop_optm.cpp
index 166e1ab75ee0efcc1bde291e5ae0b97132a47cd7..c8021d266dd7e748905ccd931262c84f587f5791 100644
--- a/oneflow/core/job_completer/rmsprop_optm.cpp
+++ b/oneflow/core/job_completer/rmsprop_optm.cpp
@@ -9,7 +9,7 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
                              const LogicalBlobId& total_loss_instance_num_lbi) {
   OperatorConf mdupdt_op;
   mdupdt_op.set_name(op.op_name() + "_optimizer");
-  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi,
+  ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder,
                         mdupdt_op.mutable_rmsprop_model_update_conf());
   job_builder->AddOps(parallel_conf, {mdupdt_op});
 }
diff --git a/oneflow/core/kernel/adam_model_update_kernel.cpp b/oneflow/core/kernel/adam_model_update_kernel.cpp
index 900dac0f209e366d54009edcfed62e3ef99bf804..ad953fcabb3fbace93c6ee9bbae8ef70d514f41e 100644
--- a/oneflow/core/kernel/adam_model_update_kernel.cpp
+++ b/oneflow/core/kernel/adam_model_update_kernel.cpp
@@ -31,26 +31,25 @@ const PbMessage& AdamMdUpdateKernel<device_type, T>::GetCustomizedOpConf() const
 
 template<DeviceType device_type, typename T>
 void AdamMdUpdateKernel<device_type, T>::UpdateModel(
-    DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-    int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step,
+    const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   Blob* model_blob = BnInOp2Blob("model");
   Blob* m_blob = BnInOp2Blob("m");
   Blob* v_blob = BnInOp2Blob("v");
   Blob* beta1_t_blob = BnInOp2Blob("beta1_t");
   Blob* beta2_t_blob = BnInOp2Blob("beta2_t");
   const auto& adam_conf = GetAdamModelUpdateConf(this->op_conf());
-  if ((next_model_vid != 1) && adam_conf.do_bias_correction()) {
-    KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta1()),
-                                     beta1_t_blob->mut_dptr<T>(), 1);
-    KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta2()),
-                                     beta2_t_blob->mut_dptr<T>(), 1);
+  if (adam_conf.do_bias_correction()) {
+    AdamMdUpdateKernelUtil<device_type, T>::DoBiasCorrection(
+        ctx, global_step, static_cast<T>(adam_conf.beta1()), static_cast<T>(adam_conf.beta2()),
+        beta1_t_blob->mut_dptr<T>(), beta2_t_blob->mut_dptr<T>());
   }
   KernelUtil<device_type, T>::Div(ctx, model_blob->shape().elem_cnt(),
                                   BnInOp2Blob("model_diff")->mut_dptr<T>(), batch_instance_num_ptr);
   AdamMdUpdateKernelUtil<device_type, T>::UpdateModel(
       ctx, model_blob->shape().elem_cnt(), learning_rate, l1, l2, static_cast<T>(adam_conf.beta1()),
       static_cast<T>(adam_conf.beta2()), static_cast<T>(adam_conf.epsilon()),
-      adam_conf.do_bias_correction(), next_model_vid,
+      adam_conf.do_bias_correction(), global_step,
       (beta1_t_blob ? beta1_t_blob->dptr<T>() : nullptr),
       (beta2_t_blob ? beta2_t_blob->dptr<T>() : nullptr), BnInOp2Blob("model_diff")->mut_dptr<T>(),
       model_blob->mut_dptr<T>(), m_blob->mut_dptr<T>(), v_blob->mut_dptr<T>());
@@ -59,9 +58,10 @@ void AdamMdUpdateKernel<device_type, T>::UpdateModel(
 template<typename T>
 class AdamMdUpdateKernelUtil<DeviceType::kCPU, T> final {
  public:
-  static void UpdateModel(DeviceCtx* ctx, int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2,
-                          T epsilon, bool do_bias_correction, int64_t next_model_vid,
-                          const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, T* v) {
+  static void UpdateModel(DeviceCtx* ctx, int64_t n, const float* learning_rate, T l1, T l2,
+                          T beta1, T beta2, T epsilon, bool do_bias_correction,
+                          const int64_t* global_step, const T* beta1_t, const T* beta2_t,
+                          T* model_diff, T* model, T* m, T* v) {
     // first-order moment
     UpdateMomentEstimate<T>(n, do_bias_correction, beta1, 1, model_diff, beta1_t, m);
     // second-order moment
@@ -69,7 +69,14 @@ class AdamMdUpdateKernelUtil<DeviceType::kCPU, T> final {
     FOR_RANGE(int64_t, i, 0, n) {
       model_diff[i] = m[i] / (std::sqrt(v[i]) + epsilon);
       T reg_diff = RegDiff(model_diff[i], l1, l2, model[i]);
-      model[i] = model[i] - learning_rate * reg_diff;
+      model[i] = model[i] - *learning_rate * reg_diff;
+    }
+  }
+  static void DoBiasCorrection(DeviceCtx*, const int64_t* global_step, const T beta1, const T beta2,
+                               T* beta1_t, T* beta2_t) {
+    if (*global_step != 0) {
+      *beta1_t *= beta1;
+      *beta2_t *= beta2;
     }
   }
 };
diff --git a/oneflow/core/kernel/adam_model_update_kernel.cu b/oneflow/core/kernel/adam_model_update_kernel.cu
index 8c34784799ee1e9800e699f94f513c4f70158431..181bc0cd24ad8cba1aa2454d33d1f11b0d67ae4e 100644
--- a/oneflow/core/kernel/adam_model_update_kernel.cu
+++ b/oneflow/core/kernel/adam_model_update_kernel.cu
@@ -44,17 +44,17 @@ __device__ void UpdateMomentEstimate(T beta, const T* model_diff, const T* beta_
 }
 
 template<typename T>
-__device__ void UpdateModel(T learning_rate, T l1, T l2, T epsilon, T* model_diff, T* model, T* m,
-                            T* v) {
+__device__ void UpdateModel(const float* learning_rate, T l1, T l2, T epsilon, T* model_diff,
+                            T* model, T* m, T* v) {
   *model_diff = *m / (sqrt(*v) + epsilon);
   T reg_diff = RegDiff(*model_diff, l1, l2, *model);
-  *model = *model - learning_rate * reg_diff;
+  *model = *model - *learning_rate * reg_diff;
 }
 
 template<bool do_bias_correction, typename T>
-__global__ void UpdateModelGpu(int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon,
-                               const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m,
-                               T* v) {
+__global__ void UpdateModelGpu(int64_t n, const float* learning_rate, T l1, T l2, T beta1, T beta2,
+                               T epsilon, const T* beta1_t, const T* beta2_t, T* model_diff,
+                               T* model, T* m, T* v) {
   CUDA_1D_KERNEL_LOOP(i, n) {
     UpdateMomentEstimate<1, do_bias_correction>(beta1, model_diff + i, beta1_t, m + i);
     UpdateMomentEstimate<2, do_bias_correction>(beta2, model_diff + i, beta2_t, v + i);
@@ -62,14 +62,24 @@ __global__ void UpdateModelGpu(int64_t n, T learning_rate, T l1, T l2, T beta1,
   }
 }
 
+template<typename T>
+__global__ void DoBiasCorrectionGpu(const int64_t* global_step, const T beta1, const T beta2,
+                                    T* beta1_t, T* beta2_t) {
+  if (*global_step != 0) {
+    *beta1_t *= beta1;
+    *beta2_t *= beta2;
+  }
+}
+
 }  // namespace
 
 template<typename T>
 class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final {
  public:
-  static void UpdateModel(DeviceCtx* ctx, int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2,
-                          T epsilon, bool do_bias_correction, int64_t next_model_vid,
-                          const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, T* v) {
+  static void UpdateModel(DeviceCtx* ctx, int64_t n, const float* learning_rate, T l1, T l2,
+                          T beta1, T beta2, T epsilon, bool do_bias_correction,
+                          const int64_t* global_step, const T* beta1_t, const T* beta2_t,
+                          T* model_diff, T* model, T* m, T* v) {
     if (do_bias_correction) {
       UpdateModelGpu<true, T>
           <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
@@ -82,6 +92,12 @@ class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final {
               m, v);
     }
   }
+
+  static void DoBiasCorrection(DeviceCtx* ctx, const int64_t* global_step, const T beta1,
+                               const T beta2, T* beta1_t, T* beta2_t) {
+    DoBiasCorrectionGpu<T>
+        <<<1, 1, 0, ctx->cuda_stream()>>>(global_step, beta1, beta2, beta1_t, beta2_t);
+  }
 };
 
 #define INSTANTIATE_GPU_KERNEL_UTIL(type_cpp, type_proto) \
diff --git a/oneflow/core/kernel/adam_model_update_kernel.h b/oneflow/core/kernel/adam_model_update_kernel.h
index 7567a810a0e6b6f4bbe0ae33eadb9f52cd5f7362..289d518d867748bff1e29b13e1dc24d2624b97bb 100644
--- a/oneflow/core/kernel/adam_model_update_kernel.h
+++ b/oneflow/core/kernel/adam_model_update_kernel.h
@@ -14,17 +14,19 @@ class AdamMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> {
 
  private:
   const PbMessage& GetCustomizedOpConf() const override;
-  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-                   int64_t next_model_vid,
+  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2,
+                   const int64_t* global_step, const float* learning_rate,
                    std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 };
 
 template<DeviceType device_type, typename T>
 class AdamMdUpdateKernelUtil final {
  public:
-  static void UpdateModel(DeviceCtx*, int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2,
-                          T epsilon, bool do_bias_correction, int64_t next_model_vid,
+  static void UpdateModel(DeviceCtx*, int64_t n, const float* learning_rate, T l1, T l2, T beta1,
+                          T beta2, T epsilon, bool do_bias_correction, const int64_t* global_step,
                           const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, T* v);
+  static void DoBiasCorrection(DeviceCtx*, const int64_t* global_step, T beta1, T beta2, T* beta1_t,
+                               T* beta2_t);
 };
 
 DECLARE_MDUPDT_KERNEL_CREATOR(Adam);
diff --git a/oneflow/core/kernel/lars_model_update_kernel.cpp b/oneflow/core/kernel/lars_model_update_kernel.cpp
index 31091f2ac458c14052522d87ac7ee7bf1fb45f05..341abb4853ccde032e397d5b557bbecd2a4e3c35 100644
--- a/oneflow/core/kernel/lars_model_update_kernel.cpp
+++ b/oneflow/core/kernel/lars_model_update_kernel.cpp
@@ -18,23 +18,19 @@ const PbMessage& LARSMdUpdateKernel<device_type, T>::GetCustomizedOpConf() const
 
 template<DeviceType device_type, typename T>
 void LARSMdUpdateKernel<device_type, T>::UpdateModel(
-    DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-    int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step,
+    const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* model_diff_blob = BnInOp2Blob("model_diff");
   Blob* model_blob = BnInOp2Blob("model");
   Blob* momentum_blob = BnInOp2Blob("momentum");
-  Blob* data_tmp_blob = BnInOp2Blob("data_tmp");
+  Blob* data_tmp_blob = BnInOp2Blob("lars_data_tmp");
   const LARSModelUpdateConf& lars_conf = GetLARSModelUpdateConf(this->op_conf());
-  if (next_model_vid == 1) {
-    Memset<device_type>(ctx, momentum_blob->mut_dptr<T>(), 0,
-                        momentum_blob->ByteSizeOfDataContentField());
-  }
   Memset<device_type>(ctx, data_tmp_blob->mut_dptr<T>(), 0,
                       data_tmp_blob->ByteSizeOfDataContentField());
   LARSMdUpdateKernelUtil<device_type, T>::UpdateModel(
       ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, learning_rate, l1, l2,
       static_cast<T>(lars_conf.momentum_beta()), static_cast<T>(lars_conf.epsilon()),
-      static_cast<T>(lars_conf.lars_coefficient()), next_model_vid, model_diff_blob->dptr<T>(),
+      static_cast<T>(lars_conf.lars_coefficient()), global_step, model_diff_blob->dptr<T>(),
       model_blob->mut_dptr<T>(), momentum_blob->mut_dptr<T>(), data_tmp_blob->mut_dptr<T>());
 }
 
@@ -42,9 +38,9 @@ template<typename T>
 class LARSMdUpdateKernelUtil<DeviceType::kCPU, T> final {
  public:
   static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr,
-                          T learning_rate, T l1, T l2, T momentum_beta, T epsilon,
-                          T lars_coefficient, int64_t next_model_vid, const T* model_diff, T* model,
-                          T* momentum, T* data_tmp) {
+                          const float* learning_rate, T l1, T l2, T momentum_beta, T epsilon,
+                          T lars_coefficient, const int64_t* global_step, const T* model_diff,
+                          T* model, T* momentum, T* data_tmp) {
     T model_norm = 0;
     T model_diff_norm = 0;
     FOR_RANGE(int64_t, i, 0, n) {
@@ -54,11 +50,11 @@ class LARSMdUpdateKernelUtil<DeviceType::kCPU, T> final {
     model_norm = std::sqrt(model_norm / n);
     model_diff_norm = std::sqrt(model_diff_norm / n);
     T local_learning_rate = 0;
-    if (next_model_vid == 1) {
+    if (*global_step == 0) {
       local_learning_rate =
-          learning_rate * lars_coefficient * model_norm / (epsilon + model_diff_norm);
+          *learning_rate * lars_coefficient * model_norm / (epsilon + model_diff_norm);
     } else {
-      local_learning_rate = learning_rate * lars_coefficient * model_norm
+      local_learning_rate = *learning_rate * lars_coefficient * model_norm
                             / (epsilon + model_diff_norm + l2 * model_norm);
     }
     FOR_RANGE(int64_t, i, 0, n) {
diff --git a/oneflow/core/kernel/lars_model_update_kernel.cu b/oneflow/core/kernel/lars_model_update_kernel.cu
index 36f1143564f8725831036898a3c0c682285ada72..cd74e947fdf736feb2b114d918b7df5f5753ea83 100644
--- a/oneflow/core/kernel/lars_model_update_kernel.cu
+++ b/oneflow/core/kernel/lars_model_update_kernel.cu
@@ -7,19 +7,19 @@ namespace oneflow {
 namespace {
 
 template<typename T>
-__global__ void GetLocalLearningRateGpu(const T* batch_instance_num_ptr, T learning_rate, T l2,
-                                        T epsilon, T lars_coefficient, int64_t next_model_vid,
-                                        T* data_tmp) {
+__global__ void GetLocalLearningRateGpu(const T* batch_instance_num_ptr, const float* learning_rate,
+                                        T l2, T epsilon, T lars_coefficient,
+                                        const int64_t* global_step, T* data_tmp) {
   T* model_norm = &data_tmp[0];
   T* model_diff_norm = &data_tmp[1];
   T* local_learning_rate = &data_tmp[2];
   *model_norm = std::sqrt(*model_norm);
   *model_diff_norm = std::sqrt(*model_diff_norm) / *batch_instance_num_ptr;  // TODO(shiyuan)
-  if (next_model_vid == 1) {
+  if (*global_step == 0) {
     *local_learning_rate =
-        learning_rate * lars_coefficient * (*model_norm) / (epsilon + (*model_diff_norm));
+        *learning_rate * lars_coefficient * (*model_norm) / (epsilon + (*model_diff_norm));
   } else {
-    *local_learning_rate = learning_rate * lars_coefficient * (*model_norm)
+    *local_learning_rate = *learning_rate * lars_coefficient * (*model_norm)
                            / (epsilon + (*model_diff_norm) + l2 * (*model_diff_norm));
   }
 }
@@ -41,14 +41,14 @@ template<typename T>
 class LARSMdUpdateKernelUtil<DeviceType::kGPU, T> final {
  public:
   static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr,
-                          T learning_rate, T l1, T l2, T momentum_beta, T epsilon,
-                          T lars_coefficient, int64_t next_model_vid, const T* model_diff, T* model,
-                          T* momentum, T* data_tmp) {
+                          const float* learning_rate, T l1, T l2, T momentum_beta, T epsilon,
+                          T lars_coefficient, const int64_t* global_step, const T* model_diff,
+                          T* model, T* momentum, T* data_tmp) {
     KernelUtil<DeviceType::kGPU, T>::Dot(ctx, n, model, 1, model, 1, &data_tmp[0]);
     KernelUtil<DeviceType::kGPU, T>::Dot(ctx, n, model_diff, 1, model_diff, 1, &data_tmp[1]);
     GetLocalLearningRateGpu<T>
         <<<1, 1, 0, ctx->cuda_stream()>>>(batch_instance_num_ptr, learning_rate, l2, epsilon,
-                                          lars_coefficient, next_model_vid, data_tmp);
+                                          lars_coefficient, global_step, data_tmp);
     UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
         n, batch_instance_num_ptr, l1, l2, momentum_beta, model_diff, model, momentum, data_tmp);
   }
diff --git a/oneflow/core/kernel/lars_model_update_kernel.h b/oneflow/core/kernel/lars_model_update_kernel.h
index ce90fe15847899c5d1d200368590e6252a834194..ca37bbabafa3331b65acb61c5acb0a688404d894 100644
--- a/oneflow/core/kernel/lars_model_update_kernel.h
+++ b/oneflow/core/kernel/lars_model_update_kernel.h
@@ -14,18 +14,18 @@ class LARSMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> {
 
  private:
   const PbMessage& GetCustomizedOpConf() const override;
-  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-                   int64_t next_model_vid,
+  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2,
+                   const int64_t* global_step, const float* learning_rate,
                    std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 };
 
 template<DeviceType device_type, typename T>
 class LARSMdUpdateKernelUtil final {
  public:
-  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate,
-                          T l1, T l2, T momentum_beta, T epsilon, T lars_coefficient,
-                          int64_t next_model_vid, const T* model_diff, T* model, T* momentum,
-                          T* data_tmp);
+  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr,
+                          const float* learning_rate, T l1, T l2, T momentum_beta, T epsilon,
+                          T lars_coefficient, const int64_t* global_step, const T* model_diff,
+                          T* model, T* momentum, T* data_tmp);
 };
 
 DECLARE_MDUPDT_KERNEL_CREATOR(LARS);
diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cpp b/oneflow/core/kernel/momentum_model_update_kernel.cpp
index 032dcbab6143cba3801ff8559a3d8d6e64b5c322..92bf35609c9fbf9f57cab0304625547cde7ba05d 100644
--- a/oneflow/core/kernel/momentum_model_update_kernel.cpp
+++ b/oneflow/core/kernel/momentum_model_update_kernel.cpp
@@ -18,17 +18,16 @@ const PbMessage& MomentumMdUpdateKernel<device_type, T>::GetCustomizedOpConf() c
 
 template<DeviceType device_type, typename T>
 void MomentumMdUpdateKernel<device_type, T>::UpdateModel(
-    DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-    int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step,
+    const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* model_diff_blob = BnInOp2Blob("model_diff");
   Blob* model_blob = BnInOp2Blob("model");
   Blob* momentum_blob = BnInOp2Blob("momentum");
   float beta = GetMomentumModelUpdateConf(this->op_conf()).beta();
-  if (next_model_vid == 1) { beta = 0.0f; }
 
   MomentumMdUpdateKernelUtil<device_type, T>::UpdateModel(
       ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, static_cast<T>(beta),
-      learning_rate, l1, l2, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(),
+      global_step, learning_rate, l1, l2, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(),
       momentum_blob->mut_dptr<T>());
 }
 
@@ -36,10 +35,12 @@ template<typename T>
 class MomentumMdUpdateKernelUtil<DeviceType::kCPU, T> final {
  public:
   static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T beta,
-                          T learning_rate, T l1, T l2, const T* model_diff, T* model, T* momentum) {
+                          const int64_t* global_step, const float* learning_rate, T l1, T l2,
+                          const T* model_diff, T* model, T* momentum) {
+    T cur_beta = *global_step == 0 ? 0 : beta;
     for (int64_t i = 0; i != n; ++i) {
       T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
-      momentum[i] = beta * momentum[i] - learning_rate * reg_diff;
+      momentum[i] = cur_beta * momentum[i] - *learning_rate * reg_diff;
       model[i] = model[i] + momentum[i];
     }
   }
diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cu b/oneflow/core/kernel/momentum_model_update_kernel.cu
index e0a74db6fc9808dffeeb653472d33af836160bc1..6c0bfd2fda5c0b68e1667f4cb6f105fc3b68eb73 100644
--- a/oneflow/core/kernel/momentum_model_update_kernel.cu
+++ b/oneflow/core/kernel/momentum_model_update_kernel.cu
@@ -7,11 +7,13 @@ namespace oneflow {
 namespace {
 
 template<typename T>
-__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T beta, T learning_rate,
-                               T l1, T l2, const T* model_diff, T* model, T* momentum) {
+__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T beta,
+                               const int64_t* global_step, const float* learning_rate, T l1, T l2,
+                               const T* model_diff, T* model, T* momentum) {
+  T cur_beta = *global_step == 0 ? 0 : beta;
   CUDA_1D_KERNEL_LOOP(i, n) {
     T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
-    momentum[i] = beta * momentum[i] - learning_rate * reg_diff;
+    momentum[i] = cur_beta * momentum[i] - *learning_rate * reg_diff;
     model[i] = model[i] + momentum[i];
   }
 }
@@ -22,10 +24,11 @@ template<typename T>
 class MomentumMdUpdateKernelUtil<DeviceType::kGPU, T> final {
  public:
   static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, T beta,
-                          T learning_rate, const T l1, const T l2, const T* model_diff, T* model,
-                          T* momentum) {
+                          const int64_t* global_step, const float* learning_rate, const T l1,
+                          const T l2, const T* model_diff, T* model, T* momentum) {
     UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-        n, batch_instance_num_ptr, beta, learning_rate, l1, l2, model_diff, model, momentum);
+        n, batch_instance_num_ptr, beta, global_step, learning_rate, l1, l2, model_diff, model,
+        momentum);
   }
 };
 
diff --git a/oneflow/core/kernel/momentum_model_update_kernel.h b/oneflow/core/kernel/momentum_model_update_kernel.h
index 0fbbfe6b67da1bf2feaf7e3cc1d3aaecf3191da2..8442d2dcac350b9a4017afcde1de40ac8b05740e 100644
--- a/oneflow/core/kernel/momentum_model_update_kernel.h
+++ b/oneflow/core/kernel/momentum_model_update_kernel.h
@@ -16,8 +16,8 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T>
   const PbMessage& GetCustomizedOpConf() const override;
 
  private:
-  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-                   int64_t next_model_vid,
+  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2,
+                   const int64_t* global_step, const float* learning_rate,
                    std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 };
 
@@ -25,7 +25,8 @@ template<DeviceType device_type, typename T>
 class MomentumMdUpdateKernelUtil final {
  public:
   static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T beta,
-                          T learning_rate, T l1, T l2, const T* model_diff, T* model, T* momentum);
+                          const int64_t* global_step, const float* learning_rate, T l1, T l2,
+                          const T* model_diff, T* model, T* momentum);
 };
 
 DECLARE_MDUPDT_KERNEL_CREATOR(Momentum);
diff --git a/oneflow/core/kernel/naive_model_update_kernel.cpp b/oneflow/core/kernel/naive_model_update_kernel.cpp
index d6a086ee5027cfd0836cc215451cc461d5bc5ed5..3ca58170c7bb7dbda3a8f31a2b7a6f62388706b3 100644
--- a/oneflow/core/kernel/naive_model_update_kernel.cpp
+++ b/oneflow/core/kernel/naive_model_update_kernel.cpp
@@ -5,8 +5,8 @@ namespace oneflow {
 
 template<DeviceType device_type, typename T>
 void NaiveMdUpdateKernel<device_type, T>::UpdateModel(
-    DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-    int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step,
+    const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* model_diff_blob = BnInOp2Blob("model_diff");
   Blob* model_blob = BnInOp2Blob("model");
   // model = model - alpha * model_diff
@@ -24,10 +24,10 @@ template<typename T>
 class NaiveMdUpdateKernelUtil<DeviceType::kCPU, T> final {
  public:
   static void UpdateModel(DeviceCtx*, const int64_t n, const T* batch_instance_num_ptr,
-                          T learning_rate, T l1, T l2, const T* model_diff, T* model) {
+                          const float* learning_rate, T l1, T l2, const T* model_diff, T* model) {
     for (int64_t i = 0; i != n; ++i) {
       T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
-      model[i] = model[i] - learning_rate * reg_diff;
+      model[i] = model[i] - *learning_rate * reg_diff;
     }
   }
 };
diff --git a/oneflow/core/kernel/naive_model_update_kernel.cu b/oneflow/core/kernel/naive_model_update_kernel.cu
index 4e21d3656954b75ac6cdf8d83b06c8d93bf9c78e..2233e57af2b3416a01a7ad43ecac8bf8fbbf9bb8 100644
--- a/oneflow/core/kernel/naive_model_update_kernel.cu
+++ b/oneflow/core/kernel/naive_model_update_kernel.cu
@@ -7,11 +7,12 @@ namespace oneflow {
 namespace {
 
 template<typename T>
-__global__ void UpdateModelGpu(const int64_t n, const T* batch_instance_num_ptr, T learning_rate,
-                               T l1, T l2, const T* model_diff, T* model) {
+__global__ void UpdateModelGpu(const int64_t n, const T* batch_instance_num_ptr,
+                               const float* learning_rate, T l1, T l2, const T* model_diff,
+                               T* model) {
   CUDA_1D_KERNEL_LOOP(i, n) {
     T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
-    model[i] = model[i] - learning_rate * reg_diff;
+    model[i] = model[i] - *learning_rate * reg_diff;
   }
 }
 
@@ -21,7 +22,7 @@ template<typename T>
 class NaiveMdUpdateKernelUtil<DeviceType::kGPU, T> final {
  public:
   static void UpdateModel(DeviceCtx* ctx, const int64_t n, const T* batch_instance_num_ptr,
-                          T learning_rate, T l1, T l2, const T* model_diff, T* model) {
+                          const float* learning_rate, T l1, T l2, const T* model_diff, T* model) {
     UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
         n, batch_instance_num_ptr, learning_rate, l1, l2, model_diff, model);
   }
diff --git a/oneflow/core/kernel/naive_model_update_kernel.h b/oneflow/core/kernel/naive_model_update_kernel.h
index 2307a395977db21b4a0036fa1b973ff83136359b..c203f63631bc66707bdadaa9cd3f18c815381cc0 100644
--- a/oneflow/core/kernel/naive_model_update_kernel.h
+++ b/oneflow/core/kernel/naive_model_update_kernel.h
@@ -14,16 +14,16 @@ class NaiveMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> {
 
  private:
   const PbMessage& GetCustomizedOpConf() const override;
-  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-                   int64_t next_model_vid,
+  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2,
+                   const int64_t* global_step, const float* learning_rate,
                    std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 };
 
 template<DeviceType device_type, typename T>
 class NaiveMdUpdateKernelUtil final {
  public:
-  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate,
-                          T l1, T l2, const T* model_diff, T* model);
+  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr,
+                          const float* learning_rate, T l1, T l2, const T* model_diff, T* model);
 };
 
 DECLARE_MDUPDT_KERNEL_CREATOR(Naive);
diff --git a/oneflow/core/kernel/normal_model_update_kernel.cpp b/oneflow/core/kernel/normal_model_update_kernel.cpp
index f5114d71bf0b02064692f5639b02e8947bf338bd..9e132f069caafd004cbe98d1d37fbf68192812db 100644
--- a/oneflow/core/kernel/normal_model_update_kernel.cpp
+++ b/oneflow/core/kernel/normal_model_update_kernel.cpp
@@ -10,28 +10,18 @@ namespace oneflow {
 template<DeviceType device_type, typename T>
 void NormalMdUpdateKernel<device_type, T>::Forward(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
-  int64_t cur_batch_num = std::get<0>(
-      *(reinterpret_cast<std::tuple<int64_t, std::function<const Blob*(const LogicalBlobId&)>>*>(
-          ctx.other)));
-  int64_t next_model_vid = cur_batch_num + 1;
   const PbMessage& op_conf = this->GetCustomizedOpConf();
   const auto& conf = *GetMsgPtrFromPbMessage<NormalModelUpdateOpUserConf>(op_conf, "user_conf");
-  float learning_rate = GetValFromPbMessage<float>(op_conf, "learning_rate");
   const T* batch_instance_num_ptr = BnInOp2Blob("total_instance_num_diff")->dptr<T>();
+  const int64_t* global_step_ptr = BnInOp2Blob("global_step")->dptr<int64_t>();
+  const float* learning_rate_ptr = BnInOp2Blob("learning_rate")->dptr<float>();
   if (conf.has_clip_conf()) {
-    ClipGradient(ctx.device_ctx, cur_batch_num, conf.clip_conf(), batch_instance_num_ptr,
-                 BnInOp2Blob);
-  }
-  if (TriggerWarmup(conf, learning_rate, next_model_vid)) {
-    learning_rate = GetWarmupLearningRate(conf.warmup_conf(), learning_rate, next_model_vid);
-  } else if (conf.has_learning_rate_decay()) {
-    learning_rate =
-        GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num);
+    ClipGradient(ctx.device_ctx, conf.clip_conf(), batch_instance_num_ptr, BnInOp2Blob);
   }
   float l1 = GetValFromPbMessage<float>(op_conf, "l1");
   float l2 = GetValFromPbMessage<float>(op_conf, "l2");
-  UpdateModel(ctx.device_ctx, batch_instance_num_ptr, static_cast<T>(learning_rate),
-              static_cast<T>(l1), static_cast<T>(l2), next_model_vid, BnInOp2Blob);
+  UpdateModel(ctx.device_ctx, batch_instance_num_ptr, static_cast<T>(l1), static_cast<T>(l2),
+              global_step_ptr, learning_rate_ptr, BnInOp2Blob);
 }
 
 #define INSTANTIATE_KERNEL(device_type, data_type_pair) \
@@ -58,109 +48,8 @@ Kernel* CreateMdUpdtKernel(const KernelConf& kernel_conf) {
   }
 }
 
-double ExponentialDecayedLearningRate(const ExponentialDecayConf& conf, double lr,
-                                      int64_t cur_batch_num) {
-  CHECK_GT(conf.decay_batches(), 0);
-  double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches());
-  if (conf.staircase()) { p = std::floor(p); }
-  return lr * std::pow(conf.decay_rate(), p);
-}
-
-double InverseTimeDecayedLearningRate(const InverseTimeDecayConf& conf, double lr,
-                                      int64_t cur_batch_num) {
-  CHECK_GT(conf.decay_batches(), 0);
-  double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches());
-  if (conf.staircase()) { p = std::floor(p); }
-  return lr / (1.0 + conf.decay_rate() * p);
-}
-
-double NaturalExpDecayedLearningRate(const NaturalExpDecayConf& conf, double lr,
-                                     int64_t cur_batch_num) {
-  CHECK_GT(conf.decay_batches(), 0);
-  double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches());
-  if (conf.staircase()) { p = std::floor(p); }
-  return lr * std::exp(-conf.decay_rate() * p);
-}
-
-double PiecewiseConstantLearningRate(const PiecewiseConstantConf& conf, double lr,
-                                     int64_t cur_batch_num) {
-  const PbRf<int64_t>& boundaries = conf.boundaries();
-  const PbRf<double>& values = conf.values();
-  CHECK_EQ(boundaries.size() + 1, values.size());
-  size_t i = 0;
-  for (; i < boundaries.size(); ++i) {
-    if (cur_batch_num <= boundaries[i]) { break; }
-  }
-  return values[i];
-}
-
-double PolynomialDecayedLearningRate(const PolynomialDecayConf& conf, double lr,
-                                     int64_t cur_batch_num) {
-  CHECK_GT(conf.decay_batches(), 0);
-  double cur_batch = static_cast<double>(cur_batch_num);
-  double decay_batches = static_cast<double>(conf.decay_batches());
-  if (conf.cycle()) {
-    if (cur_batch_num == 0) { cur_batch = 1.0; }
-    decay_batches = decay_batches * std::ceil(cur_batch / decay_batches);
-  } else {
-    cur_batch = std::min(cur_batch, decay_batches);
-  }
-  return (lr - conf.end_learning_rate()) * std::pow(1.0 - (cur_batch / decay_batches), conf.power())
-         + conf.end_learning_rate();
-}
-
-double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t cur_batch_num) {
-  CHECK_GT(conf.decay_batches(), 0);
-  const double PI = std::atan(1.0) * 4.0;
-  double cur_batch = static_cast<double>(cur_batch_num);
-  double decay_batches = static_cast<double>(conf.decay_batches());
-  cur_batch = std::min(cur_batch, decay_batches);
-  double cosine_decay = 0.5 * (1.0 + std::cos(PI * cur_batch / decay_batches));
-  double decayed = (1.0 - conf.alpha()) * cosine_decay + conf.alpha();
-  return lr * decayed;
-}
-
-double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr,
-                                       int64_t cur_batch_num) {
-  CHECK_GT(conf.decay_batches(), 0);
-  const double PI = std::atan(1.0) * 4.0;
-  double cur_batch = static_cast<double>(cur_batch_num);
-  double decay_batches = static_cast<double>(conf.decay_batches());
-  cur_batch = std::min(cur_batch, decay_batches);
-  double linear_decay = (decay_batches - cur_batch) / decay_batches;
-  double cosine_decay =
-      0.5 * (1.0 + std::cos(PI * 2.0 * conf.num_periods() * cur_batch / decay_batches));
-  double decayed = (conf.alpha() + linear_decay) * cosine_decay + conf.beta();
-  return lr * decayed;
-}
-
-double ConstantWarmupLearningRate(const ConstantWarmupConf& conf, double lr,
-                                  int64_t next_batch_num) {
-  CHECK_GT(conf.warmup_batches(), 0);
-  CHECK_GT(conf.multiplier(), 0);
-  CHECK_LT(conf.multiplier(), 1);
-  if (next_batch_num <= conf.warmup_batches()) {
-    return lr * conf.multiplier();
-  } else {
-    return lr;
-  }
-}
-
-double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t next_batch_num) {
-  CHECK_GT(conf.warmup_batches(), 0);
-  CHECK_GE(conf.start_multiplier(), 0);
-  CHECK_LT(conf.start_multiplier(), 1);
-  double start_multiplier = conf.start_multiplier();
-  double multiplier = 1.0;
-  if (next_batch_num <= conf.warmup_batches()) {
-    multiplier = start_multiplier
-                 + (1.0 - start_multiplier) * (next_batch_num * 1.0 / conf.warmup_batches());
-  }
-  return lr * multiplier;
-}
-
 template<DeviceType device_type, typename T>
-void ClipByGlobalNorm(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipByGlobalNormConf& conf,
+void ClipByGlobalNorm(DeviceCtx* ctx, const ClipByGlobalNormConf& conf,
                       const T* batch_instance_num_ptr,
                       std::function<Blob*(const std::string&)> BnInOp2Blob) {
   int64_t n = BnInOp2Blob("model_diff")->shape().elem_cnt();
@@ -182,62 +71,13 @@ void ClipByGlobalNorm(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipByG
 
 }  // namespace
 
-template<DeviceType device_type, typename T>
-bool NormalMdUpdateKernel<device_type, T>::TriggerWarmup(const NormalModelUpdateOpUserConf& conf,
-                                                         double lr, int64_t next_batch_num) const {
-  if (!conf.has_warmup_conf()) { return false; }
-  const WarmupConf& warmup_conf = conf.warmup_conf();
-  if (warmup_conf.has_constant_conf()) {
-    return (next_batch_num <= warmup_conf.constant_conf().warmup_batches());
-  } else if (warmup_conf.has_linear_conf()) {
-    return (next_batch_num <= warmup_conf.linear_conf().warmup_batches());
-  } else {
-    UNIMPLEMENTED();
-  }
-}
-
-template<DeviceType device_type, typename T>
-double NormalMdUpdateKernel<device_type, T>::GetWarmupLearningRate(const WarmupConf& conf,
-                                                                   double lr,
-                                                                   int64_t next_batch_num) const {
-  if (conf.has_constant_conf()) {
-    return ConstantWarmupLearningRate(conf.constant_conf(), lr, next_batch_num);
-  } else if (conf.has_linear_conf()) {
-    return LinearWarmupLearningRate(conf.linear_conf(), lr, next_batch_num);
-  } else {
-    UNIMPLEMENTED();
-  }
-}
-
 template<DeviceType device_type, typename T>
 void NormalMdUpdateKernel<device_type, T>::ClipGradient(
-    DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf,
-    const T* batch_instance_num_ptr, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    DeviceCtx* ctx, const ClipConf& conf, const T* batch_instance_num_ptr,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   if (conf.has_clip_by_global_norm()) {
-    ClipByGlobalNorm<device_type, T>(ctx, cur_batch_num, conf.clip_by_global_norm(),
-                                     batch_instance_num_ptr, BnInOp2Blob);
-  } else {
-    UNIMPLEMENTED();
-  }
-}
-
-template<DeviceType device_type, typename T>
-double NormalMdUpdateKernel<device_type, T>::GetDecayedLearningRate(
-    const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) const {
-  if (conf.has_exponential_conf()) {
-    return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num);
-  } else if (conf.has_inverse_time_conf()) {
-    return InverseTimeDecayedLearningRate(conf.inverse_time_conf(), lr, cur_batch_num);
-  } else if (conf.has_natural_exp_conf()) {
-    return NaturalExpDecayedLearningRate(conf.natural_exp_conf(), lr, cur_batch_num);
-  } else if (conf.has_piecewise_constant_conf()) {
-    return PiecewiseConstantLearningRate(conf.piecewise_constant_conf(), lr, cur_batch_num);
-  } else if (conf.has_polynomial_conf()) {
-    return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num);
-  } else if (conf.has_cosine_conf()) {
-    return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num);
-  } else if (conf.has_linear_cosine_conf()) {
-    return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num);
+    ClipByGlobalNorm<device_type, T>(ctx, conf.clip_by_global_norm(), batch_instance_num_ptr,
+                                     BnInOp2Blob);
   } else {
     UNIMPLEMENTED();
   }
diff --git a/oneflow/core/kernel/normal_model_update_kernel.h b/oneflow/core/kernel/normal_model_update_kernel.h
index ae5accae68144baf2db3ac891b0818c5614a6623..040e89819d111e984ef1a7d6c0bb20139933ef05 100644
--- a/oneflow/core/kernel/normal_model_update_kernel.h
+++ b/oneflow/core/kernel/normal_model_update_kernel.h
@@ -16,18 +16,12 @@ class NormalMdUpdateKernel : public KernelIf<device_type> {
 
  protected:
   NormalMdUpdateKernel() = default;
-  virtual void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1,
-                           T l2, int64_t next_model_vid,
+  virtual void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2,
+                           const int64_t* global_step, const float* learning_rate,
                            std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0;
 
  private:
-  bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr,
-                     int64_t cur_batch_num) const;
-  double GetWarmupLearningRate(const WarmupConf&, double lr, int64_t cur_batch_num) const;
-  double GetDecayedLearningRate(const LearningRateDecayConf&, double lr,
-                                int64_t cur_batch_num) const;
-  void ClipGradient(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf,
-                    const T* batch_instance_num_ptr,
+  void ClipGradient(DeviceCtx* ctx, const ClipConf& conf, const T* batch_instance_num_ptr,
                     std::function<Blob*(const std::string&)> BnInOp2Blob) const;
 };
 
diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.cpp b/oneflow/core/kernel/rmsprop_model_update_kernel.cpp
index d169e7657580e8b6c7d8a09831ec805e14d2f3c4..193943400cc0b901b16d5afee6e6eb5c39cda9ad 100644
--- a/oneflow/core/kernel/rmsprop_model_update_kernel.cpp
+++ b/oneflow/core/kernel/rmsprop_model_update_kernel.cpp
@@ -18,31 +18,30 @@ const PbMessage& RMSPropMdUpdateKernel<device_type, T>::GetCustomizedOpConf() co
 
 template<DeviceType device_type, typename T>
 void RMSPropMdUpdateKernel<device_type, T>::UpdateModel(
-    DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-    int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step,
+    const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* model_diff_blob = BnInOp2Blob("model_diff");
   Blob* model_blob = BnInOp2Blob("model");
   Blob* mean_square_blob = BnInOp2Blob("mean_square");
   const RMSPropModelUpdateConf& conf = GetRMSPropModelUpdateConf(this->op_conf());
-  float decay_rate = conf.decay_rate();
-  if (next_model_vid == 1) { decay_rate = 0.0f; }
 
   RMSPropMdUpdateKernelUtil<device_type, T>::UpdateModel(
-      ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, learning_rate,
-      static_cast<T>(decay_rate), static_cast<T>(conf.epsilon()), l1, l2,
+      ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, global_step, learning_rate,
+      static_cast<T>(conf.decay_rate()), static_cast<T>(conf.epsilon()), l1, l2,
       model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(), mean_square_blob->mut_dptr<T>());
 }
 
 template<typename T>
 class RMSPropMdUpdateKernelUtil<DeviceType::kCPU, T> final {
  public:
-  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate,
-                          T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model,
-                          T* mean_square) {
+  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr,
+                          const int64_t* global_step, const float* learning_rate, T decay_rate,
+                          T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) {
+    const T cur_decay_rate = *global_step == 0 ? 0 : decay_rate;
     for (int64_t i = 0; i < n; ++i) {
       T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
-      mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i];
-      model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
+      mean_square[i] = (1 - cur_decay_rate) * reg_diff * reg_diff + cur_decay_rate * mean_square[i];
+      model[i] = model[i] - *learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
     }
   }
 };
diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.cu b/oneflow/core/kernel/rmsprop_model_update_kernel.cu
index d02973d4481df200d6062ef72449840a1277f1b4..54757ae6c9f82855d9e25e67c25b6e6ad6bf40e7 100644
--- a/oneflow/core/kernel/rmsprop_model_update_kernel.cu
+++ b/oneflow/core/kernel/rmsprop_model_update_kernel.cu
@@ -7,13 +7,15 @@ namespace oneflow {
 namespace {
 
 template<typename T>
-__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T learning_rate,
-                               T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model,
+__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr,
+                               const int64_t* global_step, const float* learning_rate, T decay_rate,
+                               T epsilon, T l1, T l2, const T* model_diff, T* model,
                                T* mean_square) {
+  const T cur_decay_rate = *global_step == 0 ? 0 : decay_rate;
   CUDA_1D_KERNEL_LOOP(i, n) {
     T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
-    mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i];
-    model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
+    mean_square[i] = (1 - cur_decay_rate) * reg_diff * reg_diff + cur_decay_rate * mean_square[i];
+    model[i] = model[i] - *learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
   }
 }
 
@@ -23,11 +25,11 @@ template<typename T>
 class RMSPropMdUpdateKernelUtil<DeviceType::kGPU, T> final {
  public:
   static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr,
-                          T learning_rate, T decay_rate, T epsilon, T l1, T l2, const T* model_diff,
-                          T* model, T* mean_square) {
+                          const int64_t* global_step, const float* learning_rate, T decay_rate,
+                          T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) {
     UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-        n, batch_instance_num_ptr, learning_rate, decay_rate, epsilon, l1, l2, model_diff, model,
-        mean_square);
+        n, batch_instance_num_ptr, global_step, learning_rate, decay_rate, epsilon, l1, l2,
+        model_diff, model, mean_square);
   }
 };
 
diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.h b/oneflow/core/kernel/rmsprop_model_update_kernel.h
index 8c404c1b83e00e3cc2f8b5de1c7067dd4c76513e..4b2ed30a5aaebd91877045b98cee69adba3b7f3c 100644
--- a/oneflow/core/kernel/rmsprop_model_update_kernel.h
+++ b/oneflow/core/kernel/rmsprop_model_update_kernel.h
@@ -14,8 +14,8 @@ class RMSPropMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T>
 
  private:
   const PbMessage& GetCustomizedOpConf() const override;
-  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
-                   int64_t next_model_vid,
+  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2,
+                   const int64_t* global_step, const float* learning_rate,
                    std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 };
 
@@ -24,9 +24,9 @@ class RMSPropMdUpdateKernelUtil final {
  public:
   // mean_square = (1 - decay_rate) * model_diff ^ 2 + decay_rate * mean_square
   // model = model - learning_rate * model_diff / sqrt(mean_square + epsilon)
-  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate,
-                          T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model,
-                          T* mean_square);
+  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr,
+                          const int64_t* global_step, const float* learning_rate, T decay_rate,
+                          T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square);
 };
 
 DECLARE_MDUPDT_KERNEL_CREATOR(RMSProp);
diff --git a/oneflow/core/operator/lars_model_update_op.cpp b/oneflow/core/operator/lars_model_update_op.cpp
index 573aad5a727c934eb39a8919c09918a9d75884bc..6a82c5af9d11f21fed508534f70ec3aa2cf7f939 100644
--- a/oneflow/core/operator/lars_model_update_op.cpp
+++ b/oneflow/core/operator/lars_model_update_op.cpp
@@ -4,7 +4,7 @@ namespace oneflow {
 
 void LARSModelUpdateOp::MdUpdtVirtualInitFromOpConf() {
   EnrollInputBn("momentum", false)->set_is_mutable(true);
-  EnrollTmpBn("data_tmp");
+  EnrollTmpBn("lars_data_tmp");
 }
 
 Maybe<void> LARSModelUpdateOp::MdUpdtVirtualInferBlobDescs(
@@ -13,11 +13,11 @@ Maybe<void> LARSModelUpdateOp::MdUpdtVirtualInferBlobDescs(
   const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model");
   CHECK_OR_RETURN(*GetBlobDesc4BnInOp("momentum") == *model_blob_desc);
 
-  // data_tmp for gpu compute
-  // data_tmp[0] for model_norm, data_tmp[1] for model_diff_norm, data_tmp[2] for
+  // lars_data_tmp for gpu compute
+  // lars_data_tmp[0] for model_norm, lars_data_tmp[1] for model_diff_norm, lars_data_tmp[2] for
   // local_learning_rate
-  *GetBlobDesc4BnInOp("data_tmp") = *model_blob_desc;
-  GetBlobDesc4BnInOp("data_tmp")->mut_shape() = Shape({3});
+  *GetBlobDesc4BnInOp("lars_data_tmp") = *model_blob_desc;
+  GetBlobDesc4BnInOp("lars_data_tmp")->mut_shape() = Shape({3});
   return Maybe<void>::Ok();
 }
 
@@ -26,7 +26,7 @@ const PbMessage& LARSModelUpdateOp::GetCustomizedConf() const {
 }
 
 const HashSet<std::string> LARSModelUpdateOp::AlwaysBroadcastParallelBns() const {
-  return HashSet<std::string>{"data_tmp"};
+  return HashSet<std::string>{"lars_data_tmp"};
 }
 
 REGISTER_CLASS(NormalModelUpdateOpUserConf::kLarsConf, NormalModelUpdtOp, LARSModelUpdateOp);
diff --git a/oneflow/core/operator/normal_model_update_op.cpp b/oneflow/core/operator/normal_model_update_op.cpp
index 5c0d8c658b08045643bd0c1bf0134b735598f9ff..9c8f04b1d7b002e77db48cb0fc28e57191761c3c 100644
--- a/oneflow/core/operator/normal_model_update_op.cpp
+++ b/oneflow/core/operator/normal_model_update_op.cpp
@@ -7,6 +7,8 @@ void NormalModelUpdtOp::InitFromOpConf() {
   EnrollInputBn("model_diff", false);
   EnrollInputBn("total_instance_num_diff", false);
   EnrollInputBn("model", false)->set_is_mutable(true);
+  EnrollInputBn("learning_rate", false);
+  EnrollInputBn("global_step", false);
   const PbMessage& conf = this->GetCustomizedConf();
   const auto& user_conf = *GetMsgPtrFromPbMessage<NormalModelUpdateOpUserConf>(conf, "user_conf");
   if (user_conf.has_clip_conf() && user_conf.clip_conf().has_clip_by_global_norm()) {
@@ -50,6 +52,8 @@ void NormalModelUpdtOp::GetSbpSignatures(
   const auto& bns = AlwaysBroadcastParallelBns();
   PbRpf<std::string> broadcast_bns = {bns.begin(), bns.end()};
   *broadcast_bns.Add() = "total_instance_num_diff";
+  *broadcast_bns.Add() = "learning_rate";
+  *broadcast_bns.Add() = "global_step";
   FOR_RANGE(int64_t, i, 0, LogicalBlobDesc4Ibn("model").shape().NumAxes()) {
     SbpSignatureBuilder()
         .Split(input_bns(), i)
diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto
index bdc2ef3e48b751e1a37a1135c9d728108ac536c5..ba1e7fca7f6577cdce2b7d17761833cae656ec7e 100644
--- a/oneflow/core/operator/op_conf.proto
+++ b/oneflow/core/operator/op_conf.proto
@@ -574,9 +574,10 @@ message NormalModelUpdateOpConf {
   required string model_diff = 2;
   required string total_instance_num_diff = 3;
   required string model = 4;
-  required float learning_rate = 5;
-  required float l1 = 6;
-  required float l2 = 7;
+  required string global_step = 5;
+  required string learning_rate = 6;
+  required float l1 = 7;
+  required float l2 = 8;
 }
 
 message NaiveModelUpdateOpConf {
@@ -584,9 +585,10 @@ message NaiveModelUpdateOpConf {
   required string model_diff = 2;
   required string total_instance_num_diff = 3;
   required string model = 4;
-  required float learning_rate = 5;
-  required float l1 = 6;
-  required float l2 = 7;
+  required string global_step = 5;
+  required string learning_rate = 6;
+  required float l1 = 7;
+  required float l2 = 8;
 }
 
 message MomentumModelUpdateOpConf {
@@ -595,9 +597,10 @@ message MomentumModelUpdateOpConf {
   required string model_diff = 3;
   required string total_instance_num_diff = 4;
   required string model = 5;
-  required float learning_rate = 6;
-  required float l1 = 7;
-  required float l2 = 8;
+  required string global_step = 6;
+  required string learning_rate = 7;
+  required float l1 = 8;
+  required float l2 = 9;
 }
 
 message RMSPropModelUpdateOpConf {
@@ -605,9 +608,10 @@ message RMSPropModelUpdateOpConf {
   required string model_diff = 2;
   required string total_instance_num_diff = 3;
   required string model = 4;
-  required float learning_rate = 5;
-  required float l1 = 6;
-  required float l2 = 7;
+  required string global_step = 5;
+  required string learning_rate = 6;
+  required float l1 = 7;
+  required float l2 = 8;
 }
 
 message LARSModelUpdateOpConf {
@@ -616,9 +620,10 @@ message LARSModelUpdateOpConf {
   required string model_diff = 3;
   required string total_instance_num_diff = 4;
   required string model = 5;
-  required float learning_rate = 6;
-  required float l1 = 7;
-  required float l2 = 8;
+  required string global_step = 6;
+  required string learning_rate = 7;
+  required float l1 = 8;
+  required float l2 = 9;
 }
 
 message AdamModelUpdateOpConf {
@@ -630,9 +635,10 @@ message AdamModelUpdateOpConf {
   required string model_diff = 6;
   required string total_instance_num_diff = 7;
   required string model = 8;
-  required float learning_rate = 9;
-  required float l1 = 10;
-  required float l2 = 11;
+  required string global_step = 9;
+  required string learning_rate = 10;
+  required float l1 = 11;
+  required float l2 = 12;
 }
 
 message AccumulateOpConf {