Skip to content
Snippets Groups Projects
Commit fb138d8a authored by lixinqi's avatar lixinqi
Browse files

add pass DumpTimeShapeAndBlobParallelConfPass

parent 3457db13
No related branches found
No related tags found
No related merge requests found
......@@ -746,8 +746,8 @@ std::list<OpNode*> OpGraph::DataOrCtrlSourceNodes() const {
return ret;
}
void OpGraph::DumpLogicalBlobDesc(JobBuilder* job_builder) const {
auto* helper = job_builder->mutable_helper();
void OpGraph::DumpLogicalBlobDesc(Job* job) const {
auto* helper = job->mutable_helper();
ForEachNode([&](const OpNode* node) {
for (const auto& obn : node->op().output_bns()) {
const auto& lbi = node->op().BnInOp2Lbi(obn);
......@@ -757,17 +757,17 @@ void OpGraph::DumpLogicalBlobDesc(JobBuilder* job_builder) const {
});
}
void OpGraph::DumpSbpSignature(JobBuilder* job_builder) const {
void OpGraph::DumpSbpSignature(Job* job) const {
ForEachNode([&](const OpNode* node) {
(*job_builder->mutable_sbp_conf()->mutable_op_name2sbp_signature_conf())[node->op().op_name()] =
(*job->mutable_sbp_conf()->mutable_op_name2sbp_signature_conf())[node->op().op_name()] =
node->sbp_signature();
});
}
void OpGraph::DumpOpTimeShape(JobBuilder* job_builder) const {
void OpGraph::DumpOpTimeShape(Job* job) const {
ForEachNode([&](OpNode* op_node) {
auto* op_time_shape =
&(*job_builder->mutable_helper()->mutable_op_name2op_time_shape())[op_node->op().op_name()];
&(*job->mutable_helper()->mutable_op_name2op_time_shape())[op_node->op().op_name()];
if (op_node->out_blob_time_shape() != nullptr) {
op_node->out_blob_time_shape()->ToProto(op_time_shape->mutable_out_blob_time_shape());
}
......@@ -778,14 +778,15 @@ void OpGraph::DumpOpTimeShape(JobBuilder* job_builder) const {
});
}
void OpGraph::DumpBatchAxisLbi(JobBuilder* job_builder) const {
auto* lbn2batch_axis = job_builder->mutable_helper()->mutable_lbn2batch_axis();
void OpGraph::DumpBatchAxisLbi(Job* job) const {
auto* lbn2batch_axis = job->mutable_helper()->mutable_lbn2batch_axis();
ForEachNode([&](OpNode* op_node) {
for (const auto& obn : op_node->op().output_bns()) {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(obn);
const auto& lbn = GenLogicalBlobName(lbi);
const auto& pair = PbMapPair<std::string, OptInt64>(lbn, op_node->BatchAxis4Lbi(lbi));
CHECK(lbn2batch_axis->insert(pair).second);
const auto& batch_axis = op_node->BatchAxis4Lbi(lbi);
const auto& pair = PbMapPair<std::string, OptInt64>(lbn, batch_axis);
CHECK(lbn2batch_axis->insert(pair).first->second == batch_axis);
}
});
}
......
......@@ -149,10 +149,10 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
void ForEachDataAndCtrlInNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;
void ForEachDataAndCtrlOutNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;
void DumpLogicalBlobDesc(JobBuilder* job_builder) const;
void DumpSbpSignature(JobBuilder* job_builder) const;
void DumpOpTimeShape(JobBuilder* job_builder) const;
void DumpBatchAxisLbi(JobBuilder* job_builder) const;
void DumpLogicalBlobDesc(Job* job) const;
void DumpSbpSignature(Job* job) const;
void DumpOpTimeShape(Job* job) const;
void DumpBatchAxisLbi(Job* job) const;
private:
void Init(const Job& job);
......
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
......@@ -40,23 +41,28 @@ Maybe<void> JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) {
return Maybe<void>::Ok();
}
REGISTER_FUNCTION_CONFIG_DEF().Bool("is_user_function", true, "is user defined function");
Maybe<void> JobBuildAndInferCtx::Complete() {
CHECK_NOTNULL(Global<JobDesc>::Get());
Global<JobDesc>::Delete();
auto scope = std::make_unique<GlobalJobDescScope>(job_->job_conf(), job_id_);
auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); };
DoPass("CompleteOfrecordDecoder");
DoPass("SetDefaultVariableConf");
DoPass("AutoMixedPrecision");
DoPass("TieUpChainHeadersUnReachableFromAnyVariableOps");
DoPass("NonDistributedOptimizerPass");
DoPass("AutoTrainStep");
DoPass("AutoLearningRate");
DoPass("GenerateBackwardAndOptimizerOpConfs");
DoPass("SequentializeNcclTupleBroadcastReducePass");
DoPass("AddAllReduceGroupPass");
DoPass("AddLbiDiffWatcherOpConfs");
DoPass("SequentializeAllReduceGroupPass");
if (GlobalJobDesc().Bool("is_user_function")) {
DoPass("CompleteOfrecordDecoder");
DoPass("SetDefaultVariableConf");
DoPass("AutoMixedPrecision");
DoPass("TieUpChainHeadersUnReachableFromAnyVariableOps");
DoPass("NonDistributedOptimizerPass");
DoPass("AutoTrainStep");
DoPass("AutoLearningRate");
DoPass("GenerateBackwardAndOptimizerOpConfs");
DoPass("SequentializeNcclTupleBroadcastReducePass");
DoPass("AddAllReduceGroupPass");
DoPass("AddLbiDiffWatcherOpConfs");
DoPass("SequentializeAllReduceGroupPass");
}
DoPass("DumpTimeShapeAndBlobParallelConfPass");
return Maybe<void>::Ok();
}
......
......@@ -678,16 +678,17 @@ void MakePushJob(const std::string& job_name, const std::string& op_name,
void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
std::vector<Job> jobs(conf_jobs.size());
std::vector<Plan> sub_plans(conf_jobs.size());
FOR_RANGE(int64_t, job_id, 0, sub_plans.size()) {
jobs.at(job_id) = conf_jobs.Get(job_id);
AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id);
size_t user_job_size = jobs.size();
int64_t job_id = -1;
FOR_RANGE(int64_t, i, 0, sub_plans.size()) {
jobs.at(i) = conf_jobs.Get(i);
AddJobName2JobId(jobs.at(i).job_conf().job_name(), i);
{
auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(job_id).job_conf(), job_id);
CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true);
auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(i).job_conf(), i);
CompileCurJobOnMaster(&jobs.at(i), &sub_plans.at(i), true);
}
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
size_t user_job_size = jobs.size();
HashMap<std::string, ParallelBlobConf> push_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, &jobs,
&push_op_name2parallel_blob_conf);
......@@ -697,7 +698,6 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, &jobs,
&var_op_name2parallel_blob_conf);
int64_t job_id = -1;
{
size_t helper_job_size =
push_op_name2parallel_blob_conf.size() + pull_op_name2parallel_blob_conf.size();
......@@ -727,6 +727,8 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
CompileHelperJob(&pull_job);
}
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, [&](Job* job) { CompileHelperJob(job); });
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
MergeSubPlanWithoutGenNetTopo(plan, sub_plans);
InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(jobs, plan, user_job_size);
InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan);
......
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job.pb.h"
namespace oneflow {
namespace {
class DumpTimeShapeAndBlobParallelConfPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(DumpTimeShapeAndBlobParallelConfPass);
DumpTimeShapeAndBlobParallelConfPass() = default;
~DumpTimeShapeAndBlobParallelConfPass() override = default;
bool IsEnabled() const override { return true; }
void Apply(const OpGraph& op_graph, Job* job) const override {
op_graph.DumpOpTimeShape(job);
op_graph.DumpBatchAxisLbi(job);
op_graph.DumpLogicalBlobDesc(job);
op_graph.DumpSbpSignature(job);
}
};
REGISTER_FUNCTION_PASS("DumpTimeShapeAndBlobParallelConfPass",
DumpTimeShapeAndBlobParallelConfPass);
} // namespace
} // namespace oneflow
......@@ -73,20 +73,10 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
}
}
void SetOpTimeShape7BatchAxisLbis(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.DumpOpTimeShape(job_builder);
op_graph.DumpBatchAxisLbi(job_builder);
}
void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.DumpLogicalBlobDesc(job_builder);
op_graph.DumpSbpSignature(job_builder);
}
} // namespace
void JobCompleter::Complete(Job* job) const {
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
FunctionPass("DumpTimeShapeAndBlobParallelConfPass")(job);
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp);
WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp);
......@@ -97,9 +87,7 @@ void JobCompleter::Complete(Job* job) const {
AddGlobalTotalJobCriticalSection(*job);
WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job, &SetOpTimeShape7BatchAxisLbis);
FunctionPass("DumpTimeShapeAndBlobParallelConfPass")(job);
if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT
WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment