diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp
index b99e7f8829d5aa2a24dfae6f9ea78e44d991c5e7..59c57fa2fde5039fec87bac159818f84e3bdbb95 100644
--- a/oneflow/core/eager/eager_blob_object.cpp
+++ b/oneflow/core/eager/eager_blob_object.cpp
@@ -38,6 +38,7 @@ EagerBlobObject::EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case,
: BlobObject(mem_case, shape, data_type),
tensor_buffer_(tensor_buffer),
blob_body_bytes_(0),
+ is_shape_synced_(true),
infer_local_dep_object_(GetVmLocalDepObject(parallel_desc)),
compute_local_dep_object_(GetVmLocalDepObject(parallel_desc)) {
CHECK(static_cast<bool>(shape));
diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h
index b0721b91f3b6a179f96295343e0c9ff1326d9511..b41ab71a84790e24b1417b5cd6f372b315862dc9 100644
--- a/oneflow/core/eager/eager_blob_object.h
+++ b/oneflow/core/eager/eager_blob_object.h
@@ -69,12 +69,17 @@ class EagerBlobObject final : public BlobObject {
std::shared_ptr<TensorBuffer>& tensor_buffer() { return tensor_buffer_; }
+ bool is_shape_synced() const { return is_shape_synced_; }
+
+ void set_is_shape_synced(bool val) { is_shape_synced_ = val; }
+
private:
std::unique_ptr<Blob> blob_;
std::unique_ptr<char, std::function<void(char*)>> header_buffer_;
std::shared_ptr<TensorBuffer> tensor_buffer_;
std::size_t blob_body_bytes_;
MemoryAllocator non_pod_initer_;
+ std::atomic<bool> is_shape_synced_;
Maybe<VmLocalDepObject> infer_local_dep_object_;
Maybe<VmLocalDepObject> compute_local_dep_object_;
};
diff --git a/oneflow/core/eager/opkernel_instruction_type.cpp b/oneflow/core/eager/opkernel_instruction_type.cpp
index 692fd6c3bbd2db64d4de18124b37d863abae87af..3320da139d5127a73ac4bfc30287022dc8a59c1e 100644
--- a/oneflow/core/eager/opkernel_instruction_type.cpp
+++ b/oneflow/core/eager/opkernel_instruction_type.cpp
@@ -447,7 +447,6 @@ struct LocalCallOpKernelUtil final {
JUST(operand->mut_opkernel()->ChooseOpKernel(operand->inputs(), operand->outputs())));
operand->mut_opkernel()->ResetDynamicOpAttrs(operand->attrs());
JUST(CheckOutputBlobObjectsMemCase(operand, instruction->stream()));
- JUST(InferOutputTensorDescs(operand));
JUST(InitOutputBlobs(operand));
JUST(InferTempStorageBlobDesc(operand));
JUST(ResetTempStorageBlob(operand));
@@ -523,12 +522,12 @@ struct LocalCallOpKernelUtil final {
const auto& InferTmpSizeFn = operand->opkernel().GetInferTmpSizeFn(operand->user_opkernel());
auto* temp_blob_desc = operand->mut_opkernel()->mut_temp_blob_object()->mut_blob_desc();
CHECK_OR_RETURN(temp_blob_desc->data_type() == DataType::kChar);
- JUST(WithOpInferContext(operand, [&](user_op::InferContext* infer_ctx) -> Maybe<void> {
- size_t temp_size = InferTmpSizeFn(infer_ctx);
- temp_blob_desc->mut_shape() = Shape({static_cast<int64_t>(temp_size)});
- temp_blob_desc->set_is_dynamic(true);
- return Maybe<void>::Ok();
- }));
+ one::LocalUserOpInferContext* op_infer_ctx = operand->opkernel().op_infer_ctx_for_thread_a();
+ op_infer_ctx->Update(operand->inputs(), operand->outputs());
+ size_t temp_size = InferTmpSizeFn(op_infer_ctx);
+ temp_blob_desc->mut_shape() = Shape({static_cast<int64_t>(temp_size)});
+ temp_blob_desc->set_is_dynamic(true);
+ op_infer_ctx->Update(nullptr, nullptr);
return Maybe<void>::Ok();
}
@@ -537,16 +536,6 @@ struct LocalCallOpKernelUtil final {
return Maybe<void>::Ok();
}
- template<typename CallbackT>
- static inline Maybe<void> WithOpInferContext(LocalCallOpKernelPhyInstrOperand* operand,
- const CallbackT& Callback) {
- auto* opkernel = operand->mut_opkernel();
- JUST(Callback(opkernel->UpdateInferContext(operand->inputs(), operand->outputs())));
- // tensor tuples are not allowed to be hold by StatefulLocalOpKernel
- opkernel->UpdateInferContext(nullptr, nullptr);
- return Maybe<void>::Ok();
- }
-
template<typename CallbackT>
static inline Maybe<void> WithComputeContext(LocalCallOpKernelPhyInstrOperand* operand,
DeviceCtx* device_ctx, const CallbackT& Callback) {
@@ -558,10 +547,6 @@ struct LocalCallOpKernelUtil final {
return Maybe<void>::Ok();
}
- static inline Maybe<void> InferOutputTensorDescs(LocalCallOpKernelPhyInstrOperand* operand) {
- return WithOpInferContext(operand, operand->opkernel().TensorDescInferFn());
- }
-
static inline void TryInitOpKernelState(LocalCallOpKernelPhyInstrOperand* operand,
DeviceCtx* device_ctx, user_op::OpKernelState** state) {
operand->mut_opkernel()->TryInitOpKernelState(operand->user_opkernel(), device_ctx,
diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp
index 60f8e38d425d5acba597d5c4d1dc11ffb7804f55..40d9e84f50bb23d187f931fde13dc48d1958f0f0 100644
--- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp
+++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp
@@ -85,7 +85,16 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}
const auto kernel = JUST(user_op_expr.MutKernel4Device(*op_device));
- kernel->InferDataType(input_eager_blob_objects, output_eager_blob_objects);
+
+ for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) {
+ output_eager_blob_objects->at(index)->set_is_shape_synced(false);
+ }
+
+ kernel->ResetDynamicOpAttrs(attrs);
+ JUST(kernel->InferDataType(input_eager_blob_objects, output_eager_blob_objects,
+ kernel->op_infer_ctx_for_thread_b()));
+ JUST(kernel->InferTensorDesc(input_eager_blob_objects, output_eager_blob_objects,
+ kernel->op_infer_ctx_for_thread_b()));
const auto& instr_type_name = JUST(op_device->local_call_instruction_name());
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp
index a9e9e0ab554bb5e6e2e53bc2eccdb05d4d899387..a0a421af6204da1d65c9f2c64247cf61badd0ea3 100644
--- a/oneflow/core/framework/tensor_impl.cpp
+++ b/oneflow/core/framework/tensor_impl.cpp
@@ -94,6 +94,8 @@ Maybe<VmLocalDepObject> EagerMirroredTensorImpl::compute_local_dep_object() cons
}
const std::shared_ptr<const Shape>& EagerMirroredTensorImpl::shape() const {
+ if (eager_blob_object_->is_shape_synced()) { return eager_blob_object_->blob_desc().shape_ptr(); }
+
const std::shared_ptr<const Shape>* result = nullptr;
Global<ForeignLockHelper>::Get()->WithScopedRelease([this, &result]() {
BlockingCounter bc(1);
@@ -108,6 +110,7 @@ const std::shared_ptr<const Shape>& EagerMirroredTensorImpl::shape() const {
CHECK_JUST(PhysicalRun(build_instruction));
bc.WaitUntilCntEqualZero();
});
+ eager_blob_object_->set_is_shape_synced(true);
return *result;
}
diff --git a/oneflow/python/test/ops/test_stateful_local_kernel.py b/oneflow/python/test/ops/test_stateful_local_kernel.py
index b55da9f61b2867ec5c1ea5af2ca6b5b735e07bf2..ee43da9d6123d056e455b47d5130b91373e037ed 100644
--- a/oneflow/python/test/ops/test_stateful_local_kernel.py
+++ b/oneflow/python/test/ops/test_stateful_local_kernel.py
@@ -40,34 +40,27 @@ class TestStatefulLocalKernel(flow.unittest.TestCase):
test_case.assertEqual(y.shape, flow.Size((2, 3, 1)))
def test_stateful_local_kernel(test_case):
- func_config = flow.FunctionConfig()
- func_config.default_logical_view(flow.scope.mirrored_view())
-
- @flow.global_function(function_config=func_config)
- def job():
- op1 = (
- flow.builtin_op("constant")
- .Output("out")
- .Attr("is_floating_value", True)
- .Attr("floating_value", 3.0)
- .Attr("dtype", flow.float32)
- .Attr("shape", [1, 1])
- .Build()
- )
- op2 = (
- flow.builtin_op("matmul")
- .Input("a")
- .Input("b")
- .Attr("transpose_a", False)
- .Attr("transpose_b", False)
- .Attr("alpha", float(1.0))
- .Output("out")
- .Build()
- )
- x = op1()[0]
- x = op2(x, x)[0]
-
- job()
+ op1 = (
+ flow.builtin_op("constant")
+ .Output("out")
+ .Attr("is_floating_value", True)
+ .Attr("floating_value", 3.0)
+ .Attr("dtype", flow.float32)
+ .Attr("shape", [1, 1])
+ .Build()
+ )
+ op2 = (
+ flow.builtin_op("matmul")
+ .Input("a")
+ .Input("b")
+ .Attr("transpose_a", False)
+ .Attr("transpose_b", False)
+ .Attr("alpha", float(1.0))
+ .Output("out")
+ .Build()
+ )
+ x = op1()[0]
+ x = op2(x, x)[0]
if __name__ == "__main__":
diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp
index 4c616710fedf330bcd4503331389450b5c108bb4..430ff413dedd7c7ff66bfdfd97e224e26d89d214 100644
--- a/oneflow/user/kernels/stateful_local_opkernel.cpp
+++ b/oneflow/user/kernels/stateful_local_opkernel.cpp
@@ -330,8 +330,10 @@ Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>
std::make_shared<vm::TensorBuffer>()));
const std::string& device_tag = op_conf->device_tag();
- opkernel->op_infer_ctx_.reset(new LocalUserOpInferContext(opkernel->user_op_conf_.get(),
- input_arg_tuple, output_arg_tuple));
+ opkernel->op_infer_ctx_for_thread_a_.reset(new LocalUserOpInferContext(
+ opkernel->user_op_conf_.get(), input_arg_tuple, output_arg_tuple));
+ opkernel->op_infer_ctx_for_thread_b_.reset(new LocalUserOpInferContext(
+ opkernel->user_op_conf_.get(), input_arg_tuple, output_arg_tuple));
opkernel->compute_ctx_.reset(new LocalUserKernelComputeContext(
nullptr, device_tag, opkernel->user_op_conf_.get(), input_arg_tuple, output_arg_tuple,
opkernel->mut_temp_blob_object()));
@@ -416,10 +418,20 @@ user_op::DataTypeInferFn StatefulLocalOpKernel::DataTypeInferFn() const {
return data_type_infer_fn_;
}
-LocalUserOpInferContext* StatefulLocalOpKernel::UpdateInferContext(
- const EagerBlobObjectListPtr& inputs, const EagerBlobObjectListPtr& outputs) {
- op_infer_ctx_->Update(inputs, outputs);
- return op_infer_ctx_.get();
+Maybe<void> StatefulLocalOpKernel::InferTensorDesc(const EagerBlobObjectListPtr& inputs,
+ const EagerBlobObjectListPtr& outputs,
+ LocalUserOpInferContext* op_infer_ctx) {
+ InputAndOutputListScope<LocalUserOpInferContext> scope(op_infer_ctx, inputs, outputs);
+ JUST(tensor_desc_infer_fn_(op_infer_ctx));
+ return Maybe<void>::Ok();
+}
+
+Maybe<void> StatefulLocalOpKernel::InferDataType(const EagerBlobObjectListPtr& inputs,
+ const EagerBlobObjectListPtr& outputs,
+ LocalUserOpInferContext* op_infer_ctx) {
+ InputAndOutputListScope<LocalUserOpInferContext> scope(op_infer_ctx, inputs, outputs);
+ JUST(data_type_infer_fn_(op_infer_ctx));
+ return Maybe<void>::Ok();
}
LocalUserKernelComputeContext* StatefulLocalOpKernel::UpdateComputeContext(
diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h
index e5b02c105a52d3bc933aeff79ef930498fde2da1..9c1945f628478b83989ea0f3ea0f96cafab2ee15 100644
--- a/oneflow/user/kernels/stateful_local_opkernel.h
+++ b/oneflow/user/kernels/stateful_local_opkernel.h
@@ -279,18 +279,26 @@ class StatefulLocalOpKernel final {
return compute_local_dep_object_;
}
- void InferDataType(const EagerBlobObjectListPtr& inputs, const EagerBlobObjectListPtr& outputs) {
- data_type_infer_fn_(UpdateInferContext(inputs, outputs));
- UpdateInferContext(nullptr, nullptr);
- }
+ Maybe<void> InferTensorDesc(const EagerBlobObjectListPtr& inputs,
+ const EagerBlobObjectListPtr& outputs,
+ LocalUserOpInferContext* op_infer_ctx);
+ Maybe<void> InferDataType(const EagerBlobObjectListPtr& inputs,
+ const EagerBlobObjectListPtr& outputs,
+ LocalUserOpInferContext* op_infer_ctx);
void ResetDynamicOpAttrs(const AttrMap& attrs);
+ LocalUserOpInferContext* op_infer_ctx_for_thread_a() const {
+ return op_infer_ctx_for_thread_a_.get();
+ }
+
+ LocalUserOpInferContext* op_infer_ctx_for_thread_b() const {
+ return op_infer_ctx_for_thread_b_.get();
+ }
+
private:
friend struct vm::LocalCallOpKernelUtil;
StatefulLocalOpKernel() = default;
- LocalUserOpInferContext* UpdateInferContext(const EagerBlobObjectListPtr& inputs,
- const EagerBlobObjectListPtr& outputs);
LocalUserKernelComputeContext* UpdateComputeContext(const EagerBlobObjectListPtr& inputs,
const EagerBlobObjectListPtr& outputs,
DeviceCtx* device_ctx);
@@ -321,7 +329,8 @@ class StatefulLocalOpKernel final {
std::shared_ptr<MemoryCase> mem_case_;
std::unique_ptr<LocalUserKernelRegContext> reg_ctx_;
std::unique_ptr<LocalUserKernelCreateContext> create_ctx_;
- std::unique_ptr<LocalUserOpInferContext> op_infer_ctx_;
+ std::unique_ptr<LocalUserOpInferContext> op_infer_ctx_for_thread_a_;
+ std::unique_ptr<LocalUserOpInferContext> op_infer_ctx_for_thread_b_;
std::unique_ptr<LocalUserKernelComputeContext> compute_ctx_;
std::shared_ptr<const ArgTuple> input_arg_tuple_;
std::shared_ptr<const ArgTuple> output_arg_tuple_;