Skip to content
Snippets Groups Projects
Commit 59dc6433 authored by Li Xinqi's avatar Li Xinqi Committed by Will Zhang
Browse files

A BnInOp2Blob wrapper is better than GetOutDiffBlob (#889)

* A BnInOp2Blob wrapper is better than GetOutDiffBlob

* sv

* format
parent 4b4881cf
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,7 @@ void ConvKernelIf<device_type, T>::ForwardDataContent(
template<DeviceType device_type, typename T>
void ConvKernelIf<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* conv_out_diff = this->GetOutDiffBlob(BnInOp2Blob);
const Blob* conv_out_diff = BnInOp2Blob("out_diff");
if (this->template GetValFromCustomizedOpConf<bool>("use_bias")) {
BiasBackward(ctx.device_ctx, conv_out_diff, BnInOp2Blob("bias_diff"), BnInOp2Blob);
}
......
......@@ -29,7 +29,7 @@ template<DeviceType device_type, typename T>
void FullyConnectedKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
const Blob* out_diff_blob = this->GetOutDiffBlob(BnInOp2Blob);
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
Blob* in_diff_blob = BnInOp2Blob("in_diff");
......
......@@ -50,7 +50,15 @@ void Kernel::Forward(const KernelCtx& ctx,
void Kernel::Backward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BackwardActivate(ctx, BnInOp2Blob);
BackwardDataContent(ctx, BnInOp2Blob);
BackwardDataContent(ctx, [BnInOp2Blob, this](const std::string& bn) -> Blob* {
const PbRpf<std::string> odbns = this->op_attribute().output_diff_bns();
if (this->GetActivationType() != ActivationType::kNone) {
CHECK_EQ(odbns.size(), 1);
if (bn == odbns[0]) { return BnInOp2Blob("activation_buf"); }
}
return BnInOp2Blob(bn);
});
if (kernel_conf_.need_do_data_id()) { BackwardDataId(ctx, BnInOp2Blob); }
if (kernel_conf_.need_do_col_num()) { BackwardColNum(ctx, BnInOp2Blob); }
}
......@@ -132,18 +140,6 @@ ActivationType KernelIfWithActivation<device_type, T>::GetActivationType() const
return static_cast<ActivationType>(this->GetEnumFromCustomizedOpConf("activation"));
}
template<DeviceType device_type, typename T>
const Blob* KernelIfWithActivation<device_type, T>::GetOutDiffBlob(
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (this->GetActivationType() != ActivationType::kNone) {
return BnInOp2Blob("activation_buf");
} else {
const PbRpf<std::string> odbns = this->op_attribute().output_diff_bns();
CHECK_EQ(odbns.size(), 1);
return BnInOp2Blob(odbns[0]);
}
}
template<DeviceType device_type, typename T>
void KernelIfWithActivation<device_type, T>::ForwardActivate(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
......
......@@ -44,6 +44,8 @@ class Kernel {
const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {}
virtual ActivationType GetActivationType() const { return ActivationType::kNone; }
virtual void Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const;
virtual void ForwardDataContent(const KernelCtx& ctx,
......@@ -152,8 +154,7 @@ class KernelIfWithActivation : virtual public KernelIf<device_type> {
protected:
KernelIfWithActivation() = default;
ActivationType GetActivationType() const;
const Blob* GetOutDiffBlob(std::function<Blob*(const std::string&)> BnInOp2Blob) const;
ActivationType GetActivationType() const override;
void ForwardActivate(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void BackwardActivate(const KernelCtx& ctx,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment