Skip to content
Snippets Groups Projects
Unverified Commit 120a7f86 authored by Liang Depeng's avatar Liang Depeng Committed by GitHub
Browse files

refactor message OperatorConf, change device_type to device_tag (#3411)

* refactor message OperatorConf, change device_type to device_tag

* subsititute HobDeviceType with HobDeviceTag in user_op kernel registration

* remove c_api_util.DeviceType4DeviceTag

* fix error when buil with cuda off

* fix can not use CHECK_JUST macro in another macro
parent 4d44113e
No related branches found
No related tags found
No related merge requests found
Showing
with 140 additions and 162 deletions
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/kernel/indexed_slices_reduce_sum_kernel_util.h"
......@@ -54,17 +55,17 @@ void IndexedSlicesReduceSumKernel<device_type, T, K>::ForwardDataContent(
workspace_ptr, workspace_size_in_bytes);
}
#define MAKE_INDEXED_SLICES_REDUCE_SUM_KERNEL_ENTRY(device_type_v, data_type_pair, \
indices_type_pair) \
NEW_REGISTER_KERNEL( \
OperatorConf::kIndexedSlicesReduceSumConf, \
IndexedSlicesReduceSumKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ((kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.indexed_slices_reduce_sum_conf().indices_data_type())); \
#define MAKE_INDEXED_SLICES_REDUCE_SUM_KERNEL_ENTRY(device_type_v, data_type_pair, \
indices_type_pair) \
NEW_REGISTER_KERNEL( \
OperatorConf::kIndexedSlicesReduceSumConf, \
IndexedSlicesReduceSumKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ((kernel_conf.op_attribute().op_conf().device_tag() == ToString(device_type_v)) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.indexed_slices_reduce_sum_conf().indices_data_type())); \
});
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_INDEXED_SLICES_REDUCE_SUM_KERNEL_ENTRY, DEVICE_TYPE_SEQ,
......
......@@ -16,16 +16,17 @@ limitations under the License.
#ifndef ONEFLOW_CORE_KERNEL_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_KERNEL_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/resource.pb.h"
#include "oneflow/core/kernel/kernel.pb.h"
#include "oneflow/core/kernel/kernel_registration.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/op_conf_util.h"
#include "oneflow/core/persistence/snapshot.h"
#include "oneflow/core/register/blob.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/operator/op_conf_util.h"
#include "oneflow/core/kernel/kernel_registration.h"
namespace oneflow {
......@@ -210,18 +211,19 @@ std::unique_ptr<const Kernel> ConstructKernel(const JobDesc* job_desc, const Ker
{GetHashKey(device_type, OF_PP_PAIR_SECOND(data_type_pair)), \
[]() { return new kernel_class<device_type, OF_PP_PAIR_FIRST(data_type_pair)>(); }},
#define ADD_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq) \
namespace { \
\
Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (kernel_class), \
DEVICE_TYPE_SEQ, data_type_seq)}; \
return creators.at(GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), \
kernel_conf.data_type()))(); \
} \
\
REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \
#define ADD_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq) \
namespace { \
\
Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (kernel_class), \
DEVICE_TYPE_SEQ, data_type_seq)}; \
DeviceType device_type = \
CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \
return creators.at(GetHashKey(device_type, kernel_conf.data_type()))(); \
} \
\
REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \
}
#define MAKE_DEVICE_TYPE_KERNEL_CREATOR_ENTRY(kernel_class, device_type) \
......@@ -234,7 +236,9 @@ std::unique_ptr<const Kernel> ConstructKernel(const JobDesc* job_desc, const Ker
static const HashMap<int, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_DEVICE_TYPE_KERNEL_CREATOR_ENTRY, (kernel_class), \
DEVICE_TYPE_SEQ)}; \
return creators.at(kernel_conf.op_attribute().op_conf().device_type())(); \
DeviceType device_type = \
CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \
return creators.at(device_type)(); \
} \
\
REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \
......@@ -257,20 +261,21 @@ std::unique_ptr<const Kernel> ConstructKernel(const JobDesc* job_desc, const Ker
REGISTER_KERNEL_CREATOR(op_type_case, CreateKernel); \
}
#define ADD_DEFAULT_KERNEL_CREATOR_WITH_GPU_HALF(op_type_case, kernel_class, data_type_seq) \
namespace { \
\
Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (kernel_class), \
DEVICE_TYPE_SEQ, data_type_seq) \
MAKE_KERNEL_CREATOR_ENTRY(kernel_class, DeviceType::kGPU, \
(float16, DataType::kFloat16))}; \
return creators.at(GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), \
kernel_conf.data_type()))(); \
} \
\
REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \
#define ADD_DEFAULT_KERNEL_CREATOR_WITH_GPU_HALF(op_type_case, kernel_class, data_type_seq) \
namespace { \
\
Kernel* OF_PP_CAT(CreateKernel, __LINE__)(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (kernel_class), \
DEVICE_TYPE_SEQ, data_type_seq) \
MAKE_KERNEL_CREATOR_ENTRY(kernel_class, DeviceType::kGPU, \
(float16, DataType::kFloat16))}; \
DeviceType device_type = \
CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \
return creators.at(GetHashKey(device_type, kernel_conf.data_type()))(); \
} \
\
REGISTER_KERNEL_CREATOR(op_type_case, OF_PP_CAT(CreateKernel, __LINE__)); \
}
#endif // ONEFLOW_CORE_KERNEL_KERNEL_H_
......@@ -16,13 +16,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_
#define ONEFLOW_CORE_KERNEL_KERNEL_REGISTRATION_H_
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/operator/op_conf_util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.pb.h"
#include "oneflow/core/operator/op_conf_util.h"
namespace oneflow {
......@@ -79,20 +80,20 @@ Kernel* CreateKernel(const KernelConf& kernel_conf);
kernel_registration::KernelRegistrarBuilder(op_type).SetCreateFn( \
[]() { return new __VA_ARGS__(); })
#define REGISTER_KERNEL_WITH_NOTHING(op_type, ...) \
NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf&) { \
return true; \
#define REGISTER_KERNEL_WITH_NOTHING(op_type, ...) \
NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf&) -> bool { \
return true; \
});
#define REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, device, dtype, ...) \
NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) { \
return (device == conf.op_attribute().op_conf().device_type()) \
&& (GetDataType<dtype>::value == conf.data_type()); \
#define REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(op_type, device, dtype, ...) \
NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) -> bool { \
return (ToString(device) == conf.op_attribute().op_conf().device_tag()) \
&& (GetDataType<dtype>::value == conf.data_type()); \
});
#define REGISTER_KERNEL_WITH_DEVICE(op_type, device, ...) \
NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) { \
return (device == conf.op_attribute().op_conf().device_type()); \
#define REGISTER_KERNEL_WITH_DEVICE(op_type, device, ...) \
NEW_REGISTER_KERNEL(op_type, __VA_ARGS__).SetIsMatchedPred([](const KernelConf& conf) -> bool { \
return (ToString(device) == conf.op_attribute().op_conf().device_tag()); \
});
#define REGISTER_KERNEL_HELPER_CPU_FLOATING(op_type, kernel) \
......
......@@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_KERNEL_NORMAL_MODEL_UPDATE_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_NORMAL_MODEL_UPDATE_KERNEL_H_
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
......@@ -49,13 +50,14 @@ class NormalMdUpdateKernel : public KernelIf<device_type> {
#define DECLARE_MDUPDT_KERNEL_CREATOR(x) Kernel* Create##x##MdUpdtKernel(const KernelConf&);
#define DEFINE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (x##MdUpdateKernel), \
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)}; \
return creators.at(GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), \
kernel_conf.data_type()))(); \
#define DEFINE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (x##MdUpdateKernel), \
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)}; \
DeviceType device_type = \
CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \
return creators.at(GetHashKey(device_type, kernel_conf.data_type()))(); \
}
} // namespace oneflow
......
......@@ -86,17 +86,17 @@ class SigmoidCrossEntropyGradGpuKernel final : public KernelIf<DeviceType::kGPU>
#define REGISTER_SIGMOID_CROSS_ENTROPY_GPU_KERNEL(dtype, ltype) \
NEW_REGISTER_KERNEL(OperatorConf::kSigmoidCrossEntropyConf, \
SigmoidCrossEntropyGpuKernel<dtype, ltype>) \
.SetIsMatchedPred([](const KernelConf& conf) { \
return ((conf.op_attribute().op_conf().device_type() == DeviceType::kGPU) \
.SetIsMatchedPred([](const KernelConf& conf) -> bool { \
return ((conf.op_attribute().op_conf().device_tag() == "gpu") \
&& (conf.data_type() == GetDataType<dtype>::value) \
&& (GetDataType<ltype>::value \
== conf.op_attribute().op_conf().sigmoid_cross_entropy_conf().label_type())); \
}); \
NEW_REGISTER_KERNEL(OperatorConf::kSigmoidCrossEntropyGradConf, \
SigmoidCrossEntropyGradGpuKernel<dtype, ltype>) \
.SetIsMatchedPred([](const KernelConf& conf) { \
.SetIsMatchedPred([](const KernelConf& conf) -> bool { \
return ( \
(conf.op_attribute().op_conf().device_type() == DeviceType::kGPU) \
(conf.op_attribute().op_conf().device_tag() == "gpu") \
&& (conf.data_type() == GetDataType<dtype>::value) \
&& (GetDataType<ltype>::value \
== conf.op_attribute().op_conf().sigmoid_cross_entropy_grad_conf().label_type())); \
......
......@@ -13,15 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/register/register_desc.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <mutex>
#include <queue>
#include "oneflow/core/common/util.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/register/register_desc.h"
namespace oneflow {
......@@ -101,7 +103,7 @@ class SyncDynamicResizeGPUKernel final : public KernelIf<DeviceType::kGPU> {
#define REGISTER_SYNC_DYNAMIC_RESIZE_GPU_KERNEL(stype) \
NEW_REGISTER_KERNEL(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeGPUKernel<stype>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) { \
return (kernel_conf.op_attribute().op_conf().device_type() == DeviceType::kGPU \
return (kernel_conf.op_attribute().op_conf().device_tag() == "gpu" \
&& GetDataType<stype>::value \
== kernel_conf.sync_dynamic_resize_conf().size_data_type()); \
})
......@@ -139,7 +141,7 @@ class SyncDynamicResizeCPUKernel final : public KernelIf<DeviceType::kCPU> {
#define REGISTER_SYNC_DYNAMIC_RESIZE_CPU_KERNEL(stype) \
NEW_REGISTER_KERNEL(OperatorConf::kSyncDynamicResizeConf, SyncDynamicResizeCPUKernel<stype>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) { \
return (kernel_conf.op_attribute().op_conf().device_type() == DeviceType::kCPU \
return (kernel_conf.op_attribute().op_conf().device_tag() == "cpu" \
&& GetDataType<stype>::value \
== kernel_conf.sync_dynamic_resize_conf().size_data_type()); \
})
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/common/tensor_buffer.h"
......@@ -81,12 +82,12 @@ void TensorBufferToTensorListKernel<T>::ForwardDataContent(
CHECK_EQ(out_blob->total_num_of_tensors(), in_blob->shape().elem_cnt());
}
#define REGISTER_TENSOR_BUFFER_TO_TENSOR_LIST_KERNEL(dtype) \
NEW_REGISTER_KERNEL(OperatorConf::kTensorBufferToTensorListConf, \
TensorBufferToTensorListKernel<dtype>) \
.SetIsMatchedPred([](const KernelConf& conf) { \
return (conf.op_attribute().op_conf().device_type() == DeviceType::kCPU) \
&& (conf.data_type() == GetDataType<dtype>::value); \
#define REGISTER_TENSOR_BUFFER_TO_TENSOR_LIST_KERNEL(dtype) \
NEW_REGISTER_KERNEL(OperatorConf::kTensorBufferToTensorListConf, \
TensorBufferToTensorListKernel<dtype>) \
.SetIsMatchedPred([](const KernelConf& conf) { \
return (conf.op_attribute().op_conf().device_tag() == "cpu") \
&& (conf.data_type() == GetDataType<dtype>::value); \
});
REGISTER_TENSOR_BUFFER_TO_TENSOR_LIST_KERNEL(char)
......
......@@ -13,8 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
......@@ -59,7 +60,7 @@ void TensorListToTensorBufferKernel::ForwardHeader(
NEW_REGISTER_KERNEL(OperatorConf::kTensorListToTensorBufferConf, TensorListToTensorBufferKernel)
.SetIsMatchedPred([](const KernelConf& conf) {
return (conf.op_attribute().op_conf().device_type() == DeviceType::kCPU)
return (conf.op_attribute().op_conf().device_tag() == "cpu")
&& (conf.data_type() == DataType::kTensorBuffer);
});
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/kernel/unique_kernel_util.h"
......@@ -57,7 +58,7 @@ void UniqueWithCountsKernel<device_type, T, K>::ForwardDataContent(
UniqueWithCountsKernel<device_type_v, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(indices_type_pair)>) \
.SetIsMatchedPred([](const KernelConf& kernel_conf) -> bool { \
return ((kernel_conf.op_attribute().op_conf().device_type() == device_type_v) \
return ((kernel_conf.op_attribute().op_conf().device_tag() == ToString(device_type_v)) \
&& ((OF_PP_PAIR_SECOND(data_type_pair)) == kernel_conf.data_type()) \
&& (OF_PP_PAIR_SECOND(indices_type_pair) \
== kernel_conf.unique_with_counts_conf().indices_data_type())); \
......
......@@ -13,14 +13,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/eager_kernel.h"
#include "oneflow/core/framework/infer_util.h"
#include "oneflow/core/framework/op_kernel.h"
#include "oneflow/core/framework/op_kernel_infer_cache.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/infer_util.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/kernel/eager_kernel.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
......@@ -61,8 +61,8 @@ class UserKernelBaseContext {
};
InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input(), &inputs_);
InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output(), &outputs_);
device_type_ = kernel_conf.op_attribute().op_conf().device_type();
device_tag_ = kernel_conf.op_attribute().op_conf().device_tag();
device_type_ = CHECK_JUST(DeviceType4DeviceTag(device_tag_));
parallel_ctx_ = kernel_conf.user_conf().parallel_ctx();
for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) {
arg2tensor_desc_.emplace(GenUnRepeatedBn(pair.first), user_op::TensorDesc(pair.second));
......@@ -71,6 +71,7 @@ class UserKernelBaseContext {
~UserKernelBaseContext() = default;
DeviceType device_type() const { return device_type_; }
const std::string& device_tag() const { return device_tag_; }
const ParallelContext& parallel_ctx() const { return parallel_ctx_; }
const JobDesc& job_desc() const { return job_desc_; }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
......@@ -87,6 +88,7 @@ class UserKernelBaseContext {
ArgVec inputs_;
ArgVec outputs_;
DeviceType device_type_;
std::string device_tag_;
ParallelContext parallel_ctx_;
HashMap<std::pair<std::string, int32_t>, user_op::TensorDesc> arg2tensor_desc_;
const JobDesc& job_desc_;
......@@ -378,6 +380,7 @@ class UserKernelRegContext final : public user_op::KernelRegContext {
~UserKernelRegContext() = default;
DeviceType device_type() const override { return base_ctx_.device_type(); }
const std::string& device_tag() const override { return base_ctx_.device_tag(); }
const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
......
......@@ -976,7 +976,7 @@ message CastToStaticShapeOpConf {
message OperatorConf {
required string name = 1;
optional bool trainable = 3 [default = true];
optional DeviceType device_type = 4 [default = kInvalidDevice];
optional string device_tag = 4 [default = "invalid_device"];
optional bool enable_cudnn = 5;
optional int64 cudnn_buf_limit_mbyte = 6;
repeated string ctrl_in_op_name = 7;
......
......@@ -13,14 +13,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/job/sbp_signature_builder.h"
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/job/sbp_signature_builder.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
......@@ -39,7 +40,8 @@ DataType GetDataTypeFromBnInOpVec(
std::shared_ptr<Operator> CheckAndConstructOp(const OperatorConf& op_conf,
const JobDesc* job_desc) {
Operator* rptr = NewObj<Operator>(op_conf.op_type_case(), op_conf);
if (IsCpuOnly(op_conf)) { CHECK_EQ(op_conf.device_type(), DeviceType::kCPU); }
DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf.device_tag()));
if (IsCpuOnly(op_conf)) { CHECK_EQ(device_type, DeviceType::kCPU); }
rptr->Init(op_conf, job_desc);
return std::shared_ptr<Operator>(rptr);
}
......@@ -72,6 +74,11 @@ LogicalBlobId* Operator::MutBnInOp2Lbi(const std::string& bn_in_op) {
}
}
DeviceType Operator::device_type() const {
DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_attribute_.op_conf().device_tag()));
return device_type;
}
const std::string& Operator::SoleIbn() const {
CHECK_EQ(input_bns().size(), 1);
return input_bns().Get(0);
......@@ -636,7 +643,7 @@ bool IsCpuOnly(const OperatorConf& op_conf) {
std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf, DeviceType device_type,
const JobDesc* job_desc) {
OperatorConf dev_op_conf = op_conf;
dev_op_conf.set_device_type(device_type);
dev_op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(device_type)));
return CheckAndConstructOp(dev_op_conf, job_desc);
}
......
......@@ -59,7 +59,7 @@ class Operator {
// Getters
const std::string& op_name() const { return op_conf().name(); }
DeviceType device_type() const { return op_attribute_.op_conf().device_type(); }
DeviceType device_type() const;
bool EnableCudnn() const { return op_conf().enable_cudnn(); }
bool DevIsGpuAndEnableCudnn() const { return device_type() == DeviceType::kGPU && EnableCudnn(); }
const OperatorConf& op_conf() const { return op_attribute_.op_conf(); }
......
......@@ -13,12 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/operator/user_op.h"
#include "oneflow/core/operator/user_op_util.h"
#include "oneflow/core/framework/tensor_desc.h"
#include "oneflow/core/framework/batch_axis_context.h"
#include "oneflow/core/framework/infer_util.h"
#include "oneflow/core/framework/sbp_context.h"
#include "oneflow/core/framework/batch_axis_context.h"
#include "oneflow/core/framework/tensor_desc.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/operator/user_op.h"
#include "oneflow/core/operator/user_op_util.h"
namespace oneflow {
......@@ -53,7 +54,8 @@ class UserOpKernelRegContext final : public user_op::KernelRegContext {
const auto& op_conf = user_op->op_conf();
CHECK(op_conf.has_user_conf());
device_type_ = op_conf.device_type();
device_tag_ = op_conf.device_tag();
device_type_ = CHECK_JUST(DeviceType4DeviceTag(device_tag_));
parallel_ctx_ = parallel_ctx;
auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map,
......@@ -85,6 +87,7 @@ class UserOpKernelRegContext final : public user_op::KernelRegContext {
~UserOpKernelRegContext() = default;
DeviceType device_type() const override { return device_type_; }
const std::string& device_tag() const override { return device_tag_; }
const ParallelContext& parallel_ctx() const override { return *parallel_ctx_; }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
......@@ -99,6 +102,7 @@ class UserOpKernelRegContext final : public user_op::KernelRegContext {
ArgVec inputs_;
ArgVec outputs_;
DeviceType device_type_;
std::string device_tag_;
const ParallelContext* parallel_ctx_;
HashMap<std::pair<std::string, int32_t>, user_op::TensorDesc> arg2tensor_desc_;
};
......
......@@ -505,6 +505,7 @@ def ConstructNaiveBoxingOpConf(
):
op_conf = op_conf_pb.OperatorConf()
op_conf.name = "undefined_boxing_op_name"
op_conf.device_tag = "cpu"
op_conf.boxing_conf.lbi.op_name = "undefined_boxing_op_name"
op_conf.boxing_conf.lbi.blob_name = "undefined_boxing_blob_name"
op_conf.boxing_conf.in_num = in_parallel_num
......@@ -623,7 +624,7 @@ def BuildCopyHdInstruction(builder, produced_blob_object, to_device_tag):
def _MakeCopyHdOpConfAndRetLbi():
op_conf = op_conf_pb.OperatorConf()
op_conf.name = "copy_hd"
op_conf.device_type = c_api_util.DeviceType4DeviceTag("gpu")
op_conf.device_tag = "gpu"
setattr(op_conf.copy_conf, "in", "%s/in" % op_conf.name)
op_conf.copy_conf.out = "out"
lbi = logical_blob_id_util.LogicalBlobId()
......@@ -669,6 +670,8 @@ def _AssignOpConf():
op_conf.name = "assign"
op_conf.assign_conf.ref = "assign/ref"
op_conf.assign_conf.value = "assign/value"
device_tag = oneflow.current_scope().device_parallel_desc_symbol.device_tag
op_conf.device_tag = device_tag
return op_conf
......@@ -722,7 +725,7 @@ def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
def _GetEagerNcclAllReduce(parallel_conf, ibn2blob_object):
op_conf = op_conf_pb.OperatorConf()
op_conf.device_type = c_api_util.DeviceType4DeviceTag("gpu")
op_conf.device_tag = "gpu"
op_conf.name = "eager_nccl_all_reduce"
op_conf.user_conf.op_type_name = "eager_nccl_all_reduce"
op_conf.user_conf.input["in"].s.append("eager_nccl_all_reduce/in_0")
......
......@@ -21,7 +21,6 @@ import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.eager.symbol_storage as symbol_storage
import oneflow.python.framework.device_util as device_util
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.op_arg_util as op_arg_util
......@@ -363,7 +362,7 @@ def _GenModelInitOpConfAndRetLbi(var_op_conf):
variable_op_conf.CopyFrom(var_op_conf.variable_conf)
op_conf = op_conf_util.OperatorConf()
op_conf.name = "model_init"
op_conf.device_type = device_util.DeviceType4DeviceTag("cpu")
op_conf.device_tag = "cpu"
op_conf.model_init_conf.out.append("out_0")
op_conf.model_init_conf.variable_op_name.append(var_op_conf.name)
op_conf.model_init_conf.original_variable_conf.append(variable_op_conf)
......@@ -379,7 +378,7 @@ def _GenModelLoadOpConfAndRetLbi(var_op_conf, path_lbi):
op_conf = op_conf_util.OperatorConf()
op_conf.name = "model_load"
op_conf.device_type = device_util.DeviceType4DeviceTag("cpu")
op_conf.device_tag = "cpu"
op_conf.model_load_conf.path = "{}/{}".format(path_lbi.op_name, path_lbi.blob_name)
op_conf.model_load_conf.out.append("out_0")
op_conf.model_load_conf.variable_op_name.append(var_op_conf.name)
......@@ -394,7 +393,7 @@ def _GenModelLoadOpConfAndRetLbi(var_op_conf, path_lbi):
def _GenModelIOPathInputOpConfAndRetLbi():
op_conf = op_conf_util.OperatorConf()
op_conf.name = "model_io_path_input"
op_conf.device_type = device_util.DeviceType4DeviceTag("cpu")
op_conf.device_tag = "cpu"
op_conf.input_conf.out = "out"
blob_conf = op_conf_util.InterfaceBlobConf()
......@@ -413,7 +412,7 @@ def _GenModelIOPathInputOpConfAndRetLbi():
def _GenModelSaveOpConf(var_blobs, path_lbi):
op_conf = op_conf_util.OperatorConf()
op_conf.name = "model_save"
op_conf.device_type = device_util.DeviceType4DeviceTag("cpu")
op_conf.device_tag = "cpu"
op_conf.model_save_conf.path = "{}/{}".format(path_lbi.op_name, path_lbi.blob_name)
for blob in var_blobs:
getattr(op_conf.model_save_conf, "in").append(blob.logical_blob_name)
......
......@@ -554,15 +554,6 @@ def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf):
return text_format.Parse(ofrecord, record_util.OFRecord())
def DeviceType4DeviceTag(device_tag):
device_tag = str(device_tag)
device_type, error_str = oneflow_internal.DeviceType4DeviceTag(device_tag)
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
return device_type
def GetFunctionConfigDef():
func_config_def, error_str = oneflow_internal.GetFunctionConfigDef()
error = text_format.Parse(error_str, error_util.ErrorProto())
......
......@@ -61,9 +61,9 @@ def CurJobAddConsistentOp(op_conf, scope_symbol=None):
if scope_symbol is None:
scope_symbol = oneflow.current_scope()
op_conf.scope_symbol_id = scope_symbol.symbol_id
if not op_conf.HasField("device_type"):
if not op_conf.HasField("device_tag"):
device_tag = scope_symbol.device_parallel_desc_symbol.device_tag
op_conf.device_type = c_api_util.DeviceType4DeviceTag(device_tag)
op_conf.device_tag = device_tag
return c_api_util.CurJobBuildAndInferCtx_AddAndInferConsistentOp(op_conf)
......@@ -72,7 +72,7 @@ def CurJobAddMirroredOp(op_conf, scope_symbol=None):
if scope_symbol is None:
scope_symbol = oneflow.current_scope()
op_conf.scope_symbol_id = scope_symbol.symbol_id
if not op_conf.HasField("device_type"):
if not op_conf.HasField("device_tag"):
device_tag = scope_symbol.device_parallel_desc_symbol.device_tag
op_conf.device_type = c_api_util.DeviceType4DeviceTag(device_tag)
op_conf.device_tag = device_tag
return c_api_util.CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
import oneflow.python.framework.c_api_util as c_api_util
def DeviceType4DeviceTag(device_tag):
global _device_tag2device_type
if device_tag not in _device_tag2device_type:
_device_tag2device_type[device_tag] = c_api_util.DeviceType4DeviceTag(
device_tag
)
return _device_tag2device_type[device_tag]
_device_tag2device_type = {}
......@@ -20,7 +20,6 @@ import re
import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.device_util as device_util
import oneflow.python.framework.op_util as op_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_util as scope_util
......@@ -77,9 +76,6 @@ class PlacementScope(object):
self.GetDeviceTag4OpConf(op_conf), self.machine_device_ids_
)
def GetDeviceType4OpConf(self, op_conf):
return device_util.DeviceType4DeviceTag(self.GetDeviceTag4OpConf(op_conf))
def GetDeviceTag4OpConf(self, op_conf):
return self.default_device_tag
......@@ -109,15 +105,6 @@ def PlacementScopeStackTop():
return session_ctx.GetDefaultSession().placement_scope_stack[0]
def CurPlacementGroupGetDeviceType(op_conf):
assert len(session_ctx.GetDefaultSession().placement_scope_stack) > 0
return (
session_ctx.GetDefaultSession()
.placement_scope_stack[0]
.GetDeviceType4OpConf(op_conf)
)
def ParallelConf4OpConf(op_conf):
assert len(session_ctx.GetDefaultSession().placement_scope_stack) > 0
return (
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment