Skip to content
Snippets Groups Projects
Unverified Commit 120a7f86 authored by Liang Depeng's avatar Liang Depeng Committed by GitHub
Browse files

refactor message OperatorConf, change device_type to device_tag (#3411)

* refactor message OperatorConf, change device_type to device_tag

* subsititute HobDeviceType with HobDeviceTag in user_op kernel registration

* remove c_api_util.DeviceType4DeviceTag

* fix error when buil with cuda off

* fix can not use CHECK_JUST macro in another macro
parent 4d44113e
No related branches found
No related tags found
No related merge requests found
Showing
with 120 additions and 65 deletions
......@@ -120,7 +120,7 @@ if(WIN32)
#set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /DEBUG:FASTLINK")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wno-sign-compare -Wno-unused-function -fPIC -Werror=return-type")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wno-sign-compare -Wno-unused-function -fPIC")
endif()
if (THIRD_PARTY)
......
......@@ -57,7 +57,7 @@ enum JobBuildAndInferError {
kLogicalBlobNameInvalid = 402;
kOpNameExist = 450;
kOpConfDeviceTypeNoSet = 460;
kOpConfDeviceTagNoSet = 460;
kPlacementError = 470;
kBlobSplitAxisInferError = 480;
kUnknownJobBuildAndInferError = 500;
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/to_string.h"
namespace oneflow {
......@@ -20,7 +21,7 @@ namespace oneflow {
Maybe<const char*> DeviceTag4DeviceType(DeviceType device_type) {
if (device_type == kCPU) { return "cpu"; }
if (device_type == kGPU) { return "gpu"; }
return Error::DeviceTagNotFound() << "invalid";
return Error::DeviceTagNotFound() << "invalid_device";
}
Maybe<DeviceType> DeviceType4DeviceTag(const std::string& device_tag) {
......
......@@ -40,13 +40,6 @@ hob::BoolFunctorPtr<KernelRegContext> HobFalse() {
return krbf_ptr;
}
hob::HobContextGetter<KernelRegContext, DeviceType> HobDeviceType() {
std::ostringstream string_stream;
string_stream << "device_type";
return hob::HobContextGetter<KernelRegContext, DeviceType>(
string_stream.str(), [](const KernelRegContext& ctx) { return ctx.device_type(); });
}
hob::HobContextGetter<KernelRegContext, DataType> HobDataType(const std::string& tensor_name,
int tensor_idx) {
std::ostringstream string_stream;
......@@ -58,6 +51,14 @@ hob::HobContextGetter<KernelRegContext, DataType> HobDataType(const std::string&
});
}
HobStringContextGetter<KernelRegContext> HobDeviceTag() {
std::ostringstream string_stream;
string_stream << "device_tag";
return HobStringContextGetter<KernelRegContext>(
string_stream.str(),
[](const KernelRegContext& ctx) -> const std::string& { return ctx.device_tag(); });
}
} // namespace user_op
} // namespace oneflow
......@@ -16,8 +16,9 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_
#define ONEFLOW_CORE_FRAMEWORK_USER_OP_HOB_H_
#include "oneflow/core/common/high_order_bool.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/high_order_bool.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
namespace oneflow {
......@@ -28,8 +29,6 @@ hob::BoolFunctorPtr<KernelRegContext> HobTrue();
hob::BoolFunctorPtr<KernelRegContext> HobFalse();
hob::HobContextGetter<KernelRegContext, DeviceType> HobDeviceType();
hob::HobContextGetter<KernelRegContext, DataType> HobDataType(const std::string& tensor_name,
int tensor_idx);
......@@ -47,6 +46,46 @@ hob::HobContextGetter<user_op::KernelRegContext, T> HobAttr(const std::string& a
});
}
template<typename ContextT>
class HobStringContextGetter final {
public:
HobStringContextGetter(const DeviceType& device_type) {
std::string str = ToString(device_type);
debug_str_ = str;
context_getter_ = [str](const ContextT&) -> const std::string& { return str; };
}
HobStringContextGetter(const char* const_value) {
std::string str(const_value);
debug_str_ = str;
context_getter_ = [str](const ContextT&) -> const std::string& { return str; };
}
HobStringContextGetter(const std::string& const_value)
: debug_str_(const_value),
context_getter_(
[const_value](const ContextT&) -> const std::string& { return const_value; }) {}
HobStringContextGetter(const std::string& debug_str,
const std::function<const std::string&(const ContextT&)>& context_getter)
: debug_str_(debug_str), context_getter_(context_getter) {}
hob::BoolFunctorPtr<ContextT> operator==(const HobStringContextGetter& other) const {
std::ostringstream string_stream;
string_stream << debug_str_ << " == " << other.debug_str_;
std::function<std::string(const ContextT&)> l_fn = this->context_getter_;
std::function<std::string(const ContextT&)> r_fn = other.context_getter_;
std::shared_ptr<const hob::BoolFunctor<ContextT>> krbf_ptr =
std::make_shared<const hob::HighOrderBoolFunctor<ContextT>>(
string_stream.str(),
[l_fn, r_fn](const ContextT& ctx) { return l_fn(ctx) == r_fn(ctx); });
return krbf_ptr;
}
private:
std::string debug_str_;
std::function<const std::string&(const ContextT&)> context_getter_;
};
HobStringContextGetter<KernelRegContext> HobDeviceTag();
} // namespace user_op
} // namespace oneflow
......
......@@ -37,6 +37,7 @@ class KernelRegContext {
virtual ~KernelRegContext() = default;
virtual DeviceType device_type() const = 0;
virtual const std::string& device_tag() const = 0;
virtual const ParallelContext& parallel_ctx() const = 0;
virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;
......
......@@ -13,10 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/collective_boxing_task_node.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"
namespace oneflow {
......@@ -31,7 +32,7 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
const BlobDesc& logical_blob_desc, OpType op_type, int64_t root) {
OperatorConf op_conf;
op_conf.set_name(name);
op_conf.set_device_type(DeviceType::kGPU);
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(DeviceType::kGPU)));
CollectiveBoxingGenericOpConf* conf = op_conf.mutable_collective_boxing_generic_conf();
*conf->mutable_lbi() = lbi;
RankDesc* rank_desc = conf->mutable_rank_desc();
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing_identity_compute_task_node.h"
#include "oneflow/core/graph/logical_node.h"
......@@ -41,7 +42,7 @@ void BoxingIdentityCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
OperatorConf op_conf;
op_conf.set_name("System-Boxing-Identity-" + NewUniqueId());
op_conf.set_device_type(this->device_type());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
*op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi_;
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf, &GlobalJobDesc());
node->mut_op() = sole_op;
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/thread/thread_pool.h"
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
......@@ -180,10 +181,8 @@ void CollectIgnoreTaskEdgesInFirstMergedChains(const std::vector<std::vector<Tas
if (fw_node == nullptr) { continue; }
if (fw_node->logical_node()->op_vec().size() != 1) { continue; }
const auto& src_op = *fw_node->logical_node()->SoleOp();
if (src_op.op_conf().has_variable_conf()
&& src_op.op_conf().device_type() == DeviceType::kGPU) {
return true;
}
DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(src_op.op_conf().device_tag()));
if (src_op.op_conf().has_variable_conf() && device_type == DeviceType::kGPU) { return true; }
}
return false;
};
......
......@@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/thrd_id_generator.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
......@@ -79,7 +80,7 @@ void CopyHdTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
OperatorConf CopyHdTaskNode::NewCopyOpConf() {
OperatorConf conf;
conf.set_name("copy_hd_" + NewUniqueId());
conf.set_device_type(device_type());
conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type())));
conf.mutable_copy_hd_conf()->set_type(copy_type_);
auto in_regst = GetSoleConsumedRegst("copy_in");
if (in_regst->NumOfLbi() == 1) {
......@@ -141,7 +142,7 @@ void CopyCommNetTaskNode::PinConsumedRegstMemCase(MemoryCase* mem_case) {
OperatorConf CopyCommNetTaskNode::NewCopyOpConf() {
OperatorConf conf;
conf.set_name("copy_comm_net_" + NewUniqueId());
conf.set_device_type(device_type());
conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
conf.mutable_copy_comm_net_conf();
return conf;
}
......
......@@ -13,12 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/logical_graph.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/op_conf_util.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
......@@ -63,7 +64,7 @@ void LogicalGraph::NaiveBuildFwStruct(
auto parallel_desc_ptr_it = name2parallel_desc.find(cur_op_conf.name());
CHECK(parallel_desc_ptr_it != name2parallel_desc.end());
const std::shared_ptr<ParallelDesc>& parallel_desc_ptr = parallel_desc_ptr_it->second;
cur_op_conf.set_device_type(parallel_desc_ptr->device_type());
cur_op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc_ptr->device_type())));
std::shared_ptr<Operator> cur_op = ConstructOp(cur_op_conf, &GlobalJobDesc());
LogicalNode* cur_node = cur_op->NewProperLogicalNode();
AddAllocatedNode(cur_node);
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"
namespace oneflow {
......@@ -102,7 +103,7 @@ void SliceBoxingTaskNode::SetOutShape(const Shape& shape) { out_shape_ = shape;
OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() {
OperatorConf op_conf{};
op_conf.set_device_type(device_type());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type())));
SliceBoxingConf boxing_conf{};
*boxing_conf.mutable_lbi() = lbi_;
out_slice_.ToProto(boxing_conf.mutable_out_slice());
......
......@@ -13,17 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job_rewriter/op_graph_pass.h"
#include "oneflow/core/job_rewriter/autograd.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
#include "oneflow/core/job/scope.h"
#include <google/protobuf/text_format.h>
#include "oneflow/core/job_rewriter/autograd.h"
#include "oneflow/core/job_rewriter/op_graph_pass.h"
#include "oneflow/user/summary/summary_converter.h"
#include <google/protobuf/text_format.h>
#include <json.hpp>
namespace oneflow {
......@@ -505,9 +507,9 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
CHECK_OR_RETURN(op_name2op_.find(op_name) == op_name2op_.end())
<< JobBuildAndInferError::kOpNameExist << "op_name: " << op_name
<< " already exist in job: " << job_->job_conf().job_name();
CHECK_NE_OR_RETURN(op_conf.device_type(), DeviceType::kInvalidDevice)
<< JobBuildAndInferError::kOpConfDeviceTypeNoSet << "op_name: " << op_name
<< " not set device type";
CHECK_NE_OR_RETURN(op_conf.device_tag(), "invalid_device")
<< JobBuildAndInferError::kOpConfDeviceTagNoSet << "op_name: " << op_name
<< " not set device tag";
op_name2op_.emplace(op_name, ConstructOp(op_conf, job_desc));
Operator* op = op_name2op_.at(op_name).get();
......@@ -836,7 +838,7 @@ Maybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati
lbi_vec->push_back(sub_lbi);
};
OperatorConf op_conf;
op_conf.set_device_type(parallel_desc.device_type());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type())));
if (sbp.has_broadcast_parallel()) {
op_conf.set_name(kAutoMirroredBlobNamePrefix + "-DistributeClone-" + NewUniqueId());
auto* distribute_clone = op_conf.mutable_distribute_clone_conf();
......@@ -890,7 +892,8 @@ Maybe<LogicalBlobId> EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompat
CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id());
op_conf.set_scope_symbol_id(producer_op_conf.scope_symbol_id());
}
op_conf.set_device_type(parallel_desc.device_type());
// const char* device_tag = JUST(DeviceTag4DeviceType(parallel_desc.device_type()));
op_conf.set_device_tag(JUST(DeviceTag4DeviceType(parallel_desc.device_type())));
op_conf.set_name(kAutoMirroredBlobNamePrefix + "-CastToMirrored-" + NewUniqueId());
auto* cast_to_mirrored_conf = op_conf.mutable_cast_to_mirrored_conf();
cast_to_mirrored_conf->set_in(lbn);
......
......@@ -157,10 +157,9 @@ Maybe<void> ParallelDesc::CheckWithResourceDesc(const ResourceDesc& resource_des
ParallelConf ParallelDesc::GetParallelIdOnlyParallelConf(int64_t parallel_id) const {
ParallelConf parallel_conf;
const char* device_tag = CHECK_JUST(DeviceTag4DeviceType(device_type()));
std::string machine_id = std::to_string(MachineIdForParallelId(parallel_id));
std::string device_id = std::to_string(DeviceIdForParallelId(parallel_id));
parallel_conf.set_device_tag(device_tag);
parallel_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type())));
parallel_conf.add_device_name(machine_id + ":" + device_id);
return parallel_conf;
}
......
......@@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/vm/symbol_storage.h"
namespace oneflow {
......@@ -42,7 +43,7 @@ Maybe<const JobDesc*> Scope::job_desc() const {
}
Maybe<int64_t> Scope::GetParallelDescSymbolId(const OperatorConf& op_conf) const {
if (op_conf.device_type() == DeviceType::kCPU || IsCpuOnly(op_conf)) {
if (op_conf.device_tag() == "cpu" || IsCpuOnly(op_conf)) {
return scope_proto_.host_parallel_desc_symbol_id();
} else {
return scope_proto_.device_parallel_desc_symbol_id();
......@@ -50,7 +51,7 @@ Maybe<int64_t> Scope::GetParallelDescSymbolId(const OperatorConf& op_conf) const
}
Maybe<const ParallelDesc*> Scope::GetParallelDesc(const OperatorConf& op_conf) const {
if (op_conf.device_type() == DeviceType::kCPU || IsCpuOnly(op_conf)) {
if (op_conf.device_tag() == "cpu" || IsCpuOnly(op_conf)) {
return host_parallel_desc_.get();
} else {
return device_parallel_desc_.get();
......
......@@ -13,8 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job_rewriter/op_graph_pass.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job_rewriter/op_graph_pass.h"
namespace oneflow {
......@@ -35,19 +36,19 @@ Maybe<void> DumpVariableInfoPass::Apply(const OpGraph& op_graph, JobBuilder* job
const std::string sep = "\t";
auto log_stream =
TeePersistentLogStream::Create("variable_table_" + std::to_string(GlobalJobDesc().job_id()));
(*log_stream) << "id" << sep << "name" << sep << "device_type" << sep << "parallel_num" << sep
(*log_stream) << "id" << sep << "name" << sep << "device_tag" << sep << "parallel_num" << sep
<< "distribute" << sep << "data_type" << sep << "shape" << sep << "elem_cnt" << sep
<< "size"
<< "\n";
op_graph.TopoForEachNode([&](const OpNode* node) {
JUST(op_graph.TopoForEachNodeWithErrorCaptured([&](const OpNode* node) -> Maybe<void> {
const OperatorConf& op_conf = node->op().op_conf();
if (!op_conf.has_variable_conf()) { return; }
if (!op_conf.has_variable_conf()) { return Maybe<void>::Ok(); }
const VariableOpConf& conf = op_conf.variable_conf();
(*log_stream) << std::to_string(cnt);
(*log_stream) << sep;
(*log_stream) << op_conf.name();
(*log_stream) << sep;
(*log_stream) << DeviceType_Name(op_conf.device_type());
(*log_stream) << op_conf.device_tag();
(*log_stream) << sep;
(*log_stream) << std::to_string(node->parallel_desc().parallel_num());
(*log_stream) << sep;
......@@ -67,7 +68,8 @@ Maybe<void> DumpVariableInfoPass::Apply(const OpGraph& op_graph, JobBuilder* job
(*log_stream) << std::to_string(shape.elem_cnt() * GetSizeOfDataType(conf.data_type()));
(*log_stream) << "\n";
cnt += 1;
});
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
......
......@@ -13,9 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/arg_where_kernel_util.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/arg_where_kernel_util.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
......@@ -40,14 +41,14 @@ class ArgWhereKernel : public KernelIf<DeviceType::kCPU> {
}
};
#define REGISTER_ARG_WHERE_KERNEL(device_type_v, dtype, itype, ndims) \
NEW_REGISTER_KERNEL(OperatorConf::kArgWhereConf, \
ArgWhereKernel<device_type_v, dtype, itype, ndims>) \
.SetIsMatchedPred([](const KernelConf& conf) { \
return (device_type_v == conf.op_attribute().op_conf().device_type()) \
&& (GetDataType<itype>::value == conf.data_type()) \
&& (GetDataType<dtype>::value == conf.arg_where_conf().in_data_type()) \
&& (ndims == conf.arg_where_conf().num_axes()); \
#define REGISTER_ARG_WHERE_KERNEL(device_type_v, dtype, itype, ndims) \
NEW_REGISTER_KERNEL(OperatorConf::kArgWhereConf, \
ArgWhereKernel<device_type_v, dtype, itype, ndims>) \
.SetIsMatchedPred([](const KernelConf& conf) -> bool { \
return (conf.op_attribute().op_conf().device_tag() == ToString(device_type_v)) \
&& (GetDataType<itype>::value == conf.data_type()) \
&& (GetDataType<dtype>::value == conf.arg_where_conf().in_data_type()) \
&& (ndims == conf.arg_where_conf().num_axes()); \
});
#define REGISTER_ARG_WHERE_KERNELS_AT_NDIMS(device_type_v, dtype, itype) \
......
......@@ -13,10 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/kernel/indexed_slices_reduce_sum_kernel_util.h"
#include "oneflow/core/kernel/indexed_slices_lazy_adam_model_update_kernel_util.h"
#include "oneflow/core/kernel/indexed_slices_reduce_sum_kernel_util.h"
namespace oneflow {
......@@ -80,7 +81,7 @@ void IndexedSlicesLazyAdamMdUpdateKernel<device_type, T, K>::ForwardDataContent(
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ( \
(kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
(kernel_conf.op_attribute().op_conf().device_tag() == ToString(device_type_v)) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.indexed_slices_lazy_adam_model_update_conf().indices_data_type())); \
......
......@@ -13,10 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/kernel/indexed_slices_reduce_sum_kernel_util.h"
#include "oneflow/core/kernel/indexed_slices_momentum_model_update_kernel_util.h"
#include "oneflow/core/kernel/indexed_slices_reduce_sum_kernel_util.h"
namespace oneflow {
......@@ -75,7 +76,7 @@ void IndexedSlicesMomentumMdUpdateKernel<device_type, T, K>::ForwardDataContent(
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ( \
(kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
(kernel_conf.op_attribute().op_conf().device_tag() == ToString(device_type_v)) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.indexed_slices_momentum_model_update_conf().indices_data_type())); \
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/indexed_slices_naive_model_update_kernel_util.h"
......@@ -59,7 +60,7 @@ void IndexedSlicesNaiveMdUpdateKernel<device_type, T, K>::ForwardDataContent(
IndexedSlicesNaiveMdUpdateKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ((kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
return ((kernel_conf.op_attribute().op_conf().device_tag() == ToString(device_type_v)) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.indexed_slices_naive_model_update_conf().indices_data_type())); \
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment