Skip to content
Snippets Groups Projects
Unverified Commit 36fe61c9 authored by Juncheng's avatar Juncheng Committed by GitHub
Browse files

AutoRegistrationFactory add key type (#3660)


* AutoRegistrationFactory add key type

* AutoRegistrationFactory add creators accessores (#3662)

* AutoRegistrationFactory add creators accessores

* refine

* const

Co-authored-by: default avataroneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
parent 76631a41
No related branches found
No related tags found
No related merge requests found
Showing
with 111 additions and 70 deletions
......@@ -703,7 +703,7 @@ Regst* Actor::GetNaiveCurWriteable(int64_t regst_desc_id) const {
}
std::unique_ptr<Actor> NewActor(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
Actor* rptr = NewObj<Actor>(task_proto.task_type());
Actor* rptr = NewObj<int32_t, Actor>(task_proto.task_type());
const auto& job_descs = *Global<RuntimeJobDescs>::Get();
rptr->Init(&job_descs.job_desc(task_proto.job_id()), task_proto, thread_ctx);
return std::unique_ptr<Actor>(rptr);
......
......@@ -252,7 +252,7 @@ class Actor {
std::unique_ptr<Actor> NewActor(const TaskProto&, const ThreadCtx&);
#define REGISTER_ACTOR(task_type, ActorType) REGISTER_CLASS(task_type, Actor, ActorType)
#define REGISTER_ACTOR(task_type, ActorType) REGISTER_CLASS(int32_t, task_type, Actor, ActorType)
} // namespace oneflow
......
......@@ -20,59 +20,81 @@ limitations under the License.
namespace oneflow {
template<typename Base, typename... Args>
template<typename Key, typename Base, typename... Args>
struct AutoRegistrationFactory {
public:
using Creator = std::function<Base*(Args&&...)>;
template<typename Derived>
struct RawRegisterType {
RawRegisterType(int32_t k) {
CHECK((AutoRegistrationFactory<Base, Args...>::Get()
.creators_.emplace(k, [](Args&&...) { return new Derived; })
RawRegisterType(Key k) {
CHECK((AutoRegistrationFactory<Key, Base, Args...>::Get()
.mutable_creators()
->emplace(k, [](Args&&...) { return new Derived; })
.second))
<< k;
}
};
struct CreatorRegisterType {
CreatorRegisterType(int32_t k, std::function<Base*(Args&&...)> v) {
CHECK((AutoRegistrationFactory<Base, Args...>::Get().creators_.emplace(k, v).second)) << k;
CreatorRegisterType(Key k, Creator v) {
CHECK((AutoRegistrationFactory<Key, Base, Args...>::Get()
.mutable_creators()
->emplace(k, v)
.second))
<< k;
}
};
Base* New(int32_t k, Args&&... args) {
auto creators_it = creators_.find(k);
CHECK(creators_it != creators_.end()) << "Unregistered: " << k;
Base* New(Key k, Args&&... args) const {
auto creators_it = creators().find(k);
CHECK(creators_it != creators().end()) << "Unregistered: " << k;
return creators_it->second(std::forward<Args>(args)...);
}
bool IsClassRegistered(int32_t k, Args&&... args) { return creators_.find(k) != creators_.end(); }
bool IsClassRegistered(Key k, Args&&... args) const {
return creators().find(k) != creators().end();
}
static AutoRegistrationFactory<Base, Args...>& Get() {
static AutoRegistrationFactory<Base, Args...> obj;
static AutoRegistrationFactory<Key, Base, Args...>& Get() {
static AutoRegistrationFactory<Key, Base, Args...> obj;
return obj;
}
private:
HashMap<int32_t, std::function<Base*(Args&&...)>> creators_;
std::unique_ptr<HashMap<Key, Creator>> creators_;
bool has_creators() const { return creators_.get() != nullptr; }
const HashMap<Key, Creator>& creators() const {
CHECK(has_creators()) << "Unregistered key type: " << typeid(Key).name();
return *creators_.get();
}
HashMap<Key, Creator>* mutable_creators() {
if (!creators_) { creators_.reset(new HashMap<Key, Creator>); }
return creators_.get();
}
};
#define REGISTER_VAR_NAME OF_PP_CAT(g_registry_var, __COUNTER__)
#define REGISTER_CLASS(k, Base, Derived) \
static AutoRegistrationFactory<Base>::RawRegisterType<Derived> REGISTER_VAR_NAME(k)
#define REGISTER_CLASS_WITH_ARGS(k, Base, Derived, ...) \
static AutoRegistrationFactory<Base, __VA_ARGS__>::RawRegisterType<Derived> REGISTER_VAR_NAME(k)
#define REGISTER_CLASS_CREATOR(k, Base, f, ...) \
static AutoRegistrationFactory<Base, ##__VA_ARGS__>::CreatorRegisterType REGISTER_VAR_NAME(k, f)
template<typename Base, typename... Args>
inline Base* NewObj(int32_t k, Args&&... args) {
return AutoRegistrationFactory<Base, Args...>::Get().New(k, std::forward<Args>(args)...);
#define REGISTER_CLASS(Key, k, Base, Derived) \
static AutoRegistrationFactory<Key, Base>::RawRegisterType<Derived> REGISTER_VAR_NAME(k)
#define REGISTER_CLASS_WITH_ARGS(Key, k, Base, Derived, ...) \
static AutoRegistrationFactory<Key, Base, __VA_ARGS__>::RawRegisterType<Derived> \
REGISTER_VAR_NAME(k)
#define REGISTER_CLASS_CREATOR(Key, k, Base, f, ...) \
static AutoRegistrationFactory<Key, Base, ##__VA_ARGS__>::CreatorRegisterType REGISTER_VAR_NAME( \
k, f)
template<typename Key, typename Base, typename... Args>
inline Base* NewObj(Key k, Args&&... args) {
return AutoRegistrationFactory<Key, Base, Args...>::Get().New(k, std::forward<Args>(args)...);
}
template<typename Base, typename... Args>
inline bool IsClassRegistered(int32_t k, Args&&... args) {
return AutoRegistrationFactory<Base, Args...>::Get().IsClassRegistered(
template<typename Key, typename Base, typename... Args>
inline bool IsClassRegistered(Key k, Args&&... args) {
return AutoRegistrationFactory<Key, Base, Args...>::Get().IsClassRegistered(
k, std::forward<Args>(args)...);
}
......
......@@ -276,7 +276,7 @@ REGISTER_DEFAULT_MEMORY_COPIER(DeviceType::kGPU, []() { return new CudaAsyncMemo
MemoryCopier* NewDefaultMemoryCopier(DeviceType device_type) {
return std::unique_ptr<DefaultMemoryCopierCreator>(
NewObj<DefaultMemoryCopierCreator>(device_type))
NewObj<int32_t, DefaultMemoryCopierCreator>(device_type))
->Create();
}
......
......@@ -108,8 +108,8 @@ class DefaultMemoryCopierCreator final {
const Func func_;
};
#define REGISTER_DEFAULT_MEMORY_COPIER(device_type, func) \
REGISTER_CLASS_CREATOR(device_type, DefaultMemoryCopierCreator, \
#define REGISTER_DEFAULT_MEMORY_COPIER(device_type, func) \
REGISTER_CLASS_CREATOR(int32_t, device_type, DefaultMemoryCopierCreator, \
([] { return new DefaultMemoryCopierCreator(func); }))
MemoryCopier* NewDefaultMemoryCopier(DeviceType device_type);
......
......@@ -87,12 +87,6 @@ OpRegistry& OpRegistry::SetOutputBufferNum(int32_t num) {
return *this;
}
OpRegistry& OpRegistry::SetAreaId(int64_t area_id) {
CHECK_NE(area_id, AreaType::kInvalidArea);
result_.area_id = area_id;
return *this;
}
OpRegistry& OpRegistry::Attr(const std::string& name, UserOpAttrType type) {
CHECK(InsertIfNotExists(name, &unique_names_));
UserOpDef::AttrDef attr_def;
......
......@@ -22,7 +22,6 @@ limitations under the License.
#include "oneflow/core/framework/user_op_attr.pb.h"
#include "oneflow/core/framework/user_op_conf.pb.h"
#include "oneflow/core/operator/op_attribute.pb.h"
#include "oneflow/core/job/task.pb.h"
namespace oneflow {
......@@ -48,8 +47,7 @@ using GetOutputArgModifier =
using OutputArgModifyFn = std::function<void(GetOutputArgModifier, const UserOpConfWrapper&)>;
struct OpRegistryResult {
OpRegistryResult()
: cpu_only_supported(false), same_output_regst_num(-1), area_id(AreaType::kInvalidArea) {}
OpRegistryResult() : cpu_only_supported(false), same_output_regst_num(-1) {}
~OpRegistryResult() = default;
std::string op_type_name;
......@@ -64,7 +62,6 @@ struct OpRegistryResult {
// performance other than op definition
InputArgModifyFn input_arg_modify_fn;
OutputArgModifyFn output_arg_modify_fn;
int64_t area_id;
};
class OpRegistry final {
......@@ -87,7 +84,6 @@ class OpRegistry final {
OpRegistry& SupportCpuOnly();
OpRegistry& SetOutputBufferNum(int32_t num);
OpRegistry& SetAreaId(int64_t area_id);
OpRegistry& Attr(const std::string& name, UserOpAttrType type);
template<typename T>
......
......@@ -27,7 +27,7 @@ Maybe<SubTskGphBuilderStatus> ToInterfaceSubTskGphBuilder::Build(
const SbpParallel& dst_sbp_parallel) const {
const LogicalNode* dst_logical_node = sorted_dst_comp_tasks.front()->logical_node();
if (dst_logical_node->op_vec().size() != 1) { return Error::BoxingNotSupportedError(); }
if (!IsClassRegistered<IsInterfaceOpConf4OpTypeCase>(
if (!IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>(
dst_logical_node->SoleOp()->op_conf().op_type_case())) {
return Error::BoxingNotSupportedError();
}
......
......@@ -197,7 +197,8 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic
CHECK(src_pd->parallel_num() == dst_pd->parallel_num());
}
auto IsTickNode = [&](const LogicalNode* node) {
return IsClassRegistered<IsTickTockOpTypeCase>(node->SoleOp()->op_conf().op_type_case());
return IsClassRegistered<int32_t, IsTickTockOpTypeCase>(
node->SoleOp()->op_conf().op_type_case());
};
if (IsTickNode(src_node) || IsTickNode(dst_node)) {
if (src_pd->parallel_num() > 1 && dst_pd->parallel_num() == 1
......@@ -273,11 +274,11 @@ CompTaskNode* NormalForwardLogicalNode::NewCompTaskNode() const {
int64_t NormalForwardLogicalNode::GetAreaId() const {
if (this->SoleOp()->op_conf().has_user_conf()) {
auto* registration_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(
this->SoleOp()->op_conf().user_conf().op_type_name());
CHECK_NOTNULL(registration_val);
if (registration_val->area_id != AreaType::kInvalidArea) {
return registration_val->area_id;
const std::string& op_type_name = this->SoleOp()->op_conf().user_conf().op_type_name();
if (IsClassRegistered<std::string, UserOpAreaIdCreator>(op_type_name)) {
return std::unique_ptr<UserOpAreaIdCreator>(
NewObj<std::string, UserOpAreaIdCreator>(op_type_name))
->GetAreaId();
} else {
return AreaType::kDataForwardArea;
}
......@@ -297,4 +298,11 @@ int64_t NewAreaId() {
return ++next_area_id;
}
REGISTER_USER_OP_AREA_ID("sgd_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("indexed_slices_sgd_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("momentum_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("indexed_slices_momentum_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("adam_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("indexed_slices_adam_update", AreaType::kMdUpdtArea)
} // namespace oneflow
......@@ -211,6 +211,26 @@ DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Case);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Esac);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(DecodeH2D);
class UserOpAreaIdCreator {
public:
virtual ~UserOpAreaIdCreator() = default;
virtual int64_t GetAreaId() = 0;
};
class FixedUserOpAreaIdCreator : public UserOpAreaIdCreator {
public:
explicit FixedUserOpAreaIdCreator(int64_t area_id) : area_id_(area_id) {}
int64_t GetAreaId() override { return area_id_; }
private:
int64_t area_id_;
};
#define REGISTER_USER_OP_AREA_ID(op_type_name, area_id) \
REGISTER_CLASS_CREATOR(std::string, op_type_name, UserOpAreaIdCreator, \
([] { return new FixedUserOpAreaIdCreator(area_id); }));
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOGICAL_NODE_H_
......@@ -25,9 +25,9 @@ namespace oneflow {
namespace {
size_t RegstNum4OpSameOutputBlob(OperatorConf::OpTypeCase op_type_case) {
if (IsClassRegistered<RuntimeRegstNum4OpSameOutputBlob>(op_type_case)) {
if (IsClassRegistered<int32_t, RuntimeRegstNum4OpSameOutputBlob>(op_type_case)) {
std::unique_ptr<RuntimeRegstNum4OpSameOutputBlob> ptr;
ptr.reset(NewObj<RuntimeRegstNum4OpSameOutputBlob>(op_type_case));
ptr.reset(NewObj<int32_t, RuntimeRegstNum4OpSameOutputBlob>(op_type_case));
return *ptr;
} else {
return -1;
......
......@@ -44,7 +44,7 @@ bool IsInterfaceTask(const TaskNode* node) {
if (comp_task_node == nullptr) { return false; }
if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
auto op_type_case = comp_task_node->logical_node()->SoleOp()->op_conf().op_type_case();
return IsClassRegistered<IsInterfaceOpConf4OpTypeCase>(op_type_case);
return IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>(op_type_case);
}
bool IsConnectToTickOp(const TaskNode* node) {
......@@ -603,7 +603,7 @@ TaskNode* TaskGraph::BuildTaskStep(
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
if (IsInterfaceTask(task)) { return nullptr; }
if (IsClassRegistered<TickTockTaskType>(task->GetTaskType())) { return nullptr; }
if (IsClassRegistered<int32_t, TickTockTaskType>(task->GetTaskType())) { return nullptr; }
CHECK_EQ(task->device_type(), DeviceType::kGPU);
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
......
......@@ -185,14 +185,15 @@ struct IndependentThreadNum4TaskType final {
std::function<size_t()> get_num_;
};
#define REGISTER_INDEPENDENT_THREAD_NUM(task_type, ...) \
REGISTER_CLASS_CREATOR(task_type, IndependentThreadNum4TaskType, \
#define REGISTER_INDEPENDENT_THREAD_NUM(task_type, ...) \
REGISTER_CLASS_CREATOR(int32_t, task_type, IndependentThreadNum4TaskType, \
([] { return new IndependentThreadNum4TaskType(__VA_ARGS__); }))
struct TickTockTaskType final {};
#define REGISTER_TICK_TOCK_TASK_TYPE(task_type) \
REGISTER_CLASS_CREATOR(task_type, TickTockTaskType, ([] { return new TickTockTaskType; }))
#define REGISTER_TICK_TOCK_TASK_TYPE(task_type) \
REGISTER_CLASS_CREATOR(int32_t, task_type, TickTockTaskType, \
([] { return new TickTockTaskType; }))
} // namespace oneflow
......
......@@ -94,7 +94,7 @@ const UserOpAttrVal& JobDesc::GetFunctionFlagVal(const std::string& field_name)
}
bool IsInterfaceOpConf(const OperatorConf& op_conf) {
return IsClassRegistered<IsInterfaceOpConf4OpTypeCase>(op_conf.op_type_case());
return IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>(op_conf.op_type_case());
}
GlobalJobDescScope::GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id) {
......
......@@ -24,7 +24,7 @@ ThrdIdGenerator::ThrdIdGenerator(std::vector<std::pair<int64_t, TaskType>>& mach
HashMap<int64_t, std::set<TaskType>> machine2task_types;
// machine_task_type = <machine_id, task_type>
for (const auto& machine_task_type : machine_task_type_vec) {
if (IsClassRegistered<TickTockTaskType>(machine_task_type.second)) { continue; }
if (IsClassRegistered<int32_t, TickTockTaskType>(machine_task_type.second)) { continue; }
if (TaskTypeThrdNumEqMax(machine_task_type.second,
machine_task_type2thrd_num_[machine_task_type])) {
continue;
......@@ -37,7 +37,7 @@ ThrdIdGenerator::ThrdIdGenerator(std::vector<std::pair<int64_t, TaskType>>& mach
}
int64_t ThrdIdGenerator::GenerateThrdId(int64_t machine_id, int64_t task_type) {
if (IsClassRegistered<TickTockTaskType>(task_type)) {
if (IsClassRegistered<int32_t, TickTockTaskType>(task_type)) {
return Global<IDMgr>::Get()->TickTockThrdId();
}
auto key = std::make_pair(machine_id, task_type);
......@@ -55,9 +55,9 @@ int64_t ThrdIdGenerator::GetModThrdId(std::pair<int64_t, int64_t> machine_task_t
}
bool ThrdIdGenerator::TaskTypeThrdNumEqMax(int64_t task_type, int32_t thrd_num) {
if (IsClassRegistered<IndependentThreadNum4TaskType>(task_type)) {
if (IsClassRegistered<int32_t, IndependentThreadNum4TaskType>(task_type)) {
std::unique_ptr<IndependentThreadNum4TaskType> thread_num;
thread_num.reset(NewObj<IndependentThreadNum4TaskType>(task_type));
thread_num.reset(NewObj<int32_t, IndependentThreadNum4TaskType>(task_type));
return (thrd_num == *thread_num);
} else {
return false;
......
......@@ -550,10 +550,10 @@ Maybe<void> GenerateBackwardOpConfIf(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4BnInOp) {
std::unique_ptr<GenerateBackwardOpConfWrapperStruct> obj;
const auto& op_type_case = op.op_conf().op_type_case();
if (!IsClassRegistered<GenerateBackwardOpConfWrapperStruct>(op_type_case)) {
if (!IsClassRegistered<int32_t, GenerateBackwardOpConfWrapperStruct>(op_type_case)) {
return Error::GradientFunctionNotFound() << PbMessage2TxtString(op.op_conf());
}
obj.reset(NewObj<GenerateBackwardOpConfWrapperStruct>(op_type_case));
obj.reset(NewObj<int32_t, GenerateBackwardOpConfWrapperStruct>(op_type_case));
return obj->Call(op, op_confs, DiffLbi4BnInOp, LogicalBlobDesc4BnInOp);
}
......
......@@ -63,8 +63,8 @@ class GenerateBackwardOpConfWrapperStruct final {
const std::unique_ptr<const MaybeFunc> maybe_func_;
};
#define REGISTER_OP_GRAD(op_type_case, gen_grad_func) \
REGISTER_CLASS_CREATOR(op_type_case, GenerateBackwardOpConfWrapperStruct, \
#define REGISTER_OP_GRAD(op_type_case, gen_grad_func) \
REGISTER_CLASS_CREATOR(int32_t, op_type_case, GenerateBackwardOpConfWrapperStruct, \
([] { return new GenerateBackwardOpConfWrapperStruct(gen_grad_func); }))
} // namespace oneflow
......
......@@ -24,8 +24,8 @@ namespace {
std::unique_ptr<MutOpConTickInputHelper> NewMutOpConTickInputHelper(const OperatorConf& op_conf) {
std::unique_ptr<MutOpConTickInputHelper> ret;
if (IsClassRegistered<MutOpConTickInputHelper>(op_conf.op_type_case())) {
ret.reset(NewObj<MutOpConTickInputHelper>(op_conf.op_type_case()));
if (IsClassRegistered<int32_t, MutOpConTickInputHelper>(op_conf.op_type_case())) {
ret.reset(NewObj<int32_t, MutOpConTickInputHelper>(op_conf.op_type_case()));
ret->InitFromOpConf(op_conf);
}
return ret;
......
......@@ -45,7 +45,7 @@ class MutOpConTickInputHelper {
};
#define REGISTER_AUTO_TICK(op_type_case, HelperType) \
REGISTER_CLASS(op_type_case, MutOpConTickInputHelper, HelperType)
REGISTER_CLASS(int32_t, op_type_case, MutOpConTickInputHelper, HelperType)
} // namespace oneflow
......
......@@ -26,7 +26,7 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder)
HashMap<const OpNode*, OperatorConf> op_node2op_conf;
op_graph.ForEachNode([&](const OpNode* node) {
OperatorConf::OpTypeCase op_type_case = node->op().op_conf().op_type_case();
if (IsClassRegistered<DisableInputBoxingGroup>(op_type_case)) { return; }
if (IsClassRegistered<int32_t, DisableInputBoxingGroup>(op_type_case)) { return; }
for (const std::string& ibn : node->op().input_bns()) {
const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn);
const OpNode& producer = node->ProducerOpNode4Lbi(lbi);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment