diff --git a/oneflow/core/kernel/concat_kernel.cpp b/oneflow/core/kernel/concat_kernel.cpp index 94a53178d444feeb70fc78c7c2e0419fef4ac0f7..19e05983f0d854ebc4db88b2c3a9ce48fd80b452 100644 --- a/oneflow/core/kernel/concat_kernel.cpp +++ b/oneflow/core/kernel/concat_kernel.cpp @@ -24,24 +24,6 @@ void ConcatKernel<device_type, T>::ForwardDataContent( CHECK_EQ(out_col_offset, out_col_num); } -template<DeviceType device_type, typename T> -void ConcatKernel<device_type, T>::ForwardDataId( - const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - DataIdIterator input_it(BnInOp2Blob, &this->op_attribute().input_bns(), - this->op_conf().concat_conf().axis()); - DataIdIterator output_it(BnInOp2Blob, &this->op_attribute().output_bns(), 0); - CopyFromIterToIter<device_type>(ctx.device_ctx, input_it, output_it); -} - -template<DeviceType device_type, typename T> -void ConcatKernel<device_type, T>::ForwardColNum( - const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - ColNumIterator input_it(BnInOp2Blob, &this->op_attribute().input_bns(), - this->op_conf().concat_conf().axis()); - ColNumIterator output_it(BnInOp2Blob, &this->op_attribute().output_bns(), 0); - CopyFromIterToIter<device_type>(ctx.device_ctx, input_it, output_it); -} - template<DeviceType device_type, typename T> void ConcatKernel<device_type, T>::BackwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { diff --git a/oneflow/core/kernel/concat_kernel.h b/oneflow/core/kernel/concat_kernel.h index 9717d9c90fac9f083965cdbf05966d727f3aeb80..afc865fc12719b16e7fb3f28c0d2a8de7746f4f4 100644 --- a/oneflow/core/kernel/concat_kernel.h +++ b/oneflow/core/kernel/concat_kernel.h @@ -20,12 +20,6 @@ class ConcatKernel final : public KernelIf<device_type> { void ForwardDataContent(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - void ForwardDataId(const KernelCtx& ctx, - std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - - void ForwardColNum(const KernelCtx& ctx, - std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - void BackwardDataContent(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; }; diff --git a/oneflow/core/kernel/kernel.cpp b/oneflow/core/kernel/kernel.cpp index 60476855958558f0f3f90ceb7436dab042e6cef7..b96fe1a5270d5bcd39a6617bf4cbb0ae253f8e28 100644 --- a/oneflow/core/kernel/kernel.cpp +++ b/oneflow/core/kernel/kernel.cpp @@ -126,6 +126,10 @@ void KernelIf<device_type>::CopyField(DeviceCtx* ctx, if (from_bns.size() == 1) { const Blob* in_blob = BnInOp2Blob(from_bns[0]); CopyField(ctx, BnInOp2Blob, in_blob, to_bns, Copy); + } else if (to_bns.size() == 1) { + Blob* in_blob = BnInOp2Blob(from_bns[0]); + Blob* out_blob = BnInOp2Blob(to_bns[0]); + (out_blob->*Copy)(ctx, in_blob); } else { CHECK_EQ(from_bns.size(), to_bns.size()); FOR_RANGE(size_t, i, 0, from_bns.size()) {