Skip to content
Snippets Groups Projects
Unverified Commit b6e1196f authored by Houjiang Chen's avatar Houjiang Chen Committed by GitHub
Browse files

Use composed attrs when creating kernel state. (#5171)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent adeab621
No related branches found
No related tags found
No related merge requests found
......@@ -179,10 +179,12 @@ class LocalUserKernelInitContext final : public user_op::KernelInitContext {
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple,
const EagerBlobObjectListPtr& inputs,
const EagerBlobObjectListPtr& outputs)
const EagerBlobObjectListPtr& outputs,
const ComposedAttrMap* composed_attrs)
: user_op_conf_(user_op_conf),
device_ctx_(device_ctx),
base_ctx_(device_tag, input_arg_tuple, output_arg_tuple) {
base_ctx_(device_tag, input_arg_tuple, output_arg_tuple),
composed_attrs_(composed_attrs) {
base_ctx_.Update(inputs, outputs);
}
~LocalUserKernelInitContext() override = default;
......@@ -224,7 +226,7 @@ class LocalUserKernelInitContext final : public user_op::KernelInitContext {
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return user_op_conf().Attr4Name(attr_name);
return composed_attrs_->Attr4Name(attr_name);
}
const user_op::UserOpConfWrapper& user_op_conf() const override { return *user_op_conf_; }
......@@ -232,6 +234,7 @@ class LocalUserKernelInitContext final : public user_op::KernelInitContext {
const user_op::UserOpConfWrapper* user_op_conf_;
DeviceCtx* device_ctx_;
LocalUserKernelBaseContext base_ctx_;
const ComposedAttrMap* composed_attrs_;
};
LocalUserOpInferContext::LocalUserOpInferContext(
......@@ -423,7 +426,7 @@ void StatefulLocalOpKernel::TryInitOpKernelState(const user_op::OpKernel* op_ker
auto init_ctx = std::make_shared<LocalUserKernelInitContext>(
device_ctx, op_conf_->device_tag(), user_op_conf_.get(), input_arg_tuple_, output_arg_tuple_,
inputs, outputs);
inputs, outputs, composed_attrs_for_scheduler_thread());
auto created_state = op_kernel->CreateOpKernelState(init_ctx.get());
op_kernel_state_map_.emplace(op_kernel, created_state);
*state = created_state.get();
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment