diff --git a/oneflow/core/actor/normal_forward_compute_actor.cpp b/oneflow/core/actor/normal_forward_compute_actor.cpp index 7b4c165541daceaad41140e6956ce4928cf49aa3..ec1bae38ffb58c59f217c44b822b8271361ab13b 100644 --- a/oneflow/core/actor/normal_forward_compute_actor.cpp +++ b/oneflow/core/actor/normal_forward_compute_actor.cpp @@ -159,8 +159,7 @@ void NormalForwardCompActor::AsyncInitModelAndConstBuf() { for (const ExecKernel& exec_kernel : exec_kernel_vec()) { KernelCtx kernel_ctx = GenDefaultKernelCtx(); exec_kernel.kernel->InitModelAndConstBuf( - kernel_ctx, parallel_ctx(), Global<SnapshotMgr>::Get()->GetReadableSnapshot(), - [&](const std::string& bn_in_op) { + kernel_ctx, parallel_ctx(), nullptr, [&](const std::string& bn_in_op) { const LogicalBlobId& lbi = exec_kernel.kernel->BnInOp2Lbi(bn_in_op); Blob* blob = nullptr; if (model_regst_) { blob = model_regst_->GetBlobByLbi(lbi); } diff --git a/oneflow/core/job/job.proto b/oneflow/core/job/job.proto index adfa7e687757fc3795962921019e25cbb60fe50f..b5f777e2f1b44eb6f5376861a50ee0320e74119c 100644 --- a/oneflow/core/job/job.proto +++ b/oneflow/core/job/job.proto @@ -18,9 +18,11 @@ message TrainConf { repeated string loss_lbn = 6; optional int32 loss_scale_factor = 7 [default = 1]; optional string global_step_lbn = 8; + optional string primary_lr_lbn = 9; + optional string secondary_lr_lbn = 10; required float primary_lr = 101; - optional float secondary_lr = 102 [default = -1]; + optional float secondary_lr = 102; optional float weight_l1 = 103 [default = 0]; optional float bias_l1 = 104 [default = 0]; optional float weight_l2 = 105 [default = 0]; diff --git a/oneflow/core/job/job_desc.cpp b/oneflow/core/job/job_desc.cpp index bea2ab1f40ff6f56f9adcfe08fac26880d9f55c4..2d894cdac133c00caabdffd8214aa435e3456f87 100644 --- a/oneflow/core/job/job_desc.cpp +++ b/oneflow/core/job/job_desc.cpp @@ -56,8 +56,6 @@ int64_t JobDesc::NumOfPiecesInBatch() const { CHECK_EQ(BatchSize() % RecordPieceSize(), 0); return BatchSize() / RecordPieceSize(); } -float JobDesc::primary_lr() const { return job_conf_.train_conf().primary_lr(); } -float JobDesc::secondary_lr() const { return job_conf_.train_conf().secondary_lr(); } float JobDesc::weight_l1() const { return job_conf_.train_conf().weight_l1(); } float JobDesc::bias_l1() const { return job_conf_.train_conf().bias_l1(); } float JobDesc::weight_l2() const { return job_conf_.train_conf().weight_l2(); } diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h index 856cfa01c03bb0f4b726ac139d4ad66bbf46fd71..c28e88cd36e49456bef477204c084cc45a790b97 100644 --- a/oneflow/core/job/job_desc.h +++ b/oneflow/core/job/job_desc.h @@ -67,8 +67,6 @@ class JobDesc final { int64_t TotalBatchNum() const; int64_t BatchSize() const; int64_t NumOfPiecesInBatch() const; - float primary_lr() const; - float secondary_lr() const; float weight_l1() const; float bias_l1() const; float weight_l2() const; diff --git a/oneflow/core/job_completer/adam_optm.cpp b/oneflow/core/job_completer/adam_optm.cpp index c3372c1116314a2ef2c5e20c9961a088be2e558b..9170fcd4f4339b6ec69f2fdfe94291c0ac1ab5c8 100644 --- a/oneflow/core/job_completer/adam_optm.cpp +++ b/oneflow/core/job_completer/adam_optm.cpp @@ -46,7 +46,8 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ SetScalarShapeAndSbpConf(&beta1_t_var, job_builder); SetScalarShapeAndSbpConf(&beta2_t_var, job_builder); } - ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf); + ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder, + mdupdt_op_conf); mdupdt_op_conf->set_m(m_var.name() + "/out"); mdupdt_op_conf->set_v(v_var.name() + "/out"); if (adam_conf.do_bias_correction()) { diff --git a/oneflow/core/job_completer/auto_global_step.cpp b/oneflow/core/job_completer/auto_global_step.cpp index f8b32a9cac7b9facd4de632a494f1555bd77096b..d172cca258c81e8829e9b3420b0c3e6bcc2b6eef 100644 --- a/oneflow/core/job_completer/auto_global_step.cpp +++ b/oneflow/core/job_completer/auto_global_step.cpp @@ -36,7 +36,7 @@ void AutoGlobalStep(const OpGraph& op_graph, Job* job) { assign_conf->set_value(GenLogicalBlobName(scalar_add_op_conf.name(), scalar_add_conf->out())); JobBuilder job_builder(job); - job_builder.AddOps(GenParallelConfOfCpuZeroOnAllMachines(), + job_builder.AddOps(GenParallelConfOfCpuZeroOnMaster(), {variable_op_conf, identity_op_conf, scalar_add_op_conf, assign_op_conf}); job->mutable_job_conf()->mutable_train_conf()->set_global_step_lbn(global_step_lbn); } diff --git a/oneflow/core/job_completer/auto_learning_rate.cpp b/oneflow/core/job_completer/auto_learning_rate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5bb213652162f6f77de5662b7b8a9ccaaf0fb0cb --- /dev/null +++ b/oneflow/core/job_completer/auto_learning_rate.cpp @@ -0,0 +1,61 @@ +#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job/job.pb.h" + +namespace oneflow { + +void AutoLearningRate(const OpGraph& op_graph, Job* job) { + JobBuilder job_builder(job); + const TrainConf& train_conf = job->job_conf().train_conf(); + auto AddScheduleOp = [&](const std::string& op_name, const float learning_rate) -> std::string { + const ParallelConf& parallel_conf = + op_graph.OpNode4OpName(GenLogicalBlobId(train_conf.global_step_lbn()).op_name()) + ->parallel_desc() + .parallel_conf(); + const NormalModelUpdateOpUserConf& model_update_conf = train_conf.model_update_conf(); + if (model_update_conf.has_warmup_conf() || model_update_conf.has_learning_rate_decay()) { + OperatorConf schedule_op_conf{}; + schedule_op_conf.set_name(op_name); + LearningRateScheduleOpConf* schedule_conf = + schedule_op_conf.mutable_learning_rate_schedule_conf(); + schedule_conf->set_global_step(train_conf.global_step_lbn()); + schedule_conf->set_learning_rate(learning_rate); + schedule_conf->set_out("out"); + if (model_update_conf.has_warmup_conf()) { + *schedule_conf->mutable_warmup_conf() = model_update_conf.warmup_conf(); + } + if (model_update_conf.has_learning_rate_decay()) { + *schedule_conf->mutable_learning_rate_decay() = model_update_conf.learning_rate_decay(); + } + job_builder.AddOps(parallel_conf, {schedule_op_conf}); + return GenLogicalBlobName(op_name, schedule_conf->out()); + } else { + OperatorConf constant_op_conf{}; + constant_op_conf.set_name(op_name); + ConstantOpConf* constant_conf = constant_op_conf.mutable_constant_conf(); + constant_conf->set_out("out"); + *constant_conf->mutable_shape()->mutable_dim()->Add() = 1; + constant_conf->set_data_type(DataType::kFloat); + constant_conf->mutable_initializer()->mutable_constant_conf()->set_value(learning_rate); + job_builder.AddOps(parallel_conf, {constant_op_conf}); + return GenLogicalBlobName(op_name, constant_conf->out()); + } + }; + if (!train_conf.has_primary_lr_lbn()) { + CHECK(train_conf.has_primary_lr()); + const std::string lbn = + AddScheduleOp("System-Train-PrimaryLearningRate-Scheduler", train_conf.primary_lr()); + job->mutable_job_conf()->mutable_train_conf()->set_primary_lr_lbn(lbn); + } + if (!train_conf.has_secondary_lr_lbn()) { + if (train_conf.has_secondary_lr()) { + const std::string lbn = + AddScheduleOp("System-Train-SecondaryLearningRate-Scheduler", train_conf.secondary_lr()); + job->mutable_job_conf()->mutable_train_conf()->set_secondary_lr_lbn(lbn); + } else { + job->mutable_job_conf()->mutable_train_conf()->set_secondary_lr_lbn( + train_conf.primary_lr_lbn()); + } + } +} + +} // namespace oneflow diff --git a/oneflow/core/job_completer/auto_learning_rate.h b/oneflow/core/job_completer/auto_learning_rate.h new file mode 100644 index 0000000000000000000000000000000000000000..5ffdc5f7296beaf1229f8223477d8bae60fa6479 --- /dev/null +++ b/oneflow/core/job_completer/auto_learning_rate.h @@ -0,0 +1,13 @@ +#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ +#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ + +namespace oneflow { + +class OpGraph; +class Job; + +void AutoLearningRate(const OpGraph& op_graph, Job* job); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index 20e00b67028a10ac1bb0201a05ad74558a72d368..118761df9685bbb6350354ef93c7e26443a72a4a 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -12,6 +12,7 @@ #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" #include "oneflow/core/job_completer/auto_mixed_precision.h" #include "oneflow/core/job_completer/auto_global_step.h" +#include "oneflow/core/job_completer/auto_learning_rate.h" namespace oneflow { @@ -341,6 +342,7 @@ void JobCompleter::Complete(Job* job) const { WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps); WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision); WithOpGraphAndMutJob(job, &AutoGlobalStep); + WithOpGraphAndMutJob(job, &AutoLearningRate); // complete ops for trainning WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning); WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce); diff --git a/oneflow/core/job_completer/lars_optm.cpp b/oneflow/core/job_completer/lars_optm.cpp index 81728fa21559ebddf7b2c823748e8c1531346532..b5ac48e0c294afc96194faa9aa99b768d4c52c7f 100644 --- a/oneflow/core/job_completer/lars_optm.cpp +++ b/oneflow/core/job_completer/lars_optm.cpp @@ -8,6 +8,9 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out, const LogicalBlobId& total_loss_instance_num_lbi) { OperatorConf momentum_var(op.op_conf()); + InitializerConf constant_initializer; + constant_initializer.mutable_constant_conf()->set_value(0.f); + *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer; momentum_var.set_name(op.op_name() + "-momentum"); momentum_var.mutable_variable_conf()->set_out("out"); job_builder->AddOps(parallel_conf, {momentum_var}); @@ -16,7 +19,8 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ OperatorConf mdupdt_op; mdupdt_op.set_name(op.op_name() + "_optimizer"); auto* mdupdt_op_conf = mdupdt_op.mutable_lars_model_update_conf(); - ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf); + ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder, + mdupdt_op_conf); mdupdt_op_conf->set_momentum(momentum_var.name() + "/out"); job_builder->AddOps(parallel_conf, {mdupdt_op}); } diff --git a/oneflow/core/job_completer/momentum_optm.cpp b/oneflow/core/job_completer/momentum_optm.cpp index 443ce2e43d3158a28452bf70071e9d31df62d7fc..0f3964b759850a1fcca1accb20030de457c94930 100644 --- a/oneflow/core/job_completer/momentum_optm.cpp +++ b/oneflow/core/job_completer/momentum_optm.cpp @@ -8,6 +8,9 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out, const LogicalBlobId& total_loss_instance_num_lbi) { OperatorConf momentum_var(op.op_conf()); + InitializerConf constant_initializer; + constant_initializer.mutable_constant_conf()->set_value(0.f); + *(momentum_var.mutable_variable_conf()->mutable_initializer()) = constant_initializer; momentum_var.set_name(op.op_name() + "-momentum"); momentum_var.mutable_variable_conf()->set_out("out"); job_builder->AddOps(parallel_conf, {momentum_var}); @@ -16,7 +19,8 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ OperatorConf mdupdt_op; mdupdt_op.set_name(op.op_name() + "_optimizer"); auto* mdupdt_op_conf = mdupdt_op.mutable_momentum_model_update_conf(); - ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, mdupdt_op_conf); + ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder, + mdupdt_op_conf); mdupdt_op_conf->set_momentum(momentum_var.name() + "/out"); job_builder->AddOps(parallel_conf, {mdupdt_op}); } diff --git a/oneflow/core/job_completer/naive_optm.cpp b/oneflow/core/job_completer/naive_optm.cpp index d688c2539992eac585710a912aac43de45e1cf62..67477480c6cc07c5c7c04641eb78d960dfbbae7b 100644 --- a/oneflow/core/job_completer/naive_optm.cpp +++ b/oneflow/core/job_completer/naive_optm.cpp @@ -9,7 +9,7 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ const LogicalBlobId& total_loss_instance_num_lbi) { OperatorConf mdupdt_op; mdupdt_op.set_name(op.op_name() + "_optimizer"); - ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, + ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder, mdupdt_op.mutable_naive_model_update_conf()); job_builder->AddOps(parallel_conf, {mdupdt_op}); } diff --git a/oneflow/core/job_completer/optimizer.cpp b/oneflow/core/job_completer/optimizer.cpp index 7f85f0635d300b45b028666d8fb0c847c5a5854a..ec4df9f815d3f9bde89ddb45a4ad9f22f2a2708c 100644 --- a/oneflow/core/job_completer/optimizer.cpp +++ b/oneflow/core/job_completer/optimizer.cpp @@ -61,34 +61,36 @@ void BindTwoVariableOpObnSbpConf(const std::string& lhs_op_name, const std::stri template<typename T> void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out, - const LogicalBlobId& total_loss_instance_num_lbi, T* mdupdt_op_conf) { - const auto& train_conf = GlobalJobDesc().job_conf().train_conf(); + const LogicalBlobId& total_loss_instance_num_lbi, + JobBuilder* job_builder, T* mdupdt_op_conf) { + const auto& train_conf = job_builder->job().job_conf().train_conf(); *mdupdt_op_conf->mutable_user_conf() = train_conf.model_update_conf(); mdupdt_op_conf->set_model_diff(GenLogicalBlobName(diff_lbi_of_var_out)); mdupdt_op_conf->set_total_instance_num_diff(GenLogicalBlobName(total_loss_instance_num_lbi)); mdupdt_op_conf->set_model(GenLogicalBlobName(op.BnInOp2Lbi("out"))); - float primary_lr = GlobalJobDesc().primary_lr(); - float secondary_lr = GlobalJobDesc().secondary_lr(); - if (secondary_lr < 0) { secondary_lr = primary_lr; } + mdupdt_op_conf->set_global_step(train_conf.global_step_lbn()); + const std::string& primary_lr_lbn = train_conf.primary_lr_lbn(); + const std::string& secondary_lr_lbn = train_conf.secondary_lr_lbn(); if (op.op_conf().variable_conf().model_name() == "weight") { - mdupdt_op_conf->set_learning_rate(primary_lr); - mdupdt_op_conf->set_l1(GlobalJobDesc().weight_l1()); - mdupdt_op_conf->set_l2(GlobalJobDesc().weight_l2()); + mdupdt_op_conf->set_learning_rate(primary_lr_lbn); + mdupdt_op_conf->set_l1(train_conf.weight_l1()); + mdupdt_op_conf->set_l2(train_conf.weight_l2()); } else if (op.op_conf().variable_conf().model_name() == "bias") { - mdupdt_op_conf->set_learning_rate(secondary_lr); - mdupdt_op_conf->set_l1(GlobalJobDesc().bias_l1()); - mdupdt_op_conf->set_l2(GlobalJobDesc().bias_l2()); + mdupdt_op_conf->set_learning_rate(secondary_lr_lbn); + mdupdt_op_conf->set_l1(train_conf.bias_l1()); + mdupdt_op_conf->set_l2(train_conf.bias_l2()); } else { - mdupdt_op_conf->set_learning_rate(primary_lr); + mdupdt_op_conf->set_learning_rate(primary_lr_lbn); mdupdt_op_conf->set_l1(0); mdupdt_op_conf->set_l2(0); } } -#define INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(T) \ - template void ConstructMdUpdtOpConf<T>( \ - const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out, \ - const LogicalBlobId& total_loss_instance_num_lbi, T* mdupdt_op_conf) +#define INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(T) \ + template void ConstructMdUpdtOpConf<T>(const VariableOp& op, \ + const LogicalBlobId& diff_lbi_of_var_out, \ + const LogicalBlobId& total_loss_instance_num_lbi, \ + JobBuilder* job_builder, T* mdupdt_op_conf) INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(NaiveModelUpdateOpConf); INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(MomentumModelUpdateOpConf); diff --git a/oneflow/core/job_completer/optimizer.h b/oneflow/core/job_completer/optimizer.h index 88197d4e2a83ee96b60f5cb00a8da6754499b82a..9479ec16f9e06bc3e61f8caafab9ec4f05bbc9b5 100644 --- a/oneflow/core/job_completer/optimizer.h +++ b/oneflow/core/job_completer/optimizer.h @@ -15,7 +15,8 @@ void BindTwoVariableOpObnSbpConf(const std::string& lhs_op_name, const std::stri JobBuilder* job_builder); template<typename T> void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out, - const LogicalBlobId& total_loss_instance_num_lbi, T*); + const LogicalBlobId& total_loss_instance_num_lbi, + JobBuilder* job_builder, T*); class GenerateOptimizerOpConfWrapperStruct final { public: diff --git a/oneflow/core/job_completer/rmsprop_optm.cpp b/oneflow/core/job_completer/rmsprop_optm.cpp index 166e1ab75ee0efcc1bde291e5ae0b97132a47cd7..c8021d266dd7e748905ccd931262c84f587f5791 100644 --- a/oneflow/core/job_completer/rmsprop_optm.cpp +++ b/oneflow/core/job_completer/rmsprop_optm.cpp @@ -9,7 +9,7 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_ const LogicalBlobId& total_loss_instance_num_lbi) { OperatorConf mdupdt_op; mdupdt_op.set_name(op.op_name() + "_optimizer"); - ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, + ConstructMdUpdtOpConf(op, diff_lbi_of_var_out, total_loss_instance_num_lbi, job_builder, mdupdt_op.mutable_rmsprop_model_update_conf()); job_builder->AddOps(parallel_conf, {mdupdt_op}); } diff --git a/oneflow/core/kernel/adam_model_update_kernel.cpp b/oneflow/core/kernel/adam_model_update_kernel.cpp index 900dac0f209e366d54009edcfed62e3ef99bf804..ad953fcabb3fbace93c6ee9bbae8ef70d514f41e 100644 --- a/oneflow/core/kernel/adam_model_update_kernel.cpp +++ b/oneflow/core/kernel/adam_model_update_kernel.cpp @@ -31,26 +31,25 @@ const PbMessage& AdamMdUpdateKernel<device_type, T>::GetCustomizedOpConf() const template<DeviceType device_type, typename T> void AdamMdUpdateKernel<device_type, T>::UpdateModel( - DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step, + const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const { Blob* model_blob = BnInOp2Blob("model"); Blob* m_blob = BnInOp2Blob("m"); Blob* v_blob = BnInOp2Blob("v"); Blob* beta1_t_blob = BnInOp2Blob("beta1_t"); Blob* beta2_t_blob = BnInOp2Blob("beta2_t"); const auto& adam_conf = GetAdamModelUpdateConf(this->op_conf()); - if ((next_model_vid != 1) && adam_conf.do_bias_correction()) { - KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta1()), - beta1_t_blob->mut_dptr<T>(), 1); - KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta2()), - beta2_t_blob->mut_dptr<T>(), 1); + if (adam_conf.do_bias_correction()) { + AdamMdUpdateKernelUtil<device_type, T>::DoBiasCorrection( + ctx, global_step, static_cast<T>(adam_conf.beta1()), static_cast<T>(adam_conf.beta2()), + beta1_t_blob->mut_dptr<T>(), beta2_t_blob->mut_dptr<T>()); } KernelUtil<device_type, T>::Div(ctx, model_blob->shape().elem_cnt(), BnInOp2Blob("model_diff")->mut_dptr<T>(), batch_instance_num_ptr); AdamMdUpdateKernelUtil<device_type, T>::UpdateModel( ctx, model_blob->shape().elem_cnt(), learning_rate, l1, l2, static_cast<T>(adam_conf.beta1()), static_cast<T>(adam_conf.beta2()), static_cast<T>(adam_conf.epsilon()), - adam_conf.do_bias_correction(), next_model_vid, + adam_conf.do_bias_correction(), global_step, (beta1_t_blob ? beta1_t_blob->dptr<T>() : nullptr), (beta2_t_blob ? beta2_t_blob->dptr<T>() : nullptr), BnInOp2Blob("model_diff")->mut_dptr<T>(), model_blob->mut_dptr<T>(), m_blob->mut_dptr<T>(), v_blob->mut_dptr<T>()); @@ -59,9 +58,10 @@ void AdamMdUpdateKernel<device_type, T>::UpdateModel( template<typename T> class AdamMdUpdateKernelUtil<DeviceType::kCPU, T> final { public: - static void UpdateModel(DeviceCtx* ctx, int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2, - T epsilon, bool do_bias_correction, int64_t next_model_vid, - const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, T* v) { + static void UpdateModel(DeviceCtx* ctx, int64_t n, const float* learning_rate, T l1, T l2, + T beta1, T beta2, T epsilon, bool do_bias_correction, + const int64_t* global_step, const T* beta1_t, const T* beta2_t, + T* model_diff, T* model, T* m, T* v) { // first-order moment UpdateMomentEstimate<T>(n, do_bias_correction, beta1, 1, model_diff, beta1_t, m); // second-order moment @@ -69,7 +69,14 @@ class AdamMdUpdateKernelUtil<DeviceType::kCPU, T> final { FOR_RANGE(int64_t, i, 0, n) { model_diff[i] = m[i] / (std::sqrt(v[i]) + epsilon); T reg_diff = RegDiff(model_diff[i], l1, l2, model[i]); - model[i] = model[i] - learning_rate * reg_diff; + model[i] = model[i] - *learning_rate * reg_diff; + } + } + static void DoBiasCorrection(DeviceCtx*, const int64_t* global_step, const T beta1, const T beta2, + T* beta1_t, T* beta2_t) { + if (*global_step != 0) { + *beta1_t *= beta1; + *beta2_t *= beta2; } } }; diff --git a/oneflow/core/kernel/adam_model_update_kernel.cu b/oneflow/core/kernel/adam_model_update_kernel.cu index 8c34784799ee1e9800e699f94f513c4f70158431..181bc0cd24ad8cba1aa2454d33d1f11b0d67ae4e 100644 --- a/oneflow/core/kernel/adam_model_update_kernel.cu +++ b/oneflow/core/kernel/adam_model_update_kernel.cu @@ -44,17 +44,17 @@ __device__ void UpdateMomentEstimate(T beta, const T* model_diff, const T* beta_ } template<typename T> -__device__ void UpdateModel(T learning_rate, T l1, T l2, T epsilon, T* model_diff, T* model, T* m, - T* v) { +__device__ void UpdateModel(const float* learning_rate, T l1, T l2, T epsilon, T* model_diff, + T* model, T* m, T* v) { *model_diff = *m / (sqrt(*v) + epsilon); T reg_diff = RegDiff(*model_diff, l1, l2, *model); - *model = *model - learning_rate * reg_diff; + *model = *model - *learning_rate * reg_diff; } template<bool do_bias_correction, typename T> -__global__ void UpdateModelGpu(int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon, - const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, - T* v) { +__global__ void UpdateModelGpu(int64_t n, const float* learning_rate, T l1, T l2, T beta1, T beta2, + T epsilon, const T* beta1_t, const T* beta2_t, T* model_diff, + T* model, T* m, T* v) { CUDA_1D_KERNEL_LOOP(i, n) { UpdateMomentEstimate<1, do_bias_correction>(beta1, model_diff + i, beta1_t, m + i); UpdateMomentEstimate<2, do_bias_correction>(beta2, model_diff + i, beta2_t, v + i); @@ -62,14 +62,24 @@ __global__ void UpdateModelGpu(int64_t n, T learning_rate, T l1, T l2, T beta1, } } +template<typename T> +__global__ void DoBiasCorrectionGpu(const int64_t* global_step, const T beta1, const T beta2, + T* beta1_t, T* beta2_t) { + if (*global_step != 0) { + *beta1_t *= beta1; + *beta2_t *= beta2; + } +} + } // namespace template<typename T> class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final { public: - static void UpdateModel(DeviceCtx* ctx, int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2, - T epsilon, bool do_bias_correction, int64_t next_model_vid, - const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, T* v) { + static void UpdateModel(DeviceCtx* ctx, int64_t n, const float* learning_rate, T l1, T l2, + T beta1, T beta2, T epsilon, bool do_bias_correction, + const int64_t* global_step, const T* beta1_t, const T* beta2_t, + T* model_diff, T* model, T* m, T* v) { if (do_bias_correction) { UpdateModelGpu<true, T> <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( @@ -82,6 +92,12 @@ class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final { m, v); } } + + static void DoBiasCorrection(DeviceCtx* ctx, const int64_t* global_step, const T beta1, + const T beta2, T* beta1_t, T* beta2_t) { + DoBiasCorrectionGpu<T> + <<<1, 1, 0, ctx->cuda_stream()>>>(global_step, beta1, beta2, beta1_t, beta2_t); + } }; #define INSTANTIATE_GPU_KERNEL_UTIL(type_cpp, type_proto) \ diff --git a/oneflow/core/kernel/adam_model_update_kernel.h b/oneflow/core/kernel/adam_model_update_kernel.h index 7567a810a0e6b6f4bbe0ae33eadb9f52cd5f7362..289d518d867748bff1e29b13e1dc24d2624b97bb 100644 --- a/oneflow/core/kernel/adam_model_update_kernel.h +++ b/oneflow/core/kernel/adam_model_update_kernel.h @@ -14,17 +14,19 @@ class AdamMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> { private: const PbMessage& GetCustomizedOpConf() const override; - void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, + void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, + const int64_t* global_step, const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; template<DeviceType device_type, typename T> class AdamMdUpdateKernelUtil final { public: - static void UpdateModel(DeviceCtx*, int64_t n, T learning_rate, T l1, T l2, T beta1, T beta2, - T epsilon, bool do_bias_correction, int64_t next_model_vid, + static void UpdateModel(DeviceCtx*, int64_t n, const float* learning_rate, T l1, T l2, T beta1, + T beta2, T epsilon, bool do_bias_correction, const int64_t* global_step, const T* beta1_t, const T* beta2_t, T* model_diff, T* model, T* m, T* v); + static void DoBiasCorrection(DeviceCtx*, const int64_t* global_step, T beta1, T beta2, T* beta1_t, + T* beta2_t); }; DECLARE_MDUPDT_KERNEL_CREATOR(Adam); diff --git a/oneflow/core/kernel/lars_model_update_kernel.cpp b/oneflow/core/kernel/lars_model_update_kernel.cpp index 31091f2ac458c14052522d87ac7ee7bf1fb45f05..341abb4853ccde032e397d5b557bbecd2a4e3c35 100644 --- a/oneflow/core/kernel/lars_model_update_kernel.cpp +++ b/oneflow/core/kernel/lars_model_update_kernel.cpp @@ -18,23 +18,19 @@ const PbMessage& LARSMdUpdateKernel<device_type, T>::GetCustomizedOpConf() const template<DeviceType device_type, typename T> void LARSMdUpdateKernel<device_type, T>::UpdateModel( - DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step, + const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); Blob* momentum_blob = BnInOp2Blob("momentum"); - Blob* data_tmp_blob = BnInOp2Blob("data_tmp"); + Blob* data_tmp_blob = BnInOp2Blob("lars_data_tmp"); const LARSModelUpdateConf& lars_conf = GetLARSModelUpdateConf(this->op_conf()); - if (next_model_vid == 1) { - Memset<device_type>(ctx, momentum_blob->mut_dptr<T>(), 0, - momentum_blob->ByteSizeOfDataContentField()); - } Memset<device_type>(ctx, data_tmp_blob->mut_dptr<T>(), 0, data_tmp_blob->ByteSizeOfDataContentField()); LARSMdUpdateKernelUtil<device_type, T>::UpdateModel( ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, learning_rate, l1, l2, static_cast<T>(lars_conf.momentum_beta()), static_cast<T>(lars_conf.epsilon()), - static_cast<T>(lars_conf.lars_coefficient()), next_model_vid, model_diff_blob->dptr<T>(), + static_cast<T>(lars_conf.lars_coefficient()), global_step, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(), momentum_blob->mut_dptr<T>(), data_tmp_blob->mut_dptr<T>()); } @@ -42,9 +38,9 @@ template<typename T> class LARSMdUpdateKernelUtil<DeviceType::kCPU, T> final { public: static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, - T learning_rate, T l1, T l2, T momentum_beta, T epsilon, - T lars_coefficient, int64_t next_model_vid, const T* model_diff, T* model, - T* momentum, T* data_tmp) { + const float* learning_rate, T l1, T l2, T momentum_beta, T epsilon, + T lars_coefficient, const int64_t* global_step, const T* model_diff, + T* model, T* momentum, T* data_tmp) { T model_norm = 0; T model_diff_norm = 0; FOR_RANGE(int64_t, i, 0, n) { @@ -54,11 +50,11 @@ class LARSMdUpdateKernelUtil<DeviceType::kCPU, T> final { model_norm = std::sqrt(model_norm / n); model_diff_norm = std::sqrt(model_diff_norm / n); T local_learning_rate = 0; - if (next_model_vid == 1) { + if (*global_step == 0) { local_learning_rate = - learning_rate * lars_coefficient * model_norm / (epsilon + model_diff_norm); + *learning_rate * lars_coefficient * model_norm / (epsilon + model_diff_norm); } else { - local_learning_rate = learning_rate * lars_coefficient * model_norm + local_learning_rate = *learning_rate * lars_coefficient * model_norm / (epsilon + model_diff_norm + l2 * model_norm); } FOR_RANGE(int64_t, i, 0, n) { diff --git a/oneflow/core/kernel/lars_model_update_kernel.cu b/oneflow/core/kernel/lars_model_update_kernel.cu index 36f1143564f8725831036898a3c0c682285ada72..cd74e947fdf736feb2b114d918b7df5f5753ea83 100644 --- a/oneflow/core/kernel/lars_model_update_kernel.cu +++ b/oneflow/core/kernel/lars_model_update_kernel.cu @@ -7,19 +7,19 @@ namespace oneflow { namespace { template<typename T> -__global__ void GetLocalLearningRateGpu(const T* batch_instance_num_ptr, T learning_rate, T l2, - T epsilon, T lars_coefficient, int64_t next_model_vid, - T* data_tmp) { +__global__ void GetLocalLearningRateGpu(const T* batch_instance_num_ptr, const float* learning_rate, + T l2, T epsilon, T lars_coefficient, + const int64_t* global_step, T* data_tmp) { T* model_norm = &data_tmp[0]; T* model_diff_norm = &data_tmp[1]; T* local_learning_rate = &data_tmp[2]; *model_norm = std::sqrt(*model_norm); *model_diff_norm = std::sqrt(*model_diff_norm) / *batch_instance_num_ptr; // TODO(shiyuan) - if (next_model_vid == 1) { + if (*global_step == 0) { *local_learning_rate = - learning_rate * lars_coefficient * (*model_norm) / (epsilon + (*model_diff_norm)); + *learning_rate * lars_coefficient * (*model_norm) / (epsilon + (*model_diff_norm)); } else { - *local_learning_rate = learning_rate * lars_coefficient * (*model_norm) + *local_learning_rate = *learning_rate * lars_coefficient * (*model_norm) / (epsilon + (*model_diff_norm) + l2 * (*model_diff_norm)); } } @@ -41,14 +41,14 @@ template<typename T> class LARSMdUpdateKernelUtil<DeviceType::kGPU, T> final { public: static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, - T learning_rate, T l1, T l2, T momentum_beta, T epsilon, - T lars_coefficient, int64_t next_model_vid, const T* model_diff, T* model, - T* momentum, T* data_tmp) { + const float* learning_rate, T l1, T l2, T momentum_beta, T epsilon, + T lars_coefficient, const int64_t* global_step, const T* model_diff, + T* model, T* momentum, T* data_tmp) { KernelUtil<DeviceType::kGPU, T>::Dot(ctx, n, model, 1, model, 1, &data_tmp[0]); KernelUtil<DeviceType::kGPU, T>::Dot(ctx, n, model_diff, 1, model_diff, 1, &data_tmp[1]); GetLocalLearningRateGpu<T> <<<1, 1, 0, ctx->cuda_stream()>>>(batch_instance_num_ptr, learning_rate, l2, epsilon, - lars_coefficient, next_model_vid, data_tmp); + lars_coefficient, global_step, data_tmp); UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( n, batch_instance_num_ptr, l1, l2, momentum_beta, model_diff, model, momentum, data_tmp); } diff --git a/oneflow/core/kernel/lars_model_update_kernel.h b/oneflow/core/kernel/lars_model_update_kernel.h index ce90fe15847899c5d1d200368590e6252a834194..ca37bbabafa3331b65acb61c5acb0a688404d894 100644 --- a/oneflow/core/kernel/lars_model_update_kernel.h +++ b/oneflow/core/kernel/lars_model_update_kernel.h @@ -14,18 +14,18 @@ class LARSMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> { private: const PbMessage& GetCustomizedOpConf() const override; - void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, + void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, + const int64_t* global_step, const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; template<DeviceType device_type, typename T> class LARSMdUpdateKernelUtil final { public: - static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate, - T l1, T l2, T momentum_beta, T epsilon, T lars_coefficient, - int64_t next_model_vid, const T* model_diff, T* model, T* momentum, - T* data_tmp); + static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, + const float* learning_rate, T l1, T l2, T momentum_beta, T epsilon, + T lars_coefficient, const int64_t* global_step, const T* model_diff, + T* model, T* momentum, T* data_tmp); }; DECLARE_MDUPDT_KERNEL_CREATOR(LARS); diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cpp b/oneflow/core/kernel/momentum_model_update_kernel.cpp index 032dcbab6143cba3801ff8559a3d8d6e64b5c322..92bf35609c9fbf9f57cab0304625547cde7ba05d 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.cpp +++ b/oneflow/core/kernel/momentum_model_update_kernel.cpp @@ -18,17 +18,16 @@ const PbMessage& MomentumMdUpdateKernel<device_type, T>::GetCustomizedOpConf() c template<DeviceType device_type, typename T> void MomentumMdUpdateKernel<device_type, T>::UpdateModel( - DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step, + const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); Blob* momentum_blob = BnInOp2Blob("momentum"); float beta = GetMomentumModelUpdateConf(this->op_conf()).beta(); - if (next_model_vid == 1) { beta = 0.0f; } MomentumMdUpdateKernelUtil<device_type, T>::UpdateModel( ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, static_cast<T>(beta), - learning_rate, l1, l2, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(), + global_step, learning_rate, l1, l2, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(), momentum_blob->mut_dptr<T>()); } @@ -36,10 +35,12 @@ template<typename T> class MomentumMdUpdateKernelUtil<DeviceType::kCPU, T> final { public: static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T beta, - T learning_rate, T l1, T l2, const T* model_diff, T* model, T* momentum) { + const int64_t* global_step, const float* learning_rate, T l1, T l2, + const T* model_diff, T* model, T* momentum) { + T cur_beta = *global_step == 0 ? 0 : beta; for (int64_t i = 0; i != n; ++i) { T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); - momentum[i] = beta * momentum[i] - learning_rate * reg_diff; + momentum[i] = cur_beta * momentum[i] - *learning_rate * reg_diff; model[i] = model[i] + momentum[i]; } } diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cu b/oneflow/core/kernel/momentum_model_update_kernel.cu index e0a74db6fc9808dffeeb653472d33af836160bc1..6c0bfd2fda5c0b68e1667f4cb6f105fc3b68eb73 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.cu +++ b/oneflow/core/kernel/momentum_model_update_kernel.cu @@ -7,11 +7,13 @@ namespace oneflow { namespace { template<typename T> -__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T beta, T learning_rate, - T l1, T l2, const T* model_diff, T* model, T* momentum) { +__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T beta, + const int64_t* global_step, const float* learning_rate, T l1, T l2, + const T* model_diff, T* model, T* momentum) { + T cur_beta = *global_step == 0 ? 0 : beta; CUDA_1D_KERNEL_LOOP(i, n) { T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); - momentum[i] = beta * momentum[i] - learning_rate * reg_diff; + momentum[i] = cur_beta * momentum[i] - *learning_rate * reg_diff; model[i] = model[i] + momentum[i]; } } @@ -22,10 +24,11 @@ template<typename T> class MomentumMdUpdateKernelUtil<DeviceType::kGPU, T> final { public: static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, T beta, - T learning_rate, const T l1, const T l2, const T* model_diff, T* model, - T* momentum) { + const int64_t* global_step, const float* learning_rate, const T l1, + const T l2, const T* model_diff, T* model, T* momentum) { UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - n, batch_instance_num_ptr, beta, learning_rate, l1, l2, model_diff, model, momentum); + n, batch_instance_num_ptr, beta, global_step, learning_rate, l1, l2, model_diff, model, + momentum); } }; diff --git a/oneflow/core/kernel/momentum_model_update_kernel.h b/oneflow/core/kernel/momentum_model_update_kernel.h index 0fbbfe6b67da1bf2feaf7e3cc1d3aaecf3191da2..8442d2dcac350b9a4017afcde1de40ac8b05740e 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.h +++ b/oneflow/core/kernel/momentum_model_update_kernel.h @@ -16,8 +16,8 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> const PbMessage& GetCustomizedOpConf() const override; private: - void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, + void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, + const int64_t* global_step, const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; @@ -25,7 +25,8 @@ template<DeviceType device_type, typename T> class MomentumMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T beta, - T learning_rate, T l1, T l2, const T* model_diff, T* model, T* momentum); + const int64_t* global_step, const float* learning_rate, T l1, T l2, + const T* model_diff, T* model, T* momentum); }; DECLARE_MDUPDT_KERNEL_CREATOR(Momentum); diff --git a/oneflow/core/kernel/naive_model_update_kernel.cpp b/oneflow/core/kernel/naive_model_update_kernel.cpp index d6a086ee5027cfd0836cc215451cc461d5bc5ed5..3ca58170c7bb7dbda3a8f31a2b7a6f62388706b3 100644 --- a/oneflow/core/kernel/naive_model_update_kernel.cpp +++ b/oneflow/core/kernel/naive_model_update_kernel.cpp @@ -5,8 +5,8 @@ namespace oneflow { template<DeviceType device_type, typename T> void NaiveMdUpdateKernel<device_type, T>::UpdateModel( - DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step, + const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); // model = model - alpha * model_diff @@ -24,10 +24,10 @@ template<typename T> class NaiveMdUpdateKernelUtil<DeviceType::kCPU, T> final { public: static void UpdateModel(DeviceCtx*, const int64_t n, const T* batch_instance_num_ptr, - T learning_rate, T l1, T l2, const T* model_diff, T* model) { + const float* learning_rate, T l1, T l2, const T* model_diff, T* model) { for (int64_t i = 0; i != n; ++i) { T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); - model[i] = model[i] - learning_rate * reg_diff; + model[i] = model[i] - *learning_rate * reg_diff; } } }; diff --git a/oneflow/core/kernel/naive_model_update_kernel.cu b/oneflow/core/kernel/naive_model_update_kernel.cu index 4e21d3656954b75ac6cdf8d83b06c8d93bf9c78e..2233e57af2b3416a01a7ad43ecac8bf8fbbf9bb8 100644 --- a/oneflow/core/kernel/naive_model_update_kernel.cu +++ b/oneflow/core/kernel/naive_model_update_kernel.cu @@ -7,11 +7,12 @@ namespace oneflow { namespace { template<typename T> -__global__ void UpdateModelGpu(const int64_t n, const T* batch_instance_num_ptr, T learning_rate, - T l1, T l2, const T* model_diff, T* model) { +__global__ void UpdateModelGpu(const int64_t n, const T* batch_instance_num_ptr, + const float* learning_rate, T l1, T l2, const T* model_diff, + T* model) { CUDA_1D_KERNEL_LOOP(i, n) { T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); - model[i] = model[i] - learning_rate * reg_diff; + model[i] = model[i] - *learning_rate * reg_diff; } } @@ -21,7 +22,7 @@ template<typename T> class NaiveMdUpdateKernelUtil<DeviceType::kGPU, T> final { public: static void UpdateModel(DeviceCtx* ctx, const int64_t n, const T* batch_instance_num_ptr, - T learning_rate, T l1, T l2, const T* model_diff, T* model) { + const float* learning_rate, T l1, T l2, const T* model_diff, T* model) { UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( n, batch_instance_num_ptr, learning_rate, l1, l2, model_diff, model); } diff --git a/oneflow/core/kernel/naive_model_update_kernel.h b/oneflow/core/kernel/naive_model_update_kernel.h index 2307a395977db21b4a0036fa1b973ff83136359b..c203f63631bc66707bdadaa9cd3f18c815381cc0 100644 --- a/oneflow/core/kernel/naive_model_update_kernel.h +++ b/oneflow/core/kernel/naive_model_update_kernel.h @@ -14,16 +14,16 @@ class NaiveMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> { private: const PbMessage& GetCustomizedOpConf() const override; - void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, + void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, + const int64_t* global_step, const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; template<DeviceType device_type, typename T> class NaiveMdUpdateKernelUtil final { public: - static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate, - T l1, T l2, const T* model_diff, T* model); + static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, + const float* learning_rate, T l1, T l2, const T* model_diff, T* model); }; DECLARE_MDUPDT_KERNEL_CREATOR(Naive); diff --git a/oneflow/core/kernel/normal_model_update_kernel.cpp b/oneflow/core/kernel/normal_model_update_kernel.cpp index f5114d71bf0b02064692f5639b02e8947bf338bd..9e132f069caafd004cbe98d1d37fbf68192812db 100644 --- a/oneflow/core/kernel/normal_model_update_kernel.cpp +++ b/oneflow/core/kernel/normal_model_update_kernel.cpp @@ -10,28 +10,18 @@ namespace oneflow { template<DeviceType device_type, typename T> void NormalMdUpdateKernel<device_type, T>::Forward( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - int64_t cur_batch_num = std::get<0>( - *(reinterpret_cast<std::tuple<int64_t, std::function<const Blob*(const LogicalBlobId&)>>*>( - ctx.other))); - int64_t next_model_vid = cur_batch_num + 1; const PbMessage& op_conf = this->GetCustomizedOpConf(); const auto& conf = *GetMsgPtrFromPbMessage<NormalModelUpdateOpUserConf>(op_conf, "user_conf"); - float learning_rate = GetValFromPbMessage<float>(op_conf, "learning_rate"); const T* batch_instance_num_ptr = BnInOp2Blob("total_instance_num_diff")->dptr<T>(); + const int64_t* global_step_ptr = BnInOp2Blob("global_step")->dptr<int64_t>(); + const float* learning_rate_ptr = BnInOp2Blob("learning_rate")->dptr<float>(); if (conf.has_clip_conf()) { - ClipGradient(ctx.device_ctx, cur_batch_num, conf.clip_conf(), batch_instance_num_ptr, - BnInOp2Blob); - } - if (TriggerWarmup(conf, learning_rate, next_model_vid)) { - learning_rate = GetWarmupLearningRate(conf.warmup_conf(), learning_rate, next_model_vid); - } else if (conf.has_learning_rate_decay()) { - learning_rate = - GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num); + ClipGradient(ctx.device_ctx, conf.clip_conf(), batch_instance_num_ptr, BnInOp2Blob); } float l1 = GetValFromPbMessage<float>(op_conf, "l1"); float l2 = GetValFromPbMessage<float>(op_conf, "l2"); - UpdateModel(ctx.device_ctx, batch_instance_num_ptr, static_cast<T>(learning_rate), - static_cast<T>(l1), static_cast<T>(l2), next_model_vid, BnInOp2Blob); + UpdateModel(ctx.device_ctx, batch_instance_num_ptr, static_cast<T>(l1), static_cast<T>(l2), + global_step_ptr, learning_rate_ptr, BnInOp2Blob); } #define INSTANTIATE_KERNEL(device_type, data_type_pair) \ @@ -58,109 +48,8 @@ Kernel* CreateMdUpdtKernel(const KernelConf& kernel_conf) { } } -double ExponentialDecayedLearningRate(const ExponentialDecayConf& conf, double lr, - int64_t cur_batch_num) { - CHECK_GT(conf.decay_batches(), 0); - double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches()); - if (conf.staircase()) { p = std::floor(p); } - return lr * std::pow(conf.decay_rate(), p); -} - -double InverseTimeDecayedLearningRate(const InverseTimeDecayConf& conf, double lr, - int64_t cur_batch_num) { - CHECK_GT(conf.decay_batches(), 0); - double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches()); - if (conf.staircase()) { p = std::floor(p); } - return lr / (1.0 + conf.decay_rate() * p); -} - -double NaturalExpDecayedLearningRate(const NaturalExpDecayConf& conf, double lr, - int64_t cur_batch_num) { - CHECK_GT(conf.decay_batches(), 0); - double p = static_cast<double>(cur_batch_num) / static_cast<double>(conf.decay_batches()); - if (conf.staircase()) { p = std::floor(p); } - return lr * std::exp(-conf.decay_rate() * p); -} - -double PiecewiseConstantLearningRate(const PiecewiseConstantConf& conf, double lr, - int64_t cur_batch_num) { - const PbRf<int64_t>& boundaries = conf.boundaries(); - const PbRf<double>& values = conf.values(); - CHECK_EQ(boundaries.size() + 1, values.size()); - size_t i = 0; - for (; i < boundaries.size(); ++i) { - if (cur_batch_num <= boundaries[i]) { break; } - } - return values[i]; -} - -double PolynomialDecayedLearningRate(const PolynomialDecayConf& conf, double lr, - int64_t cur_batch_num) { - CHECK_GT(conf.decay_batches(), 0); - double cur_batch = static_cast<double>(cur_batch_num); - double decay_batches = static_cast<double>(conf.decay_batches()); - if (conf.cycle()) { - if (cur_batch_num == 0) { cur_batch = 1.0; } - decay_batches = decay_batches * std::ceil(cur_batch / decay_batches); - } else { - cur_batch = std::min(cur_batch, decay_batches); - } - return (lr - conf.end_learning_rate()) * std::pow(1.0 - (cur_batch / decay_batches), conf.power()) - + conf.end_learning_rate(); -} - -double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t cur_batch_num) { - CHECK_GT(conf.decay_batches(), 0); - const double PI = std::atan(1.0) * 4.0; - double cur_batch = static_cast<double>(cur_batch_num); - double decay_batches = static_cast<double>(conf.decay_batches()); - cur_batch = std::min(cur_batch, decay_batches); - double cosine_decay = 0.5 * (1.0 + std::cos(PI * cur_batch / decay_batches)); - double decayed = (1.0 - conf.alpha()) * cosine_decay + conf.alpha(); - return lr * decayed; -} - -double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr, - int64_t cur_batch_num) { - CHECK_GT(conf.decay_batches(), 0); - const double PI = std::atan(1.0) * 4.0; - double cur_batch = static_cast<double>(cur_batch_num); - double decay_batches = static_cast<double>(conf.decay_batches()); - cur_batch = std::min(cur_batch, decay_batches); - double linear_decay = (decay_batches - cur_batch) / decay_batches; - double cosine_decay = - 0.5 * (1.0 + std::cos(PI * 2.0 * conf.num_periods() * cur_batch / decay_batches)); - double decayed = (conf.alpha() + linear_decay) * cosine_decay + conf.beta(); - return lr * decayed; -} - -double ConstantWarmupLearningRate(const ConstantWarmupConf& conf, double lr, - int64_t next_batch_num) { - CHECK_GT(conf.warmup_batches(), 0); - CHECK_GT(conf.multiplier(), 0); - CHECK_LT(conf.multiplier(), 1); - if (next_batch_num <= conf.warmup_batches()) { - return lr * conf.multiplier(); - } else { - return lr; - } -} - -double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t next_batch_num) { - CHECK_GT(conf.warmup_batches(), 0); - CHECK_GE(conf.start_multiplier(), 0); - CHECK_LT(conf.start_multiplier(), 1); - double start_multiplier = conf.start_multiplier(); - double multiplier = 1.0; - if (next_batch_num <= conf.warmup_batches()) { - multiplier = start_multiplier - + (1.0 - start_multiplier) * (next_batch_num * 1.0 / conf.warmup_batches()); - } - return lr * multiplier; -} - template<DeviceType device_type, typename T> -void ClipByGlobalNorm(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipByGlobalNormConf& conf, +void ClipByGlobalNorm(DeviceCtx* ctx, const ClipByGlobalNormConf& conf, const T* batch_instance_num_ptr, std::function<Blob*(const std::string&)> BnInOp2Blob) { int64_t n = BnInOp2Blob("model_diff")->shape().elem_cnt(); @@ -182,62 +71,13 @@ void ClipByGlobalNorm(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipByG } // namespace -template<DeviceType device_type, typename T> -bool NormalMdUpdateKernel<device_type, T>::TriggerWarmup(const NormalModelUpdateOpUserConf& conf, - double lr, int64_t next_batch_num) const { - if (!conf.has_warmup_conf()) { return false; } - const WarmupConf& warmup_conf = conf.warmup_conf(); - if (warmup_conf.has_constant_conf()) { - return (next_batch_num <= warmup_conf.constant_conf().warmup_batches()); - } else if (warmup_conf.has_linear_conf()) { - return (next_batch_num <= warmup_conf.linear_conf().warmup_batches()); - } else { - UNIMPLEMENTED(); - } -} - -template<DeviceType device_type, typename T> -double NormalMdUpdateKernel<device_type, T>::GetWarmupLearningRate(const WarmupConf& conf, - double lr, - int64_t next_batch_num) const { - if (conf.has_constant_conf()) { - return ConstantWarmupLearningRate(conf.constant_conf(), lr, next_batch_num); - } else if (conf.has_linear_conf()) { - return LinearWarmupLearningRate(conf.linear_conf(), lr, next_batch_num); - } else { - UNIMPLEMENTED(); - } -} - template<DeviceType device_type, typename T> void NormalMdUpdateKernel<device_type, T>::ClipGradient( - DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf, - const T* batch_instance_num_ptr, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + DeviceCtx* ctx, const ClipConf& conf, const T* batch_instance_num_ptr, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { if (conf.has_clip_by_global_norm()) { - ClipByGlobalNorm<device_type, T>(ctx, cur_batch_num, conf.clip_by_global_norm(), - batch_instance_num_ptr, BnInOp2Blob); - } else { - UNIMPLEMENTED(); - } -} - -template<DeviceType device_type, typename T> -double NormalMdUpdateKernel<device_type, T>::GetDecayedLearningRate( - const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) const { - if (conf.has_exponential_conf()) { - return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num); - } else if (conf.has_inverse_time_conf()) { - return InverseTimeDecayedLearningRate(conf.inverse_time_conf(), lr, cur_batch_num); - } else if (conf.has_natural_exp_conf()) { - return NaturalExpDecayedLearningRate(conf.natural_exp_conf(), lr, cur_batch_num); - } else if (conf.has_piecewise_constant_conf()) { - return PiecewiseConstantLearningRate(conf.piecewise_constant_conf(), lr, cur_batch_num); - } else if (conf.has_polynomial_conf()) { - return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num); - } else if (conf.has_cosine_conf()) { - return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num); - } else if (conf.has_linear_cosine_conf()) { - return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num); + ClipByGlobalNorm<device_type, T>(ctx, conf.clip_by_global_norm(), batch_instance_num_ptr, + BnInOp2Blob); } else { UNIMPLEMENTED(); } diff --git a/oneflow/core/kernel/normal_model_update_kernel.h b/oneflow/core/kernel/normal_model_update_kernel.h index ae5accae68144baf2db3ac891b0818c5614a6623..040e89819d111e984ef1a7d6c0bb20139933ef05 100644 --- a/oneflow/core/kernel/normal_model_update_kernel.h +++ b/oneflow/core/kernel/normal_model_update_kernel.h @@ -16,18 +16,12 @@ class NormalMdUpdateKernel : public KernelIf<device_type> { protected: NormalMdUpdateKernel() = default; - virtual void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, - T l2, int64_t next_model_vid, + virtual void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, + const int64_t* global_step, const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0; private: - bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, - int64_t cur_batch_num) const; - double GetWarmupLearningRate(const WarmupConf&, double lr, int64_t cur_batch_num) const; - double GetDecayedLearningRate(const LearningRateDecayConf&, double lr, - int64_t cur_batch_num) const; - void ClipGradient(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf, - const T* batch_instance_num_ptr, + void ClipGradient(DeviceCtx* ctx, const ClipConf& conf, const T* batch_instance_num_ptr, std::function<Blob*(const std::string&)> BnInOp2Blob) const; }; diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.cpp b/oneflow/core/kernel/rmsprop_model_update_kernel.cpp index d169e7657580e8b6c7d8a09831ec805e14d2f3c4..193943400cc0b901b16d5afee6e6eb5c39cda9ad 100644 --- a/oneflow/core/kernel/rmsprop_model_update_kernel.cpp +++ b/oneflow/core/kernel/rmsprop_model_update_kernel.cpp @@ -18,31 +18,30 @@ const PbMessage& RMSPropMdUpdateKernel<device_type, T>::GetCustomizedOpConf() co template<DeviceType device_type, typename T> void RMSPropMdUpdateKernel<device_type, T>::UpdateModel( - DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, const int64_t* global_step, + const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); Blob* mean_square_blob = BnInOp2Blob("mean_square"); const RMSPropModelUpdateConf& conf = GetRMSPropModelUpdateConf(this->op_conf()); - float decay_rate = conf.decay_rate(); - if (next_model_vid == 1) { decay_rate = 0.0f; } RMSPropMdUpdateKernelUtil<device_type, T>::UpdateModel( - ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, learning_rate, - static_cast<T>(decay_rate), static_cast<T>(conf.epsilon()), l1, l2, + ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, global_step, learning_rate, + static_cast<T>(conf.decay_rate()), static_cast<T>(conf.epsilon()), l1, l2, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(), mean_square_blob->mut_dptr<T>()); } template<typename T> class RMSPropMdUpdateKernelUtil<DeviceType::kCPU, T> final { public: - static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate, - T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model, - T* mean_square) { + static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, + const int64_t* global_step, const float* learning_rate, T decay_rate, + T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) { + const T cur_decay_rate = *global_step == 0 ? 0 : decay_rate; for (int64_t i = 0; i < n; ++i) { T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); - mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i]; - model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); + mean_square[i] = (1 - cur_decay_rate) * reg_diff * reg_diff + cur_decay_rate * mean_square[i]; + model[i] = model[i] - *learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); } } }; diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.cu b/oneflow/core/kernel/rmsprop_model_update_kernel.cu index d02973d4481df200d6062ef72449840a1277f1b4..54757ae6c9f82855d9e25e67c25b6e6ad6bf40e7 100644 --- a/oneflow/core/kernel/rmsprop_model_update_kernel.cu +++ b/oneflow/core/kernel/rmsprop_model_update_kernel.cu @@ -7,13 +7,15 @@ namespace oneflow { namespace { template<typename T> -__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T learning_rate, - T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model, +__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, + const int64_t* global_step, const float* learning_rate, T decay_rate, + T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) { + const T cur_decay_rate = *global_step == 0 ? 0 : decay_rate; CUDA_1D_KERNEL_LOOP(i, n) { T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); - mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i]; - model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); + mean_square[i] = (1 - cur_decay_rate) * reg_diff * reg_diff + cur_decay_rate * mean_square[i]; + model[i] = model[i] - *learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); } } @@ -23,11 +25,11 @@ template<typename T> class RMSPropMdUpdateKernelUtil<DeviceType::kGPU, T> final { public: static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, - T learning_rate, T decay_rate, T epsilon, T l1, T l2, const T* model_diff, - T* model, T* mean_square) { + const int64_t* global_step, const float* learning_rate, T decay_rate, + T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) { UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - n, batch_instance_num_ptr, learning_rate, decay_rate, epsilon, l1, l2, model_diff, model, - mean_square); + n, batch_instance_num_ptr, global_step, learning_rate, decay_rate, epsilon, l1, l2, + model_diff, model, mean_square); } }; diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.h b/oneflow/core/kernel/rmsprop_model_update_kernel.h index 8c404c1b83e00e3cc2f8b5de1c7067dd4c76513e..4b2ed30a5aaebd91877045b98cee69adba3b7f3c 100644 --- a/oneflow/core/kernel/rmsprop_model_update_kernel.h +++ b/oneflow/core/kernel/rmsprop_model_update_kernel.h @@ -14,8 +14,8 @@ class RMSPropMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> private: const PbMessage& GetCustomizedOpConf() const override; - void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, - int64_t next_model_vid, + void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T l1, T l2, + const int64_t* global_step, const float* learning_rate, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; @@ -24,9 +24,9 @@ class RMSPropMdUpdateKernelUtil final { public: // mean_square = (1 - decay_rate) * model_diff ^ 2 + decay_rate * mean_square // model = model - learning_rate * model_diff / sqrt(mean_square + epsilon) - static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate, - T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model, - T* mean_square); + static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, + const int64_t* global_step, const float* learning_rate, T decay_rate, + T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square); }; DECLARE_MDUPDT_KERNEL_CREATOR(RMSProp); diff --git a/oneflow/core/operator/lars_model_update_op.cpp b/oneflow/core/operator/lars_model_update_op.cpp index 573aad5a727c934eb39a8919c09918a9d75884bc..6a82c5af9d11f21fed508534f70ec3aa2cf7f939 100644 --- a/oneflow/core/operator/lars_model_update_op.cpp +++ b/oneflow/core/operator/lars_model_update_op.cpp @@ -4,7 +4,7 @@ namespace oneflow { void LARSModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollInputBn("momentum", false)->set_is_mutable(true); - EnrollTmpBn("data_tmp"); + EnrollTmpBn("lars_data_tmp"); } Maybe<void> LARSModelUpdateOp::MdUpdtVirtualInferBlobDescs( @@ -13,11 +13,11 @@ Maybe<void> LARSModelUpdateOp::MdUpdtVirtualInferBlobDescs( const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model"); CHECK_OR_RETURN(*GetBlobDesc4BnInOp("momentum") == *model_blob_desc); - // data_tmp for gpu compute - // data_tmp[0] for model_norm, data_tmp[1] for model_diff_norm, data_tmp[2] for + // lars_data_tmp for gpu compute + // lars_data_tmp[0] for model_norm, lars_data_tmp[1] for model_diff_norm, lars_data_tmp[2] for // local_learning_rate - *GetBlobDesc4BnInOp("data_tmp") = *model_blob_desc; - GetBlobDesc4BnInOp("data_tmp")->mut_shape() = Shape({3}); + *GetBlobDesc4BnInOp("lars_data_tmp") = *model_blob_desc; + GetBlobDesc4BnInOp("lars_data_tmp")->mut_shape() = Shape({3}); return Maybe<void>::Ok(); } @@ -26,7 +26,7 @@ const PbMessage& LARSModelUpdateOp::GetCustomizedConf() const { } const HashSet<std::string> LARSModelUpdateOp::AlwaysBroadcastParallelBns() const { - return HashSet<std::string>{"data_tmp"}; + return HashSet<std::string>{"lars_data_tmp"}; } REGISTER_CLASS(NormalModelUpdateOpUserConf::kLarsConf, NormalModelUpdtOp, LARSModelUpdateOp); diff --git a/oneflow/core/operator/normal_model_update_op.cpp b/oneflow/core/operator/normal_model_update_op.cpp index 5c0d8c658b08045643bd0c1bf0134b735598f9ff..9c8f04b1d7b002e77db48cb0fc28e57191761c3c 100644 --- a/oneflow/core/operator/normal_model_update_op.cpp +++ b/oneflow/core/operator/normal_model_update_op.cpp @@ -7,6 +7,8 @@ void NormalModelUpdtOp::InitFromOpConf() { EnrollInputBn("model_diff", false); EnrollInputBn("total_instance_num_diff", false); EnrollInputBn("model", false)->set_is_mutable(true); + EnrollInputBn("learning_rate", false); + EnrollInputBn("global_step", false); const PbMessage& conf = this->GetCustomizedConf(); const auto& user_conf = *GetMsgPtrFromPbMessage<NormalModelUpdateOpUserConf>(conf, "user_conf"); if (user_conf.has_clip_conf() && user_conf.clip_conf().has_clip_by_global_norm()) { @@ -50,6 +52,8 @@ void NormalModelUpdtOp::GetSbpSignatures( const auto& bns = AlwaysBroadcastParallelBns(); PbRpf<std::string> broadcast_bns = {bns.begin(), bns.end()}; *broadcast_bns.Add() = "total_instance_num_diff"; + *broadcast_bns.Add() = "learning_rate"; + *broadcast_bns.Add() = "global_step"; FOR_RANGE(int64_t, i, 0, LogicalBlobDesc4Ibn("model").shape().NumAxes()) { SbpSignatureBuilder() .Split(input_bns(), i) diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index bdc2ef3e48b751e1a37a1135c9d728108ac536c5..ba1e7fca7f6577cdce2b7d17761833cae656ec7e 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -574,9 +574,10 @@ message NormalModelUpdateOpConf { required string model_diff = 2; required string total_instance_num_diff = 3; required string model = 4; - required float learning_rate = 5; - required float l1 = 6; - required float l2 = 7; + required string global_step = 5; + required string learning_rate = 6; + required float l1 = 7; + required float l2 = 8; } message NaiveModelUpdateOpConf { @@ -584,9 +585,10 @@ message NaiveModelUpdateOpConf { required string model_diff = 2; required string total_instance_num_diff = 3; required string model = 4; - required float learning_rate = 5; - required float l1 = 6; - required float l2 = 7; + required string global_step = 5; + required string learning_rate = 6; + required float l1 = 7; + required float l2 = 8; } message MomentumModelUpdateOpConf { @@ -595,9 +597,10 @@ message MomentumModelUpdateOpConf { required string model_diff = 3; required string total_instance_num_diff = 4; required string model = 5; - required float learning_rate = 6; - required float l1 = 7; - required float l2 = 8; + required string global_step = 6; + required string learning_rate = 7; + required float l1 = 8; + required float l2 = 9; } message RMSPropModelUpdateOpConf { @@ -605,9 +608,10 @@ message RMSPropModelUpdateOpConf { required string model_diff = 2; required string total_instance_num_diff = 3; required string model = 4; - required float learning_rate = 5; - required float l1 = 6; - required float l2 = 7; + required string global_step = 5; + required string learning_rate = 6; + required float l1 = 7; + required float l2 = 8; } message LARSModelUpdateOpConf { @@ -616,9 +620,10 @@ message LARSModelUpdateOpConf { required string model_diff = 3; required string total_instance_num_diff = 4; required string model = 5; - required float learning_rate = 6; - required float l1 = 7; - required float l2 = 8; + required string global_step = 6; + required string learning_rate = 7; + required float l1 = 8; + required float l2 = 9; } message AdamModelUpdateOpConf { @@ -630,9 +635,10 @@ message AdamModelUpdateOpConf { required string model_diff = 6; required string total_instance_num_diff = 7; required string model = 8; - required float learning_rate = 9; - required float l1 = 10; - required float l2 = 11; + required string global_step = 9; + required string learning_rate = 10; + required float l1 = 11; + required float l2 = 12; } message AccumulateOpConf {