Skip to content
Snippets Groups Projects
Unverified Commit edf1a86d authored by binbinHan's avatar binbinHan Committed by GitHub
Browse files

Change get op attributes to get interface op attributes (#4733)

* skip test_gpt_data_loader in eager mode

* change_GetOpAttributes_to_GetInterfaceOpAttributes

* checkout change of test_gpt_data_loader.py
parent 99d45d87
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
});
m.def("InferOpConf", &InferOpConf);
m.def("GetSerializedOpAttributes", &GetSerializedOpAttributes);
m.def("GetSerializedInterfaceOpAttributes", &GetSerializedInterfaceOpAttributes);
m.def("IsInterfaceOpTypeCase", &IsInterfaceOpTypeCase);
m.def("GetOpParallelSymbolId", &GetOpParallelSymbolId);
......
......@@ -64,7 +64,7 @@ inline Maybe<std::string> InferOpConf(const std::string& op_conf_str,
return PbMessage2TxtString(*op_attribute);
}
inline Maybe<std::string> GetSerializedOpAttributes() {
inline Maybe<std::string> GetSerializedInterfaceOpAttributes() {
OpAttributeList op_attribute_list;
const JobSet& job_set = JUST(GetJobSet());
for (int i = 0; i < job_set.job_size(); i++) {
......@@ -72,8 +72,12 @@ inline Maybe<std::string> GetSerializedOpAttributes() {
auto scope = std::make_unique<GlobalJobDescScope>(job.job_conf(), i);
const auto& op_graph = JUST(OpGraph::New(job));
op_graph->ForEachNode([&op_attribute_list](OpNode* op_node) {
const auto& op_attribute = op_node->op().GetOpAttributeWithoutOpNameAndLbn();
op_attribute_list.mutable_op_attribute()->Add()->CopyFrom(*op_attribute);
const auto& op_type_case = op_node->op().op_conf().op_type_case();
if (oneflow::IsClassRegistered<int32_t, oneflow::IsInterfaceOpConf4OpTypeCase>(
op_type_case)) {
const auto& op_attribute = op_node->op().GetOpAttributeWithoutOpNameAndLbn();
op_attribute_list.mutable_op_attribute()->Add()->CopyFrom(*op_attribute);
}
});
}
return PbMessage2TxtString(op_attribute_list);
......
......@@ -35,8 +35,8 @@ inline std::string InferOpConf(const std::string& serialized_op_conf,
return oneflow::InferOpConf(serialized_op_conf, serialized_op_input_signature).GetOrThrow();
}
inline std::string GetSerializedOpAttributes() {
return oneflow::GetSerializedOpAttributes().GetOrThrow();
inline std::string GetSerializedInterfaceOpAttributes() {
return oneflow::GetSerializedInterfaceOpAttributes().GetOrThrow();
}
inline bool IsInterfaceOpTypeCase(int64_t op_type_case) {
......
......@@ -244,8 +244,8 @@ def GetScopeConfigDef():
return text_format.Parse(scope_config_def, ConfigDef())
def GetOpAttributes():
op_attributes = oneflow._oneflow_internal.GetSerializedOpAttributes()
def GetInterfaceOpAttributes():
op_attributes = oneflow._oneflow_internal.GetSerializedInterfaceOpAttributes()
return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())
......
......@@ -168,10 +168,8 @@ class Session(object):
return self
def UpdateInfo4InterfaceOp(self):
for op_attr in c_api_util.GetOpAttributes().op_attribute:
op_conf = op_attr.op_conf
if c_api_util.IsInterfaceOpConf(op_conf):
self.interface_op_name2op_attr_[op_conf.name] = op_attr
for op_attr in c_api_util.GetInterfaceOpAttributes().op_attribute:
self.interface_op_name2op_attr_[op_attr.op_conf.name] = op_attr
for job in c_api_util.GetJobSet().job:
op_name2parallel_conf = {}
for placement_group in job.placement.placement_group:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment