diff --git a/oneflow/core/job/job.proto b/oneflow/core/job/job.proto index f487664a4a091eb394129cc5b316c9601dc28107..d099516f81e69c20e2d70398e5c80ae000026ff7 100644 --- a/oneflow/core/job/job.proto +++ b/oneflow/core/job/job.proto @@ -17,6 +17,7 @@ message TrainConf { required int32 num_of_batches_in_snapshot = 5; repeated string loss_lbn = 6; optional int32 loss_scale_factor = 7 [default = 1]; + optional string global_step_lbn = 8; // default_initializer_conf here is now deprecated optional InitializerConf default_initializer_conf = 100; diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index 55f56bb3045162f05b75d982a725edfb168d8878..49f47ea31472e3bf67ea573398695e5f8d864edb 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -133,4 +133,20 @@ std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx( } } +ParallelConf GenParallelConfOfCpuZeroOnMaster() { + ParallelConf parallel_conf; + parallel_conf.set_policy(kDataParallel); + parallel_conf.add_device_name("0:cpu:0"); + return parallel_conf; +} + +ParallelConf GenParallelConfOfCpuZeroOnAllMachines() { + ParallelConf parallel_conf; + parallel_conf.set_policy(kDataParallel); + FOR_RANGE(int64_t, i, 0, Global<ResourceDesc>::Get()->TotalMachineNum()) { + parallel_conf.add_device_name(std::to_string(i) + ":cpu:0"); + } + return parallel_conf; +} + } // namespace oneflow diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index f385a0b67861dfb42d888621eff44ecf91618964..5e4dc890ee42eb52b97e8303bd26a56480ebfc7b 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -66,6 +66,9 @@ inline bool operator!=(const ParallelConf& lhs, const ParallelConf& rhs) { std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx( const ParallelContext* parallel_ctx); +ParallelConf GenParallelConfOfCpuZeroOnMaster(); +ParallelConf GenParallelConfOfCpuZeroOnAllMachines(); + } // namespace oneflow namespace std { diff --git a/oneflow/core/job_completer/auto_global_step.cpp b/oneflow/core/job_completer/auto_global_step.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8b32a9cac7b9facd4de632a494f1555bd77096b --- /dev/null +++ b/oneflow/core/job_completer/auto_global_step.cpp @@ -0,0 +1,44 @@ +#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job/job.pb.h" + +namespace oneflow { + +void AutoGlobalStep(const OpGraph& op_graph, Job* job) { + if (job->job_conf().train_conf().has_global_step_lbn()) { return; } + OperatorConf variable_op_conf{}; + const std::string global_step_name = "System-Train-GlobalStep-" + job->job_conf().job_name(); + variable_op_conf.set_name(global_step_name); + VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); + variable_conf->set_out("out"); + *variable_conf->mutable_shape()->mutable_dim()->Add() = 1; + variable_conf->set_data_type(DataType::kInt64); + variable_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(0); + + OperatorConf identity_op_conf{}; + identity_op_conf.set_name(global_step_name + "-Identity"); + IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf(); + identity_conf->set_in(GenLogicalBlobName(variable_op_conf.name(), variable_conf->out())); + identity_conf->set_out("out"); + const std::string& global_step_lbn = + GenLogicalBlobName(identity_op_conf.name(), identity_conf->out()); + + OperatorConf scalar_add_op_conf{}; + scalar_add_op_conf.set_name(global_step_name + "-ScalarAdd"); + ScalarAddOpConf* scalar_add_conf = scalar_add_op_conf.mutable_scalar_add_conf(); + scalar_add_conf->set_in(global_step_lbn); + scalar_add_conf->set_out("out"); + scalar_add_conf->set_int_operand(1); + + OperatorConf assign_op_conf{}; + assign_op_conf.set_name(global_step_name + "-Assign"); + AssignOpConf* assign_conf = assign_op_conf.mutable_assign_conf(); + assign_conf->set_ref(GenLogicalBlobName(variable_op_conf.name(), variable_conf->out())); + assign_conf->set_value(GenLogicalBlobName(scalar_add_op_conf.name(), scalar_add_conf->out())); + + JobBuilder job_builder(job); + job_builder.AddOps(GenParallelConfOfCpuZeroOnAllMachines(), + {variable_op_conf, identity_op_conf, scalar_add_op_conf, assign_op_conf}); + job->mutable_job_conf()->mutable_train_conf()->set_global_step_lbn(global_step_lbn); +} + +} // namespace oneflow diff --git a/oneflow/core/job_completer/auto_global_step.h b/oneflow/core/job_completer/auto_global_step.h new file mode 100644 index 0000000000000000000000000000000000000000..afef8a996a824fe81a04a22cf6b18698c72046d1 --- /dev/null +++ b/oneflow/core/job_completer/auto_global_step.h @@ -0,0 +1,13 @@ +#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_GLOBAL_STEP_H_ +#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_GLOBAL_STEP_H_ + +namespace oneflow { + +class OpGraph; +class Job; + +void AutoGlobalStep(const OpGraph& op_graph, Job* job); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_GLOBAL_STEP_H_ diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index 64518765cca64a311687400d0ca33ead5c2e6485..20e00b67028a10ac1bb0201a05ad74558a72d368 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -11,6 +11,7 @@ #include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" #include "oneflow/core/job_completer/auto_mixed_precision.h" +#include "oneflow/core/job_completer/auto_global_step.h" namespace oneflow { @@ -339,6 +340,7 @@ void JobCompleter::Complete(Job* job) const { if (GlobalJobDesc().IsTrain()) { WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps); WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision); + WithOpGraphAndMutJob(job, &AutoGlobalStep); // complete ops for trainning WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning); WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce); diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index d7cf5dacea549e9e1425c6c29c3f33103c645332..0162b6b555fd822b3549626a6c0e997a747f722d 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -326,11 +326,15 @@ inline LogicalBlobId GenLogicalBlobId(const std::string& lbn) { return lbi; } +inline std::string GenLogicalBlobName(const std::string& op_name, const std::string& blob_name) { + return op_name + "/" + blob_name; +} + inline std::string GenLogicalBlobName(const LogicalBlobId& lbi) { CHECK_EQ(lbi.has_op_name(), true); CHECK_EQ(lbi.has_blob_name(), true); CHECK_EQ(lbi.is_packed_id(), false); - return lbi.op_name() + "/" + lbi.blob_name(); + return GenLogicalBlobName(lbi.op_name(), lbi.blob_name()); } } // namespace oneflow