Skip to content
Snippets Groups Projects
Commit 725a37fc authored by lixinqi's avatar lixinqi
Browse files

refactor ConfigFlagDef

parent 707adc37
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,16 @@ ConfigDef* MutGlobalConfigDef() {
return &config_def;
}
template<ConfigDefType config_def_type>
UserOpAttrVal* AddConfigFlagDef(const std::string& name, const std::string& description) {
auto* name2flag_def = MutGlobalConfigDef<config_def_type>()->mutable_flag_name2flag_def();
CHECK(name2flag_def->find(name) == name2flag_def->end());
auto* flag_def = &(*name2flag_def)[name];
flag_def->set_name(name);
flag_def->set_description(description);
return flag_def->mutable_default_val();
}
} // namespace
const ConfigDef& GlobalEnvConfigDef() { return *MutGlobalConfigDef<kEnvConfigType>(); }
......@@ -19,37 +29,29 @@ const ConfigDef& GlobalFunctionConfigDef() { return *MutGlobalConfigDef<kFunctio
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Bool(
const std::string& name, bool default_val) const {
auto* flag2default = MutGlobalConfigDef<config_def_type>()->mutable_flag_name2default_val();
CHECK(flag2default->find(name) == flag2default->end());
(*flag2default)[name].set_at_bool(default_val);
const std::string& name, bool default_val, const std::string& description) const {
AddConfigFlagDef<config_def_type>(name, description)->set_at_bool(default_val);
return *this;
}
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Int64(
const std::string& name, int64_t default_val) const {
auto* flag2default = MutGlobalConfigDef<config_def_type>()->mutable_flag_name2default_val();
CHECK(flag2default->find(name) == flag2default->end());
(*flag2default)[name].set_at_int64(default_val);
const std::string& name, int64_t default_val, const std::string& description) const {
AddConfigFlagDef<config_def_type>(name, description)->set_at_int64(default_val);
return *this;
}
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::Double(
const std::string& name, double default_val) const {
auto* flag2default = MutGlobalConfigDef<config_def_type>()->mutable_flag_name2default_val();
CHECK(flag2default->find(name) == flag2default->end());
(*flag2default)[name].set_at_double(default_val);
const std::string& name, double default_val, const std::string& description) const {
AddConfigFlagDef<config_def_type>(name, description)->set_at_double(default_val);
return *this;
}
template<ConfigDefType config_def_type>
const ConfigDefBuidler<config_def_type>& ConfigDefBuidler<config_def_type>::String(
const std::string& name, const std::string& default_val) const {
auto* flag2default = MutGlobalConfigDef<config_def_type>()->mutable_flag_name2default_val();
CHECK(flag2default->find(name) == flag2default->end());
(*flag2default)[name].set_at_string(default_val);
const std::string& name, const std::string& default_val, const std::string& description) const {
AddConfigFlagDef<config_def_type>(name, description)->set_at_string(default_val);
return *this;
}
......
......@@ -9,10 +9,14 @@ namespace oneflow {
template<ConfigDefType config_def_type>
struct ConfigDefBuidler final {
const ConfigDefBuidler& Bool(const std::string& name, bool default_val) const;
const ConfigDefBuidler& Int64(const std::string& name, int64_t default_val) const;
const ConfigDefBuidler& Double(const std::string& name, double default_val) const;
const ConfigDefBuidler& String(const std::string& name, const std::string& default_val) const;
const ConfigDefBuidler& Bool(const std::string& name, bool default_val,
const std::string& description) const;
const ConfigDefBuidler& Int64(const std::string& name, int64_t default_val,
const std::string& description) const;
const ConfigDefBuidler& Double(const std::string& name, double default_val,
const std::string& description) const;
const ConfigDefBuidler& String(const std::string& name, const std::string& default_val,
const std::string& description) const;
};
#define REGISTER_ENV_CONFIG_DEF() REGISTER_CONFIG_DEF(kEnvConfigType)
......
......@@ -9,6 +9,12 @@ enum ConfigDefType {
kFunctionConfigType = 3;
}
message ConfigFlagDef {
required string name = 1;
required string description = 2;
required UserOpAttrVal default_val = 3;
}
message ConfigDef {
map<string, UserOpAttrVal> flag_name2default_val = 1;
map<string, ConfigFlagDef> flag_name2flag_def = 1;
}
......@@ -14,11 +14,11 @@ namespace oneflow {
namespace {
void CheckFunctionConfig(const JobConfigProto& job_conf) {
const auto& flag_name2default_val = GlobalFunctionConfigDef().flag_name2default_val();
const auto& flag_name2flag_def = GlobalFunctionConfigDef().flag_name2flag_def();
for (const auto& pair : job_conf.flag_name2flag_value()) {
const auto& iter = flag_name2default_val.find(pair.first);
CHECK(iter != flag_name2default_val.end());
CHECK_EQ(iter->second.value_case(), pair.second.value_case());
const auto& iter = flag_name2flag_def.find(pair.first);
CHECK(iter != flag_name2flag_def.end());
CHECK_EQ(iter->second.default_val().value_case(), pair.second.value_case());
}
}
......@@ -102,10 +102,10 @@ void JobDesc::Init() {
const UserOpAttrVal& JobDesc::GetFunctionFlagVal(const std::string& field_name) const {
const auto& iter = job_conf_.flag_name2flag_value().find(field_name);
if (iter != job_conf_.flag_name2flag_value().end()) { return iter->second; }
const auto& flag_name2default_val = GlobalFunctionConfigDef().flag_name2default_val();
const auto& def_iter = flag_name2default_val.find(field_name);
CHECK(def_iter != flag_name2default_val.end());
return def_iter->second;
const auto& flag_name2flag_def = GlobalFunctionConfigDef().flag_name2flag_def();
const auto& def_iter = flag_name2flag_def.find(field_name);
CHECK(def_iter != flag_name2flag_def.end());
return def_iter->second.default_val();
}
bool IsInterfaceOpConf(const OperatorConf& op_conf) {
......
......@@ -230,7 +230,8 @@ std::function<bool(OpNode*)> MakePredicatorIsReachableFromAnyVariableOps(const O
};
}
REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false);
REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false,
"ties up chain headers unreachable from any variable ops");
class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass {
bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); }
void Apply(const OpGraph& op_graph, Job* job) override {
......
......@@ -58,8 +58,8 @@ class Session(object):
return self.job_name2function_desc_[job_name].job_config_proto
def UpdateFunctionFlagName2DefaultVal(self):
flag_name2default_val = c_api_util.GetFunctionConfigDef().flag_name2default_val
self.function_flag_name2default_val_ = flag_name2default_val
items = c_api_util.GetFunctionConfigDef().flag_name2flag_def.items()
self.function_flag_name2default_val_ = {k : v.default_val for k, v in items}
def TryInit(self):
if self.status_ is SessionStatus.OPEN: self.Init()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment