Skip to content
Snippets Groups Projects
Commit 7ee0d54c authored by Juncheng's avatar Juncheng Committed by Niu Chong
Browse files

Dev kernel launch synchronized (#2230)

* IsKernelLaunchSynchronized

* virtual

* refine

* refine
parent 2d813ffb
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,11 @@ void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto,
exec_kernel_vec_.push_back(std::move(ek));
}
is_kernel_launch_synchronized_ =
std::all_of(exec_kernel_vec_.cbegin(), exec_kernel_vec_.cend(),
[](const ExecKernel& ek) { return ek.kernel->IsKernelLaunchSynchronized(); });
if (!is_kernel_launch_synchronized_) { CHECK_EQ(exec_kernel_vec_.size(), 1); }
remaining_eord_cnt_ = 0;
msg_handler_ = nullptr;
eord_regst_desc_ids_.clear();
......@@ -618,8 +623,9 @@ int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
}
void Actor::EnqueueAsyncMsg(const ActorMsg& msg) {
if (GetGlobalWorkStreamId()
== Global<IDMgr>::Get()->GlobalWorkStreamId4ActorId(msg.dst_actor_id())) {
if (is_kernel_launch_synchronized_
&& GetGlobalWorkStreamId()
== Global<IDMgr>::Get()->GlobalWorkStreamId4ActorId(msg.dst_actor_id())) {
Global<ActorMsgBus>::Get()->SendMsg(msg);
} else {
async_msg_queue_.push_back(msg);
......
......@@ -230,6 +230,7 @@ class Actor {
HashMap<int64_t, int64_t> inplace_regst_desc_id_out2in_;
std::deque<ActorMsg> async_msg_queue_;
bool is_kernel_launch_synchronized_;
};
std::unique_ptr<Actor> NewActor(const TaskProto&, const ThreadCtx&);
......
......@@ -31,6 +31,12 @@ class Kernel {
const LogicalBlobId& BnInOp2Lbi(const std::string& bn_in_op) const;
const OperatorConf& op_conf() const { return op_attribute().op_conf(); }
const OpAttribute& op_attribute() const { return kernel_conf().op_attribute(); }
/*
* return true means all below must be guaranteed when `Launch` function return:
* 1) all out blob header has been set (e.g. SyncSetHeadKernel)
* 2) all asynchronous task has been queued (e.g. NCCL related kernel)
*/
virtual bool IsKernelLaunchSynchronized() const { return true; }
protected:
Kernel() = default;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment