Skip to content
Snippets Groups Projects
Commit a80acf32 authored by Juncheng's avatar Juncheng Committed by Li Xinqi
Browse files

add AutoGlobalStep (#2073)

parent 069b3fcf
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,7 @@ message TrainConf { ...@@ -17,6 +17,7 @@ message TrainConf {
required int32 num_of_batches_in_snapshot = 5; required int32 num_of_batches_in_snapshot = 5;
repeated string loss_lbn = 6; repeated string loss_lbn = 6;
optional int32 loss_scale_factor = 7 [default = 1]; optional int32 loss_scale_factor = 7 [default = 1];
optional string global_step_lbn = 8;
// default_initializer_conf here is now deprecated // default_initializer_conf here is now deprecated
optional InitializerConf default_initializer_conf = 100; optional InitializerConf default_initializer_conf = 100;
......
...@@ -133,4 +133,20 @@ std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx( ...@@ -133,4 +133,20 @@ std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx(
} }
} }
ParallelConf GenParallelConfOfCpuZeroOnMaster() {
ParallelConf parallel_conf;
parallel_conf.set_policy(kDataParallel);
parallel_conf.add_device_name("0:cpu:0");
return parallel_conf;
}
ParallelConf GenParallelConfOfCpuZeroOnAllMachines() {
ParallelConf parallel_conf;
parallel_conf.set_policy(kDataParallel);
FOR_RANGE(int64_t, i, 0, Global<ResourceDesc>::Get()->TotalMachineNum()) {
parallel_conf.add_device_name(std::to_string(i) + ":cpu:0");
}
return parallel_conf;
}
} // namespace oneflow } // namespace oneflow
...@@ -66,6 +66,9 @@ inline bool operator!=(const ParallelConf& lhs, const ParallelConf& rhs) { ...@@ -66,6 +66,9 @@ inline bool operator!=(const ParallelConf& lhs, const ParallelConf& rhs) {
std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx( std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx(
const ParallelContext* parallel_ctx); const ParallelContext* parallel_ctx);
ParallelConf GenParallelConfOfCpuZeroOnMaster();
ParallelConf GenParallelConfOfCpuZeroOnAllMachines();
} // namespace oneflow } // namespace oneflow
namespace std { namespace std {
......
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job.pb.h"
namespace oneflow {
void AutoGlobalStep(const OpGraph& op_graph, Job* job) {
if (job->job_conf().train_conf().has_global_step_lbn()) { return; }
OperatorConf variable_op_conf{};
const std::string global_step_name = "System-Train-GlobalStep-" + job->job_conf().job_name();
variable_op_conf.set_name(global_step_name);
VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();
variable_conf->set_out("out");
*variable_conf->mutable_shape()->mutable_dim()->Add() = 1;
variable_conf->set_data_type(DataType::kInt64);
variable_conf->mutable_initializer()->mutable_constant_int_conf()->set_value(0);
OperatorConf identity_op_conf{};
identity_op_conf.set_name(global_step_name + "-Identity");
IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf();
identity_conf->set_in(GenLogicalBlobName(variable_op_conf.name(), variable_conf->out()));
identity_conf->set_out("out");
const std::string& global_step_lbn =
GenLogicalBlobName(identity_op_conf.name(), identity_conf->out());
OperatorConf scalar_add_op_conf{};
scalar_add_op_conf.set_name(global_step_name + "-ScalarAdd");
ScalarAddOpConf* scalar_add_conf = scalar_add_op_conf.mutable_scalar_add_conf();
scalar_add_conf->set_in(global_step_lbn);
scalar_add_conf->set_out("out");
scalar_add_conf->set_int_operand(1);
OperatorConf assign_op_conf{};
assign_op_conf.set_name(global_step_name + "-Assign");
AssignOpConf* assign_conf = assign_op_conf.mutable_assign_conf();
assign_conf->set_ref(GenLogicalBlobName(variable_op_conf.name(), variable_conf->out()));
assign_conf->set_value(GenLogicalBlobName(scalar_add_op_conf.name(), scalar_add_conf->out()));
JobBuilder job_builder(job);
job_builder.AddOps(GenParallelConfOfCpuZeroOnAllMachines(),
{variable_op_conf, identity_op_conf, scalar_add_op_conf, assign_op_conf});
job->mutable_job_conf()->mutable_train_conf()->set_global_step_lbn(global_step_lbn);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_GLOBAL_STEP_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_GLOBAL_STEP_H_
namespace oneflow {
class OpGraph;
class Job;
void AutoGlobalStep(const OpGraph& op_graph, Job* job);
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_GLOBAL_STEP_H_
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/auto_mixed_precision.h" #include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job_completer/auto_global_step.h"
namespace oneflow { namespace oneflow {
...@@ -339,6 +340,7 @@ void JobCompleter::Complete(Job* job) const { ...@@ -339,6 +340,7 @@ void JobCompleter::Complete(Job* job) const {
if (GlobalJobDesc().IsTrain()) { if (GlobalJobDesc().IsTrain()) {
WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps); WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps);
WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision); WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision);
WithOpGraphAndMutJob(job, &AutoGlobalStep);
// complete ops for trainning // complete ops for trainning
WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning); WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning);
WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce); WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce);
......
...@@ -326,11 +326,15 @@ inline LogicalBlobId GenLogicalBlobId(const std::string& lbn) { ...@@ -326,11 +326,15 @@ inline LogicalBlobId GenLogicalBlobId(const std::string& lbn) {
return lbi; return lbi;
} }
inline std::string GenLogicalBlobName(const std::string& op_name, const std::string& blob_name) {
return op_name + "/" + blob_name;
}
inline std::string GenLogicalBlobName(const LogicalBlobId& lbi) { inline std::string GenLogicalBlobName(const LogicalBlobId& lbi) {
CHECK_EQ(lbi.has_op_name(), true); CHECK_EQ(lbi.has_op_name(), true);
CHECK_EQ(lbi.has_blob_name(), true); CHECK_EQ(lbi.has_blob_name(), true);
CHECK_EQ(lbi.is_packed_id(), false); CHECK_EQ(lbi.is_packed_id(), false);
return lbi.op_name() + "/" + lbi.blob_name(); return GenLogicalBlobName(lbi.op_name(), lbi.blob_name());
} }
} // namespace oneflow } // namespace oneflow
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment