From 0a3adeaea932a926cec6ec2b5a374ced3b9c3bb2 Mon Sep 17 00:00:00 2001 From: leaves-zwx <kunta0932@gmail.com> Date: Mon, 28 May 2018 13:54:35 +0800 Subject: [PATCH] Dev concat forward field (#906) * remove copy by itor * use already existed copy method * delete forward field override and use CopyField instead * remove needless --- oneflow/core/kernel/concat_kernel.cpp | 18 ------------------ oneflow/core/kernel/concat_kernel.h | 6 ------ oneflow/core/kernel/kernel.cpp | 4 ++++ 3 files changed, 4 insertions(+), 24 deletions(-) diff --git a/oneflow/core/kernel/concat_kernel.cpp b/oneflow/core/kernel/concat_kernel.cpp index 94a53178d..19e05983f 100644 --- a/oneflow/core/kernel/concat_kernel.cpp +++ b/oneflow/core/kernel/concat_kernel.cpp @@ -24,24 +24,6 @@ void ConcatKernel<device_type, T>::ForwardDataContent( CHECK_EQ(out_col_offset, out_col_num); } -template<DeviceType device_type, typename T> -void ConcatKernel<device_type, T>::ForwardDataId( - const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - DataIdIterator input_it(BnInOp2Blob, &this->op_attribute().input_bns(), - this->op_conf().concat_conf().axis()); - DataIdIterator output_it(BnInOp2Blob, &this->op_attribute().output_bns(), 0); - CopyFromIterToIter<device_type>(ctx.device_ctx, input_it, output_it); -} - -template<DeviceType device_type, typename T> -void ConcatKernel<device_type, T>::ForwardColNum( - const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - ColNumIterator input_it(BnInOp2Blob, &this->op_attribute().input_bns(), - this->op_conf().concat_conf().axis()); - ColNumIterator output_it(BnInOp2Blob, &this->op_attribute().output_bns(), 0); - CopyFromIterToIter<device_type>(ctx.device_ctx, input_it, output_it); -} - template<DeviceType device_type, typename T> void ConcatKernel<device_type, T>::BackwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { diff --git a/oneflow/core/kernel/concat_kernel.h b/oneflow/core/kernel/concat_kernel.h index 9717d9c90..afc865fc1 100644 --- a/oneflow/core/kernel/concat_kernel.h +++ b/oneflow/core/kernel/concat_kernel.h @@ -20,12 +20,6 @@ class ConcatKernel final : public KernelIf<device_type> { void ForwardDataContent(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - void ForwardDataId(const KernelCtx& ctx, - std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - - void ForwardColNum(const KernelCtx& ctx, - std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - void BackwardDataContent(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; diff --git a/oneflow/core/kernel/kernel.cpp b/oneflow/core/kernel/kernel.cpp index 604768559..b96fe1a52 100644 --- a/oneflow/core/kernel/kernel.cpp +++ b/oneflow/core/kernel/kernel.cpp @@ -126,6 +126,10 @@ void KernelIf<device_type>::CopyField(DeviceCtx* ctx, if (from_bns.size() == 1) { const Blob* in_blob = BnInOp2Blob(from_bns[0]); CopyField(ctx, BnInOp2Blob, in_blob, to_bns, Copy); + } else if (to_bns.size() == 1) { + Blob* in_blob = BnInOp2Blob(from_bns[0]); + Blob* out_blob = BnInOp2Blob(to_bns[0]); + (out_blob->*Copy)(ctx, in_blob); } else { CHECK_EQ(from_bns.size(), to_bns.size()); FOR_RANGE(size_t, i, 0, from_bns.size()) { -- GitLab