Skip to content
Snippets Groups Projects
Commit 5f84cae8 authored by lixinqi's avatar lixinqi
Browse files

enable pseudo chain merge

parent 2955965e
No related branches found
No related tags found
No related merge requests found
Showing
with 304 additions and 77 deletions
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace {
template<ConfigDefType config_def_type>
ConfigDef* MutGlobalConfigDef() {
static ConfigDef config_def;
return &config_def;
}
template<ConfigDefType config_def_type>
void CheckNoExistedField(const std::string& name) {
const auto& flags = MutGlobalConfigDef<config_def_type>()->flag();
auto Found = [&](const ConfigFlagDef& existed) { return existed.name() == name; };
CHECK(std::find_if(flags.begin(), flags.end(), Found) == flags.end());
}
} // namespace
const ConfigDef& GlobalEnvConfigDef() { return *MutGlobalConfigDef<kEnvConfigType>(); }
const ConfigDef& GlobalSessionConfigDef() { return *MutGlobalConfigDef<kSessionConfigType>(); }
const ConfigDef& GlobalFunctionConfigDef() { return *MutGlobalConfigDef<kFunctionConfigType>(); }
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Bool(
const std::string& name, bool default_val) const {
CheckNoExistedField<config_def_type>(name);
auto* flag = MutGlobalConfigDef<config_def_type>()->mutable_flag()->Add();
flag->set_name(name);
flag->set_type(UserOpAttrType::kAtBool);
flag->mutable_default_val()->set_at_bool(default_val);
return *this;
}
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Int64(
const std::string& name, int64_t default_val) const {
CheckNoExistedField<config_def_type>(name);
auto* flag = MutGlobalConfigDef<config_def_type>()->mutable_flag()->Add();
flag->set_name(name);
flag->set_type(UserOpAttrType::kAtInt64);
flag->mutable_default_val()->set_at_int64(default_val);
return *this;
}
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Double(
const std::string& name, double default_val) const {
CheckNoExistedField<config_def_type>(name);
auto* flag = MutGlobalConfigDef<config_def_type>()->mutable_flag()->Add();
flag->set_name(name);
flag->set_type(UserOpAttrType::kAtDouble);
flag->mutable_default_val()->set_at_double(default_val);
return *this;
}
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::String(
const std::string& name, const std::string& default_val) const {
CheckNoExistedField<config_def_type>(name);
auto* flag = MutGlobalConfigDef<config_def_type>()->mutable_flag()->Add();
flag->set_name(name);
flag->set_type(UserOpAttrType::kAtString);
flag->mutable_default_val()->set_at_string(default_val);
return *this;
}
template class ConfigDefBuidler<kEnvConfigType>;
template class ConfigDefBuidler<kSessionConfigType>;
template class ConfigDefBuidler<kFunctionConfigType>;
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_CONFIG_DEF_H_
#define ONEFLOW_CORE_JOB_CONFIG_DEF_H_
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/framework/user_op_attr.pb.h"
#include "oneflow/core/framework/config_def.pb.h"
namespace oneflow {
template<ConfigDefType config_def_type>
struct ConfigDefBuidler final {
const ConfigDefBuidler& Bool(const std::string& name, bool default_val) const;
const ConfigDefBuidler& Int64(const std::string& name, int64_t default_val) const;
const ConfigDefBuidler& Double(const std::string& name, double default_val) const;
const ConfigDefBuidler& String(const std::string& name, const std::string& default_val) const;
};
#define REGISTER_ENV_CONFIG_DEF() REGISTER_CONFIG_DEF(kEnvConfigType)
#define REGISTER_SESSION_CONFIG_DEF() REGISTER_CONFIG_DEF(kSessionConfigType)
#define REGISTER_FUNCTION_CONFIG_DEF() REGISTER_CONFIG_DEF(kFunctionConfigType)
#define REGISTER_CONFIG_DEF(config_def_type) \
static ConfigDefBuidler<config_def_type> OF_PP_CAT(g_##config_def_type##_def_, __LINE__) = \
ConfigDefBuidler<config_def_type>()
const ConfigDef& GlobalEnvConfigDef();
const ConfigDef& GlobalSessionConfigDef();
const ConfigDef& GlobalFunctionConfigDef();
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_CONFIG_DEF_H_
syntax = "proto2";
package oneflow;
import "oneflow/core/framework/user_op_attr.proto";
enum ConfigDefType {
kEnvConfigType = 1;
kSessionConfigType = 2;
kFunctionConfigType = 3;
}
message ConfigFlagDef {
required string name = 1;
required UserOpAttrType type = 2;
optional UserOpAttrVal default_val = 3;
}
message ConfigDef {
repeated ConfigFlagDef flag = 1;
}
syntax = "proto2";
package oneflow;
import "oneflow/core/common/shape.proto";
enum UserOpAttrType {
kAtInt32 = 1;
kAtInt64 = 2;
kAtBool = 3;
kAtFloat = 4;
kAtDouble = 5;
kAtString = 6;
kAtShape = 7;
kAtListInt32 = 8;
kAtListInt64 = 9;
kAtListFloat = 10;
}
message UserOpAttrVal {
message ListInt32 {
repeated int32 val = 1;
}
message ListInt64 {
repeated int64 val = 1;
}
message ListFloat {
repeated float val = 1;
}
oneof value {
int32 at_int32 = 1;
int64 at_int64 = 2;
bool at_bool = 3;
float at_float = 4;
double at_double = 5;
string at_string = 6;
ShapeProto at_shape = 7;
ListInt32 at_list_int32 = 8;
ListInt64 at_list_int64 = 9;
ListFloat at_list_float = 10;
}
}
......@@ -10,6 +10,7 @@ import "oneflow/core/register/blob_desc.proto";
import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/common/shape.proto";
import "oneflow/core/job/sbp_parallel.proto";
import "oneflow/core/framework/user_op_attr.proto";
message TrainConf {
required NormalModelUpdateOpUserConf model_update_conf = 3;
......@@ -105,6 +106,8 @@ message JobConfigProto {
optional bool enable_auto_mixed_precision = 602 [default = false];
optional int64 concurrency_width = 1000 [default = 128];
map<string, UserOpAttrVal> flag_name2flag_value = 2000;
}
message OpTimeShape {
......
......@@ -8,9 +8,31 @@
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/framework/config_def.h"
namespace oneflow {
namespace {
void CheckAndCompleteFunctionConfigDefault(JobConfigProto* job_conf) {
HashMap<std::string, const ConfigFlagDef*> name2config_flag;
for (const ConfigFlagDef& config_flag : GlobalFunctionConfigDef().flag()) {
name2config_flag[config_flag.name()] = &config_flag;
}
for (const auto& pair : job_conf->flag_name2flag_value()) {
const auto& iter = name2config_flag.find(pair.first);
CHECK(iter != name2config_flag.end());
CHECK_EQ(static_cast<int>(iter->second->type()), static_cast<int>(pair.second.value_case()));
}
for (const ConfigFlagDef& config_flag : GlobalFunctionConfigDef().flag()) {
const auto& iter = job_conf->flag_name2flag_value().find(config_flag.name());
if (iter != job_conf->flag_name2flag_value().end()) { continue; }
(*job_conf->mutable_flag_name2flag_value())[config_flag.name()] = config_flag.default_val();
}
}
} // namespace
int64_t JobDesc::all_reduce_group_min_byte() const {
int64_t ret = job_conf_.all_reduce_group_min_mbyte() * 1024 * 1024;
CHECK_GT(ret, 0);
......@@ -83,6 +105,7 @@ void JobDesc::Init() {
#ifndef WITH_CUDA
CHECK_EQ(Global<ResourceDesc>::Get()->GpuDeviceNum(), 0);
#endif
CheckAndCompleteFunctionConfigDefault(&job_conf_);
}
bool IsInterfaceOpConf(const OperatorConf& op_conf) {
......
......@@ -4,6 +4,7 @@
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/dlnet_conf.pb.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/framework/user_op_attr.pb.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/job/inter_user_job_info.pb.h"
#include "oneflow/core/register/logical_blob_id.pb.h"
......@@ -62,6 +63,17 @@ class JobDesc final {
bool all_reduce_fp16() const;
int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); }
#define DEFINE_FUNCTION_CONFIG_GETTER(T, func_name, field_name) \
T func_name(const std::string& field_name) const { \
const UserOpAttrVal& attr_val = job_conf_.flag_name2flag_value().at(field_name); \
CHECK(attr_val.has_##field_name()); \
return attr_val.field_name(); \
}
DEFINE_FUNCTION_CONFIG_GETTER(bool, Bool, at_bool);
DEFINE_FUNCTION_CONFIG_GETTER(int64_t, Int64, at_int64);
DEFINE_FUNCTION_CONFIG_GETTER(double, Double, at_double);
DEFINE_FUNCTION_CONFIG_GETTER(const std::string&, String, at_string);
// Train conf
int64_t TotalBatchNum() const;
int64_t NumOfPiecesInBatch() const;
......
......@@ -261,7 +261,7 @@ void BuildAllReduceStruct(
} // namespace
void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) {
auto GetLastTouchedOpName = MakeGetterArg4ThisFuncPrevCall();
auto ProducerOpNode4Lbi = MakeGetterProducerOpNode4Lbi(op_graph);
std::vector<LogicalBlobId> lbis;
......
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class AllReduceAddPass final {
class AllReduceAddPass final : public OpGraphPass {
public:
AllReduceAddPass() = default;
~AllReduceAddPass() = default;
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;
bool IsEnabled() const override {
return !GlobalJobDesc().enable_non_distributed_optimizer()
&& GlobalJobDesc().enable_all_reduce_group();
}
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) override;
};
} // namespace oneflow
......
......@@ -54,7 +54,7 @@ void ReOrderAllReduceGroups(std::vector<AllReduceGroup>* all_reduce_groups) {
} // namespace
void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) {
std::vector<AllReduceGroup> all_reduce_groups;
FindAllReduceGroups(op_graph, &all_reduce_groups);
ReOrderAllReduceGroups(&all_reduce_groups);
......
......@@ -2,17 +2,18 @@
#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class AllReduceSequencePass final {
class AllReduceSequencePass final : public OpGraphPass {
public:
AllReduceSequencePass() = default;
~AllReduceSequencePass() = default;
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;
bool IsEnabled() const override { return !GlobalJobDesc().disable_all_reduce_sequence(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) override;
};
} // namespace oneflow
......
#include <algorithm>
#include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
......@@ -177,6 +178,7 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>&
} // namespace
void AutoMixedPrecision::Apply(const OpGraph& op_graph, JobBuilder* job_builder) {
CHECK_GE(CUDA_VERSION, 10000);
CHECK(GlobalJobDesc().DefaultDataType() == DataType::kFloat);
std::function<std::string(OpNode* const&)> OpName4Node = [](OpNode* const& node) {
......
......@@ -2,7 +2,7 @@
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
......@@ -10,15 +10,18 @@ class OpGraph;
class OpNode;
class Job;
class AutoMixedPrecision final {
class AutoMixedPrecision final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision);
AutoMixedPrecision(const AMPList& white, const AMPList& black, const AMPList& gray,
const AMPList& clear)
: white_list_(white), black_list_(black), gray_list_(gray), clear_list_(clear) {}
AutoMixedPrecision()
: white_list_(AutoMixedPrecisionLists::WhiteList()),
black_list_(AutoMixedPrecisionLists::BlackList()),
gray_list_(AutoMixedPrecisionLists::GrayList()),
clear_list_(AutoMixedPrecisionLists::ClearList()) {}
~AutoMixedPrecision() = default;
void Apply(const OpGraph& op_graph, JobBuilder* job_builder);
bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) override;
private:
void FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set);
......
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/job_completer/job_completer.h"
#include "oneflow/core/job_completer/autograd.h"
#include "oneflow/core/job_completer/autotick.h"
......@@ -15,6 +14,7 @@
#include "oneflow/core/job_completer/auto_train_step.h"
#include "oneflow/core/job_completer/auto_learning_rate.h"
#include "oneflow/core/job_completer/add_lbi_diff_watcher.h"
#include "oneflow/core/framework/config_def.h"
namespace oneflow {
......@@ -230,35 +230,39 @@ std::function<bool(OpNode*)> MakePredicatorIsReachableFromAnyVariableOps(const O
};
}
void TieUpChainHeadersUnReachableFromAnyVariableOps(const OpGraph& op_graph, Job* job) {
auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph);
auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes,
std::vector<OpNode*>* source_nodes,
std::vector<OpEdge*>* source_edges) {
for (OpNode* node : chain_nodes) {
for (OpEdge* edge : node->in_edges()) {
if (chain_nodes.find(edge->src_node()) == chain_nodes.end()
&& IsReachableFromAnyVariableOps(edge->src_node()) == false) {
source_edges->push_back(edge);
source_nodes->push_back(node);
REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", true);
class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass {
bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); }
void Apply(const OpGraph& op_graph, Job* job) override {
auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph);
auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes,
std::vector<OpNode*>* source_nodes,
std::vector<OpEdge*>* source_edges) {
for (OpNode* node : chain_nodes) {
for (OpEdge* edge : node->in_edges()) {
if (chain_nodes.find(edge->src_node()) == chain_nodes.end()
&& IsReachableFromAnyVariableOps(edge->src_node()) == false) {
source_edges->push_back(edge);
source_nodes->push_back(node);
}
}
}
}
};
auto MutOperatorConf4OpName = MakeMutableOperatorConf4OpName(job);
auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job->placement());
op_graph.ForEachChainFamily([&](const HashSet<OpNode*>& chain_nodes) {
std::vector<OpNode*> source_nodes;
std::vector<OpEdge*> source_edges;
GetSourceNodesAndEdges(chain_nodes, &source_nodes, &source_edges);
if (source_edges.size() <= 1) { return; }
if (source_nodes.size() <= 1) { return; }
// ignore small chain
if (chain_nodes.size() - source_nodes.size() <= 2) { return; }
AddIdentityOpAndReconnect("pseudo_chain_header_", job, source_edges, MutOperatorConf4OpName,
*ParallelConf4OpName(source_nodes.at(0)->op().op_name()));
});
}
};
auto MutOperatorConf4OpName = MakeMutableOperatorConf4OpName(job);
auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job->placement());
op_graph.ForEachChainFamily([&](const HashSet<OpNode*>& chain_nodes) {
std::vector<OpNode*> source_nodes;
std::vector<OpEdge*> source_edges;
GetSourceNodesAndEdges(chain_nodes, &source_nodes, &source_edges);
if (source_edges.size() <= 1) { return; }
if (source_nodes.size() <= 1) { return; }
// ignore small chain
if (chain_nodes.size() - source_nodes.size() <= 2) { return; }
AddIdentityOpAndReconnect("pseudo_chain_header_", job, source_edges, MutOperatorConf4OpName,
*ParallelConf4OpName(source_nodes.at(0)->op().op_name()));
});
}
};
void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) {
auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {
......@@ -309,37 +313,11 @@ void SetOpTimeShape7BatchAxisLbis(const OpGraph& op_graph, JobBuilder* job_build
op_graph.DumpBatchAxisLbi(job_builder);
}
void RewriteBoxingWithAllReduce(const OpGraph& op_graph, JobBuilder* job_builder) {
if (!GlobalJobDesc().enable_non_distributed_optimizer()
&& GlobalJobDesc().enable_all_reduce_group()) {
AllReduceAddPass().Apply(op_graph, job_builder);
}
}
void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.DumpLogicalBlobDesc(job_builder);
op_graph.DumpSbpSignature(job_builder);
}
void MakeAllReduceSequence(const OpGraph& op_graph, JobBuilder* job_builder) {
if (GlobalJobDesc().disable_all_reduce_sequence()) { return; }
AllReduceSequencePass().Apply(op_graph, job_builder);
}
void EnableAutoMixedPrecision(const OpGraph& op_graph, JobBuilder* job_builder) {
if (!GlobalJobDesc().enable_auto_mixed_precision()) { return; }
CHECK_GE(CUDA_VERSION, 10000);
AutoMixedPrecision(AutoMixedPrecisionLists::WhiteList(), AutoMixedPrecisionLists::BlackList(),
AutoMixedPrecisionLists::GrayList(), AutoMixedPrecisionLists::ClearList())
.Apply(op_graph, job_builder);
}
void EnableNonDistributedOptimizer(const OpGraph& op_graph, JobBuilder* job_builder) {
if (!GlobalJobDesc().enable_non_distributed_optimizer()) { return; }
CHECK(GlobalJobDesc().enable_nccl());
NonDistributedOptimizerPass().Apply(op_graph, job_builder);
}
void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* job_builder) {
NcclTupleBroadcastReduceSequencePass().Apply(op_graph, job_builder);
}
......@@ -351,18 +329,18 @@ void JobCompleter::Complete(Job* job) const {
WithOpGraphAndMutJobBuilder(job, &ReplaceFacade);
// complete variable ops
WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf);
WithOpGraphAndMutJobBuilder(job, &EnableAutoMixedPrecision);
AutoMixedPrecision()(job);
if (GlobalJobDesc().IsTrain()) {
WithOpGraphAndMutJob(job, &TieUpChainHeadersUnReachableFromAnyVariableOps);
WithOpGraphAndMutJobBuilder(job, &EnableNonDistributedOptimizer);
TieUpChainHeadersUnReachableFromAnyVariableOps()(job);
NonDistributedOptimizerPass()(job);
WithOpGraphAndMutJob(job, &AutoTrainStep);
WithOpGraphAndMutJob(job, &AutoLearningRate);
// complete ops for trainning
WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning);
WithOpGraphAndMutJobBuilder(job, &MakeNcclTupleBroadcastReduceSequence);
WithOpGraphAndMutJobBuilder(job, &RewriteBoxingWithAllReduce);
AllReduceAddPass()(job);
AddLbiDiffWatcherOpConfs(job);
WithOpGraphAndMutJobBuilder(job, &MakeAllReduceSequence);
AllReduceSequencePass()(job);
}
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
......
......@@ -2,19 +2,20 @@
#define ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class JobBuilder;
class NonDistributedOptimizerPass final {
class NonDistributedOptimizerPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass);
NonDistributedOptimizerPass() = default;
~NonDistributedOptimizerPass() = default;
void Apply(const OpGraph& op_graph, JobBuilder* job_builder);
bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) override;
};
} // namespace oneflow
......
#ifndef ONEFLOW_CORE_JOB_COMPLETER_OP_GRAPH_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_OP_GRAPH_PASS_H_
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job_builder.h"
namespace oneflow {
class OpGraphPass {
public:
void operator()(Job* job) {
if (IsEnabled() == false) { return; }
const OpGraph op_graph(*job);
Apply(op_graph, job);
}
virtual bool IsEnabled() const { return true; }
virtual void Apply(const OpGraph& op_graph, Job* job) {
JobBuilder job_builder(job);
Apply(op_graph, &job_builder);
}
virtual void Apply(const OpGraph& op_graph, JobBuilder* job_builder) {
UNIMPLEMENTED();
}
};
}
#endif // ONEFLOW_CORE_JOB_COMPLETER_OP_GRAPH_PASS_H_
......@@ -249,6 +249,11 @@ def set_weight_l2(func_desc, value):
def set_bias_l2(func_desc, value):
func_desc.job_config_proto.train_conf.bias_l2 = value
@oneflow_function_config('enable_pseudo_chain_merge')
def set_bias_l2(func_desc, value):
assert type(value) is bool
func_desc.job_config_proto.flag_name2flag_value['enable_pseudo_chain_merge'] = value
@oneflow_function_config('default_placement_scope')
def set_default_placement(func_desc, value):
assert isinstance(value, placement_ctx.PlacementScope)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment