Skip to content
Snippets Groups Projects
Commit 0a3adeae authored by leaves-zwx's avatar leaves-zwx Committed by Will Zhang
Browse files

Dev concat forward field (#906)

* remove copy by itor

* use already existed copy method

* delete forward field override and use CopyField instead

* remove needless
parent 68cccab8
No related branches found
No related tags found
No related merge requests found
......@@ -24,24 +24,6 @@ void ConcatKernel<device_type, T>::ForwardDataContent(
CHECK_EQ(out_col_offset, out_col_num);
}
template<DeviceType device_type, typename T>
void ConcatKernel<device_type, T>::ForwardDataId(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
DataIdIterator input_it(BnInOp2Blob, &this->op_attribute().input_bns(),
this->op_conf().concat_conf().axis());
DataIdIterator output_it(BnInOp2Blob, &this->op_attribute().output_bns(), 0);
CopyFromIterToIter<device_type>(ctx.device_ctx, input_it, output_it);
}
template<DeviceType device_type, typename T>
void ConcatKernel<device_type, T>::ForwardColNum(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
ColNumIterator input_it(BnInOp2Blob, &this->op_attribute().input_bns(),
this->op_conf().concat_conf().axis());
ColNumIterator output_it(BnInOp2Blob, &this->op_attribute().output_bns(), 0);
CopyFromIterToIter<device_type>(ctx.device_ctx, input_it, output_it);
}
template<DeviceType device_type, typename T>
void ConcatKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
......
......@@ -20,12 +20,6 @@ class ConcatKernel final : public KernelIf<device_type> {
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void ForwardDataId(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void ForwardColNum(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void BackwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
......
......@@ -126,6 +126,10 @@ void KernelIf<device_type>::CopyField(DeviceCtx* ctx,
if (from_bns.size() == 1) {
const Blob* in_blob = BnInOp2Blob(from_bns[0]);
CopyField(ctx, BnInOp2Blob, in_blob, to_bns, Copy);
} else if (to_bns.size() == 1) {
Blob* in_blob = BnInOp2Blob(from_bns[0]);
Blob* out_blob = BnInOp2Blob(to_bns[0]);
(out_blob->*Copy)(ctx, in_blob);
} else {
CHECK_EQ(from_bns.size(), to_bns.size());
FOR_RANGE(size_t, i, 0, from_bns.size()) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment