From 4131ea97ad9d8fc72138d706a8b82c5945540a9e Mon Sep 17 00:00:00 2001 From: daquexian <daquexian566@gmail.com> Date: Thu, 10 Sep 2020 18:35:46 +0800 Subject: [PATCH] split BlobObject and EagerBlobObject (#3485) * split BaseObject and EagerBaseObject * Use BlobObject in FeedOrFetch, rename functions * ForEachIbnAndEagerBlobObject->ForEachIbnAndBlobObject * address reviews * give lazy ref blob a full blob desc * rename ForEachObnAndEagerBlobObject->ForEachMutBnAndBlobObject, ForEachIbnAndBlobObject->ForEachConstBnAndBlobObject * fix bug --- oneflow/core/eager/blob_object.cpp | 50 ----------- oneflow/core/eager/blob_object.h | 29 ++----- oneflow/core/eager/eager_blob_object.cpp | 73 ++++++++++++++++ oneflow/core/eager/eager_blob_object.h | 59 +++++++++++++ oneflow/core/eager/lazy_ref_blob_object.h | 18 ++-- .../core/eager/opkernel_instruction_type.cpp | 83 ++++++++++--------- 6 files changed, 193 insertions(+), 119 deletions(-) create mode 100644 oneflow/core/eager/eager_blob_object.cpp create mode 100644 oneflow/core/eager/eager_blob_object.h diff --git a/oneflow/core/eager/blob_object.cpp b/oneflow/core/eager/blob_object.cpp index aa6fbc18a..fdb24e468 100644 --- a/oneflow/core/eager/blob_object.cpp +++ b/oneflow/core/eager/blob_object.cpp @@ -14,32 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/eager/blob_object.h" -#include "oneflow/core/vm/allocator.h" #include "oneflow/core/job/parallel_desc.h" -#include "oneflow/core/framework/to_string.h" namespace oneflow { namespace eager { -Maybe<void> BlobObject::TryInitBlob() { - if (!blob_) { JUST(InitBlob()); } - return Maybe<void>::Ok(); -} - -Maybe<void> BlobObject::InitBlob() { - CHECK_NE_OR_RETURN(blob_desc_.data_type(), DataType::kInvalidDataType); - rt_blob_desc_.reset(new RtBlobDesc(blob_desc_)); - { - header_buffer_.reset(); - int64_t header_byte_size = rt_blob_desc_->ByteSizeOfBlobHeader(); - const auto& FreeHeader = [header_byte_size](char* dptr) { std::free(dptr); }; - char* ptr = reinterpret_cast<char*>(std::malloc(header_byte_size)); - header_buffer_ = std::unique_ptr<char, std::function<void(char*)>>(ptr, FreeHeader); - } - blob_.reset(new Blob(*mem_case_, rt_blob_desc_.get(), header_buffer_.get(), nullptr)); - return Maybe<void>::Ok(); -} - Maybe<void> BlobObject::CheckMemCase(const ParallelDesc& parallel_desc, int64_t machine_id) const { CHECK_OR_RETURN(parallel_desc.HasMachineId(machine_id)) << "ParallelDesc does not contain machine_id: " << machine_id; @@ -60,34 +39,5 @@ Maybe<void> BlobObject::CheckMemCase(const ParallelDesc& parallel_desc, int64_t return Maybe<void>::Ok(); } -void BlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) { - vm::Allocator* allocator = device_ctx->mut_allocator(); - CHECK_NOTNULL(allocator); - Blob* blob = mut_blob(); - CHECK_NOTNULL(blob); - const std::size_t required_body_bytes = blob->AlignedByteSizeOfBlobBody(); - if (required_body_bytes == 0) { - CHECK_ISNULL(blob->dptr()); - return; - } - if (blob->dptr() != nullptr) { - CHECK_EQ(blob_body_bytes_, required_body_bytes); - return; - } - { - // reset blob_dptr_; - const auto& Free = [allocator, required_body_bytes](char* dptr) { - allocator->Deallocate(dptr, required_body_bytes); - }; - char* dptr = nullptr; - blob_dptr_.reset(); - allocator->Allocate(&dptr, required_body_bytes); - blob_dptr_ = std::unique_ptr<char, std::function<void(char*)>>(dptr, Free); - blob->reset_dptr(dptr); - InitNonPODTypeBlobIfNeed(&non_pod_initer_, blob_.get()); - } - blob_body_bytes_ = required_body_bytes; -} - } // namespace eager } // namespace oneflow diff --git a/oneflow/core/eager/blob_object.h b/oneflow/core/eager/blob_object.h index c5d59de19..a1240c1e3 100644 --- a/oneflow/core/eager/blob_object.h +++ b/oneflow/core/eager/blob_object.h @@ -19,8 +19,6 @@ limitations under the License. #include "oneflow/core/vm/object.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/register/blob.h" -#include "oneflow/core/common/maybe.h" -#include "oneflow/core/memory/memory_allocator.h" namespace oneflow { @@ -30,36 +28,25 @@ namespace eager { class BlobObject : public vm::Object { public: + BlobObject(const std::shared_ptr<MemoryCase>& mem_case, DataType data_type) + : mem_case_(mem_case), blob_desc_(data_type) {} BlobObject(const BlobObject&) = delete; BlobObject(BlobObject&&) = delete; - BlobObject(const std::shared_ptr<MemoryCase>& mem_case, DataType data_type) - : mem_case_(mem_case), blob_body_bytes_(0), blob_desc_(data_type) {} virtual ~BlobObject() override = default; const BlobDesc& blob_desc() const { return blob_desc_; } - BlobDesc* mut_blob_desc() { return &blob_desc_; } + virtual BlobDesc* mut_blob_desc() = 0; - virtual const Blob& blob() const { return *blob_; } - virtual Blob* mut_blob() { return blob_.get(); } - virtual Maybe<void> TryInitBlob(); + virtual const Blob& blob() const = 0; + virtual Blob* mut_blob() = 0; + virtual Maybe<void> TryInitBlob() = 0; + virtual void TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) = 0; Maybe<void> CheckMemCase(const ParallelDesc& parallel_desc, int64_t machine_id) const; - void TryAllocateBlobBodyMemory(DeviceCtx* device_ctx); - - private: - Maybe<void> InitBlob(); - - std::shared_ptr<MemoryCase> mem_case_; - std::unique_ptr<Blob> blob_; - std::unique_ptr<char, std::function<void(char*)>> header_buffer_; - std::unique_ptr<char, std::function<void(char*)>> blob_dptr_; - std::size_t blob_body_bytes_; - MemoryAllocator non_pod_initer_; - protected: + std::shared_ptr<MemoryCase> mem_case_; BlobDesc blob_desc_; - std::unique_ptr<RtBlobDesc> rt_blob_desc_; }; } // namespace eager diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp new file mode 100644 index 000000000..3c35f8389 --- /dev/null +++ b/oneflow/core/eager/eager_blob_object.cpp @@ -0,0 +1,73 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/eager/eager_blob_object.h" +#include "oneflow/core/vm/allocator.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/framework/to_string.h" + +namespace oneflow { +namespace eager { + +Maybe<void> EagerBlobObject::TryInitBlob() { + if (!blob_) { JUST(InitBlob()); } + return Maybe<void>::Ok(); +} + +Maybe<void> EagerBlobObject::InitBlob() { + CHECK_NE_OR_RETURN(blob_desc_.data_type(), DataType::kInvalidDataType); + rt_blob_desc_.reset(new RtBlobDesc(blob_desc_)); + { + header_buffer_.reset(); + int64_t header_byte_size = rt_blob_desc_->ByteSizeOfBlobHeader(); + const auto& FreeHeader = [header_byte_size](char* dptr) { std::free(dptr); }; + char* ptr = reinterpret_cast<char*>(std::malloc(header_byte_size)); + header_buffer_ = std::unique_ptr<char, std::function<void(char*)>>(ptr, FreeHeader); + } + blob_.reset(new Blob(*mem_case_, rt_blob_desc_.get(), header_buffer_.get(), nullptr)); + return Maybe<void>::Ok(); +} + +void EagerBlobObject::TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) { + vm::Allocator* allocator = device_ctx->mut_allocator(); + CHECK_NOTNULL(allocator); + Blob* blob = mut_blob(); + CHECK_NOTNULL(blob); + const std::size_t required_body_bytes = blob->AlignedByteSizeOfBlobBody(); + if (required_body_bytes == 0) { + CHECK_ISNULL(blob->dptr()); + return; + } + if (blob->dptr() != nullptr) { + CHECK_EQ(blob_body_bytes_, required_body_bytes); + return; + } + { + // reset blob_dptr_; + const auto& Free = [allocator, required_body_bytes](char* dptr) { + allocator->Deallocate(dptr, required_body_bytes); + }; + char* dptr = nullptr; + blob_dptr_.reset(); + allocator->Allocate(&dptr, required_body_bytes); + blob_dptr_ = std::unique_ptr<char, std::function<void(char*)>>(dptr, Free); + blob->reset_dptr(dptr); + InitNonPODTypeBlobIfNeed(&non_pod_initer_, blob_.get()); + } + blob_body_bytes_ = required_body_bytes; +} + +} // namespace eager +} // namespace oneflow diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h new file mode 100644 index 000000000..cd9fcb309 --- /dev/null +++ b/oneflow/core/eager/eager_blob_object.h @@ -0,0 +1,59 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ +#define ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/eager/blob_object.h" +#include "oneflow/core/memory/memory_allocator.h" + +namespace oneflow { + +namespace eager { + +class EagerBlobObject : public BlobObject { + public: + EagerBlobObject(const EagerBlobObject&) = delete; + EagerBlobObject(EagerBlobObject&&) = delete; + EagerBlobObject(const std::shared_ptr<MemoryCase>& mem_case, DataType data_type) + : BlobObject(mem_case, data_type), blob_body_bytes_(0) {} + virtual ~EagerBlobObject() override = default; + + virtual BlobDesc* mut_blob_desc() override { return &blob_desc_; } + + virtual const Blob& blob() const override { return *blob_; } + virtual Blob* mut_blob() override { return blob_.get(); } + virtual Maybe<void> TryInitBlob() override; + + virtual void TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) override; + + private: + Maybe<void> InitBlob(); + + std::unique_ptr<Blob> blob_; + std::unique_ptr<char, std::function<void(char*)>> header_buffer_; + std::unique_ptr<char, std::function<void(char*)>> blob_dptr_; + std::size_t blob_body_bytes_; + MemoryAllocator non_pod_initer_; + + protected: + std::unique_ptr<RtBlobDesc> rt_blob_desc_; +}; + +} // namespace eager +} // namespace oneflow + +#endif // ONEFLOW_CORE_EAGER_EAGER_BLOB_OBJECT_H_ diff --git a/oneflow/core/eager/lazy_ref_blob_object.h b/oneflow/core/eager/lazy_ref_blob_object.h index 5d31f048d..fa83f82de 100644 --- a/oneflow/core/eager/lazy_ref_blob_object.h +++ b/oneflow/core/eager/lazy_ref_blob_object.h @@ -16,10 +16,6 @@ limitations under the License. #ifndef ONEFLOW_CORE_EAGER_LAZY_REF_BLOB_OBJECT_H_ #define ONEFLOW_CORE_EAGER_LAZY_REF_BLOB_OBJECT_H_ -#include "oneflow/core/vm/object.h" -#include "oneflow/core/register/blob_desc.h" -#include "oneflow/core/register/blob.h" -#include "oneflow/core/common/maybe.h" #include "oneflow/core/eager/blob_object.h" namespace oneflow { @@ -31,17 +27,23 @@ class LazyRefBlobObject : public BlobObject { LazyRefBlobObject(LazyRefBlobObject&&) = delete; LazyRefBlobObject(Blob* blob) : BlobObject(std::make_shared<MemoryCase>(blob->mem_case()), blob->data_type()) { - rt_blob_desc_.reset(new RtBlobDesc(blob_desc())); + const auto& rt_blob_desc = blob->blob_desc(); + blob_desc_ = BlobDesc(rt_blob_desc.body(), rt_blob_desc.is_tensor_list(), + rt_blob_desc.is_body_disabled(), rt_blob_desc.is_dynamic()); ref_blob_ = blob; } virtual ~LazyRefBlobObject() override = default; + virtual BlobDesc* mut_blob_desc() override { UNIMPLEMENTED(); } + virtual const Blob& blob() const override { return *ref_blob_; } virtual Blob* mut_blob() override { return ref_blob_; } - // TODO(daquexian): Separate LazyBlobObject and EagerBlobObject, - // remove "virtual xxx override { Unimplemented }" - virtual Maybe<void> TryInitBlob() override { return Error::Unimplemented(); }; + virtual void TryAllocateBlobBodyMemory(DeviceCtx* device_ctx) override{ + // do nothing + }; + + virtual Maybe<void> TryInitBlob() override { return Maybe<void>::Ok(); } private: Blob* ref_blob_ = nullptr; diff --git a/oneflow/core/eager/opkernel_instruction_type.cpp b/oneflow/core/eager/opkernel_instruction_type.cpp index 17beb825e..9b4e2d549 100644 --- a/oneflow/core/eager/opkernel_instruction_type.cpp +++ b/oneflow/core/eager/opkernel_instruction_type.cpp @@ -19,7 +19,7 @@ limitations under the License. #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/eager/opkernel_object.h" -#include "oneflow/core/eager/blob_object.h" +#include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/vm/object_wrapper.h" #include "oneflow/core/vm/string_object.h" #include "oneflow/core/vm/stream.msg.h" @@ -112,8 +112,8 @@ std::shared_ptr<MemoryCase> MakeMemCase(const DeviceType device_type, const int6 } template<typename T, typename CallbackT> -Maybe<void> ForEachIbnAndBlobObject(vm::Instruction* instruction, const T& args, - const CallbackT& Callback) { +Maybe<void> ForEachConstBnAndBlobObject(vm::Instruction* instruction, const T& args, + const CallbackT& Callback) { CHECK_EQ_OR_RETURN(args.ibn_size(), args.input_blob_size()); FOR_RANGE(int, i, 0, args.ibn_size()) { const auto* operand_ibn = instruction->operand_type(args.ibn(i)); @@ -129,8 +129,8 @@ Maybe<void> ForEachIbnAndBlobObject(vm::Instruction* instruction, const T& args, } template<typename T, typename CallbackT> -Maybe<void> ForEachObnAndBlobObject(vm::Instruction* instruction, const T& args, - const CallbackT& Callback) { +Maybe<void> ForEachMutBnAndBlobObject(vm::Instruction* instruction, const T& args, + const CallbackT& Callback) { CHECK_EQ_OR_RETURN(args.obn_size(), args.output_blob_size()); FOR_RANGE(int, i, 0, args.obn_size()) { const auto* operand_obn = instruction->operand_type(args.obn(i)); @@ -138,7 +138,7 @@ Maybe<void> ForEachObnAndBlobObject(vm::Instruction* instruction, const T& args, const std::string& bn_in_op = JUST(operand_obn->template Get<vm::StringObject>())->str(); auto* operand_output_blob = instruction->mut_operand_type(args.output_blob(i)); CHECK_NOTNULL_OR_RETURN(operand_output_blob) << "obn: " << bn_in_op; - auto* blob_object = operand_output_blob->template Mut<BlobObject>(); + auto* blob_object = operand_output_blob->template Mut<EagerBlobObject>(); JUST(Callback(bn_in_op, blob_object)); } CHECK_EQ_OR_RETURN(args.mut2_obn_size(), args.mut2_output_blob_size()); @@ -148,7 +148,7 @@ Maybe<void> ForEachObnAndBlobObject(vm::Instruction* instruction, const T& args, const std::string& bn_in_op = JUST(operand_obn->template Get<vm::StringObject>())->str(); auto* operand_output_blob = instruction->mut_operand_type(args.mut2_output_blob(i)); CHECK_NOTNULL_OR_RETURN(operand_output_blob) << "obn: " << bn_in_op; - auto* blob_object = operand_output_blob->template Mut<BlobObject>(); + auto* blob_object = operand_output_blob->template Mut<EagerBlobObject>(); JUST(Callback(bn_in_op, blob_object)); } return Maybe<void>::Ok(); @@ -160,9 +160,9 @@ Maybe<void> MakeBlobDesc4BnInOp(vm::Instruction* instruction, const T& args, const auto& obn2blob_desc = std::make_shared<HashMap<std::string, BlobDesc*>>(); { HashSet<const BlobDesc*> out_blob_descs; - JUST(ForEachObnAndBlobObject( + JUST(ForEachMutBnAndBlobObject( instruction, args, - [&](const std::string& bn_in_op, BlobObject* blob_object) -> Maybe<void> { + [&](const std::string& bn_in_op, EagerBlobObject* blob_object) -> Maybe<void> { auto* blob_desc = blob_object->mut_blob_desc(); CHECK_OR_RETURN(out_blob_descs.insert(blob_desc).second); CHECK_OR_RETURN(obn2blob_desc->emplace(bn_in_op, blob_desc).second); @@ -170,7 +170,7 @@ Maybe<void> MakeBlobDesc4BnInOp(vm::Instruction* instruction, const T& args, })); } const auto& ibn2blob_desc = std::make_shared<HashMap<std::string, const BlobDesc*>>(); - JUST(ForEachIbnAndBlobObject( + JUST(ForEachConstBnAndBlobObject( instruction, args, [&](const std::string& bn_in_op, const BlobObject& blob_object) -> Maybe<void> { CHECK_OR_RETURN(ibn2blob_desc->emplace(bn_in_op, &blob_object.blob_desc()).second); @@ -190,16 +190,17 @@ template<typename T> Maybe<void> MakeBlob4BnInOp( vm::Instruction* instruction, const T& args, std::function<Blob*(const std::string&)>* Blob4BnInOp, - const std::function<bool(const std::string&, const BlobObject&)>& FilterOutBlob) { + const std::function<bool(const std::string&, const EagerBlobObject&)>& FilterOutBlob) { const auto& obn2blob = std::make_shared<HashMap<std::string, Blob*>>(); - JUST(ForEachObnAndBlobObject( - instruction, args, [&](const std::string& bn_in_op, BlobObject* blob_object) -> Maybe<void> { + JUST(ForEachMutBnAndBlobObject( + instruction, args, + [&](const std::string& bn_in_op, EagerBlobObject* blob_object) -> Maybe<void> { if (!FilterOutBlob(bn_in_op, *blob_object)) { return Maybe<void>::Ok(); } CHECK_OR_RETURN(obn2blob->emplace(bn_in_op, blob_object->mut_blob()).second); return Maybe<void>::Ok(); })); const auto& ibn2blob = std::make_shared<HashMap<std::string, const Blob*>>(); - JUST(ForEachIbnAndBlobObject( + JUST(ForEachConstBnAndBlobObject( instruction, args, [&](const std::string& bn_in_op, const BlobObject& blob_object) -> Maybe<void> { CHECK_OR_RETURN(ibn2blob->emplace(bn_in_op, &blob_object.blob()).second); @@ -219,7 +220,7 @@ template<typename T> Maybe<void> MakeBlob4BnInOp(vm::Instruction* instruction, const T& args, std::function<Blob*(const std::string&)>* Blob4BnInOp) { return MakeBlob4BnInOp(instruction, args, Blob4BnInOp, - [](const std::string&, const BlobObject&) { return true; }); + [](const std::string&, const EagerBlobObject&) { return true; }); } template<typename T> @@ -229,9 +230,9 @@ void InitOutputBlobObjects(vm::Instruction* instruction, const T& args, const auto& parallel_desc = instruction->parallel_desc(); CHECK(static_cast<bool>(parallel_desc)); if (rw_mutexed_object->has_object()) { - CHECK(rw_mutexed_object->Has<BlobObject>()); + CHECK(rw_mutexed_object->Has<EagerBlobObject>()); } else { - rw_mutexed_object->Init<BlobObject>(mem_case, data_type); + rw_mutexed_object->Init<EagerBlobObject>(mem_case, data_type); } }; FOR_RANGE(int, i, 0, args.output_blob_size()) { @@ -259,15 +260,16 @@ Maybe<void> CheckBlobParallel(vm::Instruction* instruction, const T& args, return symbol_storage_ptr->GetPtr(symbol_id).get(); }; - JUST(ForEachObnAndBlobObject( - instruction, args, [&](const std::string& bn_in_op, BlobObject* blob_object) -> Maybe<void> { + JUST(ForEachMutBnAndBlobObject( + instruction, args, + [&](const std::string& bn_in_op, EagerBlobObject* blob_object) -> Maybe<void> { const auto* parallel_desc = JUST(ParallelDesc4BnInOp(bn_in_op)); if (parallel_desc == nullptr) { return Maybe<void>::Ok(); } JUST(blob_object->CheckMemCase(*parallel_desc, instruction->stream().machine_id())); return Maybe<void>::Ok(); })); - JUST(ForEachIbnAndBlobObject( + JUST(ForEachConstBnAndBlobObject( instruction, args, [&](const std::string& bn_in_op, const BlobObject& blob_object) -> Maybe<void> { const auto* parallel_desc = JUST(ParallelDesc4BnInOp(bn_in_op)); @@ -301,13 +303,13 @@ Maybe<void> OpKernelInfer(OpKernelObject* opkernel_obj, vm::Instruction* instruc JUST(opkernel_obj->ResetOpAndKernel(*op_node_signature, ¶llel_ctx, BlobDesc4BnInOp, instruction->parallel_desc().get())); JUST(CheckBlobParallel(instruction, args, op_node_signature)); - JUST(ForEachObnAndBlobObject(instruction, args, - [](const std::string& obn, BlobObject* blob_object) -> Maybe<void> { - return blob_object->TryInitBlob(); - })); + JUST(ForEachMutBnAndBlobObject( + instruction, args, [](const std::string& obn, EagerBlobObject* blob_object) -> Maybe<void> { + return blob_object->TryInitBlob(); + })); std::function<Blob*(const std::string&)> Blob4BnInOp; Shape empty_shape{}; - const auto& FilterOutBlob = [&](const std::string& bn_in_op, const BlobObject& blob_object) { + const auto& FilterOutBlob = [&](const std::string& bn_in_op, const EagerBlobObject& blob_object) { return !(bn_in_op == "tmp_buffer_0" && blob_object.blob_desc().shape() == empty_shape); }; JUST(MakeBlob4BnInOp(instruction, args, &Blob4BnInOp, FilterOutBlob)); @@ -338,10 +340,10 @@ Maybe<void> OpKernelInfer(SystemOpKernelObject* opkernel_obj, vm::Instruction* i JUST(opkernel_obj->ResetKernel(*op_node_signature, ¶llel_ctx, BlobDesc4BnInOp, instruction->parallel_desc().get())); JUST(CheckBlobParallel(instruction, args, op_node_signature)); - JUST(ForEachObnAndBlobObject(instruction, args, - [](const std::string& obn, BlobObject* blob_object) -> Maybe<void> { - return blob_object->TryInitBlob(); - })); + JUST(ForEachMutBnAndBlobObject( + instruction, args, [](const std::string& obn, EagerBlobObject* blob_object) -> Maybe<void> { + return blob_object->TryInitBlob(); + })); std::function<Blob*(const std::string&)> Blob4BnInOp; JUST(MakeBlob4BnInOp(instruction, args, &Blob4BnInOp)); opkernel_obj->kernel().SystemForwardHeader(KernelCtx(), Blob4BnInOp); @@ -352,16 +354,17 @@ template<typename T> Maybe<void> OpKernelCompute(OpKernelObject* opkernel_obj, vm::Instruction* instruction, const T& args) { DeviceCtx* device_ctx = instruction->stream().device_ctx().get(); - JUST(ForEachObnAndBlobObject(instruction, args, - [&](const std::string&, BlobObject* blob_object) -> Maybe<void> { - blob_object->TryAllocateBlobBodyMemory(device_ctx); - return Maybe<void>::Ok(); - })); + JUST(ForEachMutBnAndBlobObject( + instruction, args, [&](const std::string&, EagerBlobObject* blob_object) -> Maybe<void> { + blob_object->TryAllocateBlobBodyMemory(device_ctx); + return Maybe<void>::Ok(); + })); std::shared_ptr<user_op::OpKernelState> new_state; { std::function<Blob*(const std::string&)> Blob4BnInOp; Shape empty_shape{}; - const auto& FilterOutBlob = [&](const std::string& bn_in_op, const BlobObject& blob_object) { + const auto& FilterOutBlob = [&](const std::string& bn_in_op, + const EagerBlobObject& blob_object) { return !(bn_in_op == "tmp_buffer_0" && blob_object.blob_desc().shape() == empty_shape); }; JUST(MakeBlob4BnInOp(instruction, args, &Blob4BnInOp, FilterOutBlob)); @@ -376,11 +379,11 @@ Maybe<void> OpKernelCompute(OpKernelObject* opkernel_obj, vm::Instruction* instr Maybe<void> OpKernelCompute(SystemOpKernelObject* opkernel_obj, vm::Instruction* instruction, const StatelessCallOpKernelInstrOperand& args) { DeviceCtx* device_ctx = instruction->stream().device_ctx().get(); - JUST(ForEachObnAndBlobObject(instruction, args, - [&](const std::string&, BlobObject* blob_object) -> Maybe<void> { - blob_object->TryAllocateBlobBodyMemory(device_ctx); - return Maybe<void>::Ok(); - })); + JUST(ForEachMutBnAndBlobObject( + instruction, args, [&](const std::string&, EagerBlobObject* blob_object) -> Maybe<void> { + blob_object->TryAllocateBlobBodyMemory(device_ctx); + return Maybe<void>::Ok(); + })); KernelCtx kernel_ctx; kernel_ctx.device_ctx = device_ctx; std::function<Blob*(const std::string&)> Blob4BnInOp; -- GitLab