diff --git a/oneflow/core/framework/device.cpp b/oneflow/core/framework/device.cpp index 452f7474907efae90d2ac7aba2ff2336fa676002..ce1734f7381f76f7a822a31bfc751e7ed2a386c2 100644 --- a/oneflow/core/framework/device.cpp +++ b/oneflow/core/framework/device.cpp @@ -48,6 +48,10 @@ Maybe<void> Device::Init() { return std::shared_ptr<const Device>(device); } +/*static*/ Maybe<const Device> Device::New(const std::string& type) { + return New(type, GlobalProcessCtx::Rank() % GlobalProcessCtx::NumOfProcessPerNode()); +} + const std::shared_ptr<const ParallelDesc>& Device::parallel_desc_ptr() const { return Global<EnvGlobalObjectsScope>::Get()->MutParallelDesc4Device(*this); } diff --git a/oneflow/core/framework/device.h b/oneflow/core/framework/device.h index f304122cac1e6dd7b7b3adbb741eea3324b04985..0844c1f8f59ebd847c324e75a8a291e063bf4a42 100644 --- a/oneflow/core/framework/device.h +++ b/oneflow/core/framework/device.h @@ -43,6 +43,7 @@ class Device final { const std::shared_ptr<MemoryCase>& mem_case() const { return mem_case_; } static Maybe<const Device> New(const std::string& type, int64_t device_id); + static Maybe<const Device> New(const std::string& typed); static Maybe<const ParallelDesc> MakeParallelDescByDevice(const Device& device); static Maybe<const Device> MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc); diff --git a/oneflow/core/framework/user_op_registry.cpp b/oneflow/core/framework/user_op_registry.cpp index 084b59c0e6652a68ae12dc752fa0b43df5d3e12f..f449a16c9eae08c21f54ff1f0e846947139284e7 100644 --- a/oneflow/core/framework/user_op_registry.cpp +++ b/oneflow/core/framework/user_op_registry.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/device.h" #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/framework/attr_value_accessor.h" @@ -239,6 +240,28 @@ OpRegistry& OpRegistry::Finish() { if (result_.get_sbp_fn == nullptr) { result_.get_sbp_fn = GetSbpFnUtil::DefaultBroadcastToBroadcast; } + if (result_.cpu_only_supported && result_.device_infer_fn == nullptr) { + result_.device_infer_fn = [](DeviceInferContext* ctx) -> Maybe<const Device> { + for (const auto& pair : ctx->inputs()) { + const std::shared_ptr<const Device>& input_device = + ctx->InputTensorDevice4ArgNameAndIndex(pair.first, pair.second); + CHECK_EQ_OR_RETURN(JUST(input_device->of_type()), "cpu"); + } + std::shared_ptr<const Device> default_device; + { + if (ctx->inputs().size() != 0) { + const auto& first_input_name = ctx->inputs().begin()->first; + default_device = ctx->InputTensorDevice4ArgNameAndIndex(first_input_name, 0); + } else { + default_device = JUST(Device::New("cpu")); + } + } + for (const auto& pair : ctx->outputs()) { + *ctx->OutputTensorDevice4ArgNameAndIndex(pair.first, pair.second) = default_device; + } + return default_device; + }; + } return *this; }