Skip to content
Snippets Groups Projects
Unverified Commit 110f94b6 authored by binbinHan's avatar binbinHan Committed by GitHub
Browse files

Add device infer fn cpu only (#4832)


* device_infer_fn

* Device::local_call_instruction_name

* implement UserOpExprDeviceInferContext

* merge master

* refactor UserOpConfWrapper

* fix a ci bug

* Refine (#4825)

* fix segmentation fault bug

* add cpu only device_infer_fn

* optimize code

* del blank line

* minor fix

* override Device::New

* minor fix

Co-authored-by: default avatarlixinqi <lixinqi0703106@163.com>
Co-authored-by: default avatarLi Xinqi <lixinqi2010@gmail.com>
Co-authored-by: default avatarHoujiang Chen <chenhoujiangcug@gmail.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 866a8b24
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,10 @@ Maybe<void> Device::Init() {
return std::shared_ptr<const Device>(device);
}
/*static*/ Maybe<const Device> Device::New(const std::string& type) {
return New(type, GlobalProcessCtx::Rank() % GlobalProcessCtx::NumOfProcessPerNode());
}
const std::shared_ptr<const ParallelDesc>& Device::parallel_desc_ptr() const {
return Global<EnvGlobalObjectsScope>::Get()->MutParallelDesc4Device(*this);
}
......
......@@ -43,6 +43,7 @@ class Device final {
const std::shared_ptr<MemoryCase>& mem_case() const { return mem_case_; }
static Maybe<const Device> New(const std::string& type, int64_t device_id);
static Maybe<const Device> New(const std::string& typed);
static Maybe<const ParallelDesc> MakeParallelDescByDevice(const Device& device);
static Maybe<const Device> MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc);
......
......@@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/user_op_registry.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/infer_util.h"
#include "oneflow/core/framework/attr_value.h"
#include "oneflow/core/framework/attr_value_accessor.h"
......@@ -239,6 +240,28 @@ OpRegistry& OpRegistry::Finish() {
if (result_.get_sbp_fn == nullptr) {
result_.get_sbp_fn = GetSbpFnUtil::DefaultBroadcastToBroadcast;
}
if (result_.cpu_only_supported && result_.device_infer_fn == nullptr) {
result_.device_infer_fn = [](DeviceInferContext* ctx) -> Maybe<const Device> {
for (const auto& pair : ctx->inputs()) {
const std::shared_ptr<const Device>& input_device =
ctx->InputTensorDevice4ArgNameAndIndex(pair.first, pair.second);
CHECK_EQ_OR_RETURN(JUST(input_device->of_type()), "cpu");
}
std::shared_ptr<const Device> default_device;
{
if (ctx->inputs().size() != 0) {
const auto& first_input_name = ctx->inputs().begin()->first;
default_device = ctx->InputTensorDevice4ArgNameAndIndex(first_input_name, 0);
} else {
default_device = JUST(Device::New("cpu"));
}
}
for (const auto& pair : ctx->outputs()) {
*ctx->OutputTensorDevice4ArgNameAndIndex(pair.first, pair.second) = default_device;
}
return default_device;
};
}
return *this;
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment