diff --git a/oneflow/core/eager/eager_blob_object.cpp b/oneflow/core/eager/eager_blob_object.cpp index 8d21a8c95395beb1f39f0306494cfa577e63c501..811a99f961c6a792071ee8c52379838ee0596d42 100644 --- a/oneflow/core/eager/eager_blob_object.cpp +++ b/oneflow/core/eager/eager_blob_object.cpp @@ -52,14 +52,9 @@ Maybe<void> EagerBlobObject::TryInitBlob() { Maybe<void> EagerBlobObject::InitBlob() { CHECK_NE_OR_RETURN(blob_desc_.data_type(), DataType::kInvalidDataType); - { - header_buffer_.reset(); - int64_t header_byte_size = blob_desc_.AlignedByteSizeOfBlobHeader(); - const auto& FreeHeader = [header_byte_size](char* dptr) { std::free(dptr); }; - char* ptr = reinterpret_cast<char*>(std::malloc(header_byte_size)); - header_buffer_ = std::unique_ptr<char, std::function<void(char*)>>(ptr, FreeHeader); - } - blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer_.get(), nullptr)); + char* header_buffer = + reinterpret_cast<char*>(const_cast<int64_t*>(blob_desc_.shape().dim_vec().data())); + blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer, nullptr)); return Maybe<void>::Ok(); } diff --git a/oneflow/core/eager/eager_blob_object.h b/oneflow/core/eager/eager_blob_object.h index c888a8e95b6a787f1e24282ccc5584effc508544..69c1e27afe2eb1fcb0047dc45f824d300bfe6d95 100644 --- a/oneflow/core/eager/eager_blob_object.h +++ b/oneflow/core/eager/eager_blob_object.h @@ -51,7 +51,6 @@ class EagerBlobObject final : public BlobObject { ~EagerBlobObject() override { non_pod_initer_.reset(); tensor_buffer_.reset(); - header_buffer_.reset(); blob_.reset(); } @@ -79,7 +78,6 @@ class EagerBlobObject final : public BlobObject { private: std::unique_ptr<Blob> blob_; - std::unique_ptr<char, std::function<void(char*)>> header_buffer_; std::shared_ptr<TensorBuffer> tensor_buffer_; std::size_t blob_body_bytes_; std::unique_ptr<MemoryAllocator> non_pod_initer_; diff --git a/oneflow/core/kernel/input_kernel.cpp b/oneflow/core/kernel/input_kernel.cpp index 242f47e465e894f100172fb107b9efe077eeee1c..e5af04b8090c5debe6a3a90f4c342dc2010c3e8f 100644 --- a/oneflow/core/kernel/input_kernel.cpp +++ b/oneflow/core/kernel/input_kernel.cpp @@ -13,7 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/job_instance.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { @@ -30,7 +34,21 @@ class InputKernel final : public KernelIf<device_type> { void Forward(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override {} void ForwardDataContent(const KernelCtx& ctx, - std::function<Blob*(const std::string&)> BnInOp2Blob) const override {} + std::function<Blob*(const std::string&)> BnInOp2Blob) const override { + if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + const auto& job_name = this->job_desc().job_name(); + const auto& op_name = this->op_conf().name(); + auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get(); + auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name)); + std::shared_ptr<JobInstance> job_instance; + BufferStatus buffer_status = buffer->TryReceive(&job_instance); + CHECK_NE(buffer_status, kBufferStatusEmpty); + if (buffer_status == kBufferStatusSuccess) { + OfBlob ofblob(ctx.device_ctx, BnInOp2Blob("out")); + job_instance->PushBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name); + } + } + } void ForwardHeader(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override {} }; diff --git a/oneflow/core/kernel/output_kernel.cpp b/oneflow/core/kernel/output_kernel.cpp index 231948abb0d938a9309d09d35f7f30f0ecf2b340..de9a7599de4d88d8f47f8a4d017fade4d5d60f95 100644 --- a/oneflow/core/kernel/output_kernel.cpp +++ b/oneflow/core/kernel/output_kernel.cpp @@ -14,19 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/output_kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/job_instance.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { template<DeviceType device_type> void OutputKernel<device_type>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in")); + if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + const auto& job_name = this->job_desc().job_name(); + const auto& op_name = this->op_conf().name(); + auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get(); + auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name)); + std::shared_ptr<JobInstance> job_instance; + BufferStatus buffer_status = buffer->TryReceive(&job_instance); + CHECK_NE(buffer_status, kBufferStatusEmpty); + if (buffer_status == kBufferStatusSuccess) { + OfBlob ofblob(ctx.device_ctx, BnInOp2Blob("in")); + job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name); + } + } else { + BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in")); + } } template<DeviceType device_type> void OutputKernel<device_type>::ForwardHeader( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in")); + if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + // Do nothing. + } else { + BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in")); + } } ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kOutputConf, OutputKernel); diff --git a/oneflow/core/kernel/return_kernel.cpp b/oneflow/core/kernel/return_kernel.cpp index 77da8acd33983fc64268ca503854742bfa2a0b8c..4b51e93d8ab716298286674ac6034bd348c5b0ba 100644 --- a/oneflow/core/kernel/return_kernel.cpp +++ b/oneflow/core/kernel/return_kernel.cpp @@ -14,20 +14,41 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/kernel/return_kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/job_instance.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { template<DeviceType device_type> void ReturnKernel<device_type>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in")); - ctx.device_ctx->SyncDevice(); + if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + const auto& job_name = this->job_desc().job_name(); + const auto& op_name = this->op_conf().name(); + auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get(); + auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name)); + std::shared_ptr<JobInstance> job_instance; + BufferStatus buffer_status = buffer->TryReceive(&job_instance); + CHECK_NE(buffer_status, kBufferStatusEmpty); + if (buffer_status == kBufferStatusSuccess) { + OfBlob ofblob(ctx.device_ctx, BnInOp2Blob("in")); + job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name); + } + } else { + BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in")); + ctx.device_ctx->SyncDevice(); + } } template<DeviceType device_type> void ReturnKernel<device_type>::ForwardHeader( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in")); + if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + // Do nothing. + } else { + BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in")); + } } ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReturnConf, ReturnKernel); diff --git a/oneflow/core/kernel/wait_and_send_ids_kernel.cpp b/oneflow/core/kernel/wait_and_send_ids_kernel.cpp index 63583ff6022c10098941912e57449f64c1efbd53..1bcd99d9a5c1415a2299764057c2805173cf0e4f 100644 --- a/oneflow/core/kernel/wait_and_send_ids_kernel.cpp +++ b/oneflow/core/kernel/wait_and_send_ids_kernel.cpp @@ -13,7 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "oneflow/core/kernel/wait_and_send_ids_kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/job_instance.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { @@ -24,11 +28,25 @@ void WaitAndSendIdsKernel<T>::ForwardDataContent( auto* status = static_cast<WaitAndSendIdsStatus*>(ctx.other); const auto& conf = this->op_conf().wait_and_send_ids_conf(); if (status->out_idx_ >= status->out_num_) { - status->buffer_status_ = - Global<BufferMgr<int64_t>>::Get()->Get(conf.wait_buffer_name())->Receive(&status->in_id_); - if (status->buffer_status_ == kBufferStatusErrorClosed) { return; } - status->out_idx_ = 0; - status->out_num_ = conf.id_list(status->in_id_).value_size(); + if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + const auto& job_name = this->job_desc().job_name(); + auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get(); + auto* buffer = buffer_mgr->Get(GetSourceTickBufferName(job_name)); + status->in_id_ = 0; + { + std::shared_ptr<JobInstance> job_instance; + status->buffer_status_ = buffer->Receive(&job_instance); + } + if (status->buffer_status_ == kBufferStatusErrorClosed) { return; } + status->out_idx_ = 0; + status->out_num_ = 1; + } else { + auto* buffer_mgr = Global<BufferMgr<int64_t>>::Get(); + status->buffer_status_ = buffer_mgr->Get(conf.wait_buffer_name())->Receive(&status->in_id_); + if (status->buffer_status_ == kBufferStatusErrorClosed) { return; } + status->out_idx_ = 0; + status->out_num_ = conf.id_list(status->in_id_).value_size(); + } } *BnInOp2Blob("out")->mut_dptr<T>() = conf.id_list(status->in_id_).value(status->out_idx_); ++status->out_idx_;