Skip to content
Snippets Groups Projects
Commit e0a2a6df authored by guo ran's avatar guo ran Committed by Shenghang Tsai
Browse files

prelu (#2086)

* prelu

* fix

* fix
parent 560bd2f1
No related branches found
No related tags found
No related merge requests found
Showing
with 612 additions and 181 deletions
#include "oneflow/core/job_completer/autovar.h"
namespace oneflow {
namespace {
void GenerateInputVarOpConf(
const Operator& op, std::vector<OperatorConf>* op_confs,
const std::function<const BlobDesc&(const std::string&)>& BlobDesc4ModelBn) {
CHECK(op.op_conf().has_prelu_conf());
OperatorConf prelu_conf(op.op_conf());
auto* mut_conf = prelu_conf.mutable_prelu_conf();
const auto& conf = op.op_conf().prelu_conf();
if (!conf.has_alpha()) {
OperatorConf alpha_var_op =
GenerateVariableOpConf(BlobDesc4ModelBn("alpha"), op.op_name() + "-alpha", "alpha");
alpha_var_op.mutable_variable_conf()->mutable_initializer()->mutable_constant_conf()->set_value(
conf.alpha_init());
op_confs->push_back(alpha_var_op);
mut_conf->set_alpha(alpha_var_op.name() + "/out");
}
op_confs->push_back(prelu_conf);
}
} // namespace
REGISTER_OP_INPUT_VAR(OperatorConf::kPreluConf, &GenerateInputVarOpConf);
} // namespace oneflow
#include "oneflow/core/job_completer/autograd.h"
namespace oneflow {
namespace {
void GenerateBackwardOpConf(
const Operator& op, std::vector<OperatorConf>* op_confs,
const std::function<LogicalBlobId*(const std::string&)>& DiffLbi4BnInOp) {
CHECK(op.op_conf().has_prelu_conf());
const auto& conf = op.op_conf().prelu_conf();
if (DiffLbi4BnInOp("in") != nullptr) {
OperatorConf prelu_data_grad_op;
prelu_data_grad_op.set_name(op.op_name() + "_data_grad");
PReluDataGradOpConf* prelu_data_grad_op_conf =
prelu_data_grad_op.mutable_prelu_data_grad_conf();
prelu_data_grad_op_conf->set_dy(GenLogicalBlobName(*DiffLbi4BnInOp("out")));
prelu_data_grad_op_conf->set_x(GenLogicalBlobName(op.BnInOp2Lbi("in")));
prelu_data_grad_op_conf->set_alpha(GenLogicalBlobName(op.BnInOp2Lbi("alpha")));
prelu_data_grad_op_conf->set_data_format(conf.data_format());
prelu_data_grad_op_conf->set_channel_shared(conf.channel_shared());
prelu_data_grad_op_conf->set_dx("dx");
op_confs->push_back(prelu_data_grad_op);
DiffLbi4BnInOp("in")->set_op_name(prelu_data_grad_op.name());
DiffLbi4BnInOp("in")->set_blob_name("dx");
}
if (DiffLbi4BnInOp("alpha") != nullptr) {
OperatorConf prelu_alpha_grad_op;
prelu_alpha_grad_op.set_name(op.op_name() + "_alpha_grad");
PReluAlphaGradOpConf* prelu_alpha_grad_op_conf =
prelu_alpha_grad_op.mutable_prelu_alpha_grad_conf();
prelu_alpha_grad_op_conf->set_dy(GenLogicalBlobName(*DiffLbi4BnInOp("out")));
prelu_alpha_grad_op_conf->set_x(GenLogicalBlobName(op.BnInOp2Lbi("in")));
prelu_alpha_grad_op_conf->set_data_format(conf.data_format());
prelu_alpha_grad_op_conf->set_channel_shared(conf.channel_shared());
prelu_alpha_grad_op_conf->set_alpha_grad("alpha_grad");
op_confs->push_back(prelu_alpha_grad_op);
DiffLbi4BnInOp("alpha")->set_op_name(prelu_alpha_grad_op.name());
DiffLbi4BnInOp("alpha")->set_blob_name("alpha_grad");
}
}
} // namespace
REGISTER_OP_GRAD(OperatorConf::kPreluConf, &GenerateBackwardOpConf);
} // namespace oneflow
......@@ -141,7 +141,7 @@ message RecordLoadKernelConf {
required int64 device_piece_size = 1;
}
message PReluKernelConf {
message PReluAlphaGradKernelConf {
repeated int32 perm = 1;
}
......@@ -198,7 +198,7 @@ message KernelConf {
GatherKernelConf gather_conf = 406;
VariableKernelConf variable_conf = 407;
RecordLoadKernelConf record_load_conf = 408;
PReluKernelConf prelu_conf = 409;
PReluAlphaGradKernelConf prelu_alpha_grad_conf = 409;
ConvFilterGradKernelConf conv_filter_grad_conf = 410;
ConvDataGradKernelConf conv_data_grad_conf = 411;
ShapeElemCntKernelConf shape_elem_cnt_conf = 412;
......
#include "oneflow/core/kernel/prelu_alpha_grad_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void PReluAlphaGradKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* alpha_grad_blob = BnInOp2Blob("alpha_grad");
Memset<device_type>(ctx.device_ctx, alpha_grad_blob->mut_dptr<T>(), 0,
alpha_grad_blob->ByteSizeOfDataContentField());
PReluAlphaGradKernelUtil<device_type, T>::Compute(
ctx, this->op_conf().prelu_alpha_grad_conf(),
this->kernel_conf().prelu_alpha_grad_conf().perm(), BnInOp2Blob("x"), BnInOp2Blob("dy"),
BnInOp2Blob("bw_buf"), BnInOp2Blob("alpha_grad_buf"), alpha_grad_blob);
}
template<typename T>
struct PReluAlphaGradKernelUtil<DeviceType::kCPU, T> {
static void Compute(const KernelCtx& ctx, const PReluAlphaGradOpConf& conf,
const PbRf<int32_t>& permutation, const Blob* x_blob, const Blob* dy_blob,
Blob* bw_buf_blob, Blob* alpha_grad_buf_blob, Blob* alpha_grad_blob) {
const T* x = x_blob->dptr<T>();
const T* dy = dy_blob->dptr<T>();
T* alpha_grad_dptr = alpha_grad_blob->mut_dptr<T>();
const int64_t elem_cnt = x_blob->shape().elem_cnt();
if (conf.data_format() == "channels_first") {
const int64_t channel_num = x_blob->shape().At(1);
const int64_t alpha_channel_num = conf.channel_shared() ? channel_num : 1;
const int64_t area = x_blob->shape().Count(2);
FOR_RANGE(int64_t, i, 0, elem_cnt) {
if (x[i] <= 0) {
int64_t c = (i / area) % channel_num / alpha_channel_num;
alpha_grad_dptr[c] += dy[i] * x[i];
}
}
} else if (conf.data_format() == "channels_last") {
const int64_t channel_num = x_blob->shape().At(x_blob->shape().NumAxes() - 1);
const int64_t alpha_channel_num = conf.channel_shared() ? channel_num : 1;
FOR_RANGE(int64_t, i, 0, elem_cnt) {
if (x[i] <= 0) {
int64_t c = i % channel_num / alpha_channel_num;
alpha_grad_dptr[c] += dy[i] * x[i];
}
}
} else {
UNIMPLEMENTED();
}
}
};
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kPreluAlphaGradConf, PReluAlphaGradKernel,
FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#include "oneflow/core/kernel/prelu_alpha_grad_kernel.h"
namespace oneflow {
namespace {
template<typename T>
__global__ void PReluAlphaBackward(const int64_t elem_cnt, const T* x, const T* dy,
T* alpha_grad_buf_dptr) {
CUDA_1D_KERNEL_LOOP(i, elem_cnt) { alpha_grad_buf_dptr[i] = (x[i] <= 0) ? dy[i] * x[i] : 0; }
}
} // namespace
template<typename T>
struct PReluAlphaGradKernelUtil<DeviceType::kGPU, T> {
static void Compute(const KernelCtx& ctx, const PReluAlphaGradOpConf& conf,
const PbRf<int32_t>& permutation, const Blob* x_blob, const Blob* dy_blob,
Blob* bw_buf_blob, Blob* alpha_grad_buf_blob, Blob* alpha_grad_blob) {
const int64_t elem_cnt = dy_blob->shape().elem_cnt();
PReluAlphaBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, x_blob->dptr<T>(), dy_blob->dptr<T>(), alpha_grad_buf_blob->mut_dptr<T>());
if (conf.channel_shared()) {
KernelUtil<DeviceType::kGPU, T>::Sum(
ctx.device_ctx, elem_cnt, alpha_grad_buf_blob->dptr<T>(), alpha_grad_blob->mut_dptr<T>(),
bw_buf_blob->mut_dptr<T>(), bw_buf_blob->ByteSizeOfDataContentField());
} else {
KernelUtil<DeviceType::kGPU, T>::Transpose(
ctx.device_ctx, alpha_grad_buf_blob->shape().NumAxes(), alpha_grad_buf_blob->shape(),
bw_buf_blob->shape(), permutation, alpha_grad_buf_blob->shape().elem_cnt(),
alpha_grad_buf_blob->dptr<T>(), bw_buf_blob->mut_dptr<T>());
CHECK_EQ(elem_cnt, bw_buf_blob->shape().elem_cnt());
if (conf.data_format() == "channels_first") {
const int64_t channel_num = dy_blob->shape().At(1);
CHECK_EQ(channel_num, bw_buf_blob->shape().At(0));
KernelUtil<DeviceType::kGPU, T>::RowSum(
ctx.device_ctx, channel_num, bw_buf_blob->shape().Count(1), bw_buf_blob->dptr<T>(),
alpha_grad_blob->mut_dptr<T>(), alpha_grad_buf_blob->mut_dptr<T>(),
alpha_grad_buf_blob->ByteSizeOfDataContentField());
} else if (conf.data_format() == "channels_last") {
const int64_t channel_num = dy_blob->shape().At(x_blob->shape().NumAxes() - 1);
CHECK_EQ(channel_num, bw_buf_blob->shape().At(0));
KernelUtil<DeviceType::kGPU, T>::RowSum(
ctx.device_ctx, channel_num, bw_buf_blob->shape().Count(1), bw_buf_blob->dptr<T>(),
alpha_grad_blob->mut_dptr<T>(), alpha_grad_buf_blob->mut_dptr<T>(),
alpha_grad_buf_blob->ByteSizeOfDataContentField());
} else {
UNIMPLEMENTED();
}
}
}
};
#define INSTANTIATE_P_RELU_ALPHA_GRAD_KERNEL_UTIL(type_cpp, type_proto) \
template class PReluAlphaGradKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_P_RELU_ALPHA_GRAD_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_PRELU_ALPHA_GRAD_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_PRELU_ALPHA_GRAD_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class PReluAlphaGradKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(PReluAlphaGradKernel);
PReluAlphaGradKernel() = default;
~PReluAlphaGradKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename T>
struct PReluAlphaGradKernelUtil {
static void Compute(const KernelCtx& ctx, const PReluAlphaGradOpConf& conf,
const PbRf<int32_t>& permutation, const Blob* x_blob, const Blob* dy_blob,
Blob* bw_buf_blob, Blob* alpha_grad_buf_blob, Blob* alpha_grad_blob);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_PRELU_ALPHA_GRAD_KERNEL_H_
#include "oneflow/core/kernel/prelu_data_grad_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void PReluDataGradKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* dx_blob = BnInOp2Blob("dx");
if (dx_blob == nullptr) { return; }
Memset<device_type>(ctx.device_ctx, dx_blob->mut_dptr<T>(), 0,
dx_blob->ByteSizeOfDataContentField());
PReluDataGradKernelUtil<device_type, T>::Compute(ctx, this->op_conf().prelu_data_grad_conf(),
BnInOp2Blob("x"), BnInOp2Blob("alpha"),
BnInOp2Blob("dy"), dx_blob);
}
template<typename T>
struct PReluDataGradKernelUtil<DeviceType::kCPU, T> {
static void Compute(const KernelCtx& ctx, const PReluDataGradOpConf& conf, const Blob* x_blob,
const Blob* alpha_blob, const Blob* dy_blob, Blob* dx_blob) {
const T* x = x_blob->dptr<T>();
const T* alpha_dptr = alpha_blob->dptr<T>();
const T* dy = dy_blob->dptr<T>();
T* dx = dx_blob->mut_dptr<T>();
const int64_t elem_cnt = x_blob->shape().elem_cnt();
if (conf.data_format() == "channels_first") {
const int64_t channel_num = x_blob->shape().At(1);
const int64_t alpha_channel_num = conf.channel_shared() ? channel_num : 1;
const int64_t area = x_blob->shape().Count(2);
FOR_RANGE(int64_t, i, 0, elem_cnt) {
if (x[i] > 0) {
dx[i] = dy[i];
} else {
int64_t c = (i / area) % channel_num / alpha_channel_num;
dx[i] = alpha_dptr[c] * dy[i];
}
}
} else if (conf.data_format() == "channels_last") {
const int64_t channel_num = x_blob->shape().At(x_blob->shape().NumAxes() - 1);
const int64_t alpha_channel_num = conf.channel_shared() ? channel_num : 1;
FOR_RANGE(int64_t, i, 0, elem_cnt) {
if (x[i] > 0) {
dx[i] = dy[i];
} else {
int64_t c = i % channel_num / alpha_channel_num;
dx[i] = alpha_dptr[c] * dy[i];
}
}
} else {
UNIMPLEMENTED();
}
}
};
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kPreluDataGradConf, PReluDataGradKernel,
FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#include "oneflow/core/kernel/prelu_data_grad_kernel.h"
namespace oneflow {
namespace {
template<typename T>
__global__ void PReluDataBackward(const int64_t elem_cnt, const int64_t channel_num,
const int64_t area, const T* x, const T* alpha_dptr, const T* dy,
T* dx) {
CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
int64_t c = (i / area) % channel_num;
dx[i] = (x[i] <= 0) ? dy[i] * alpha_dptr[c] : dy[i];
}
}
} // namespace
template<typename T>
struct PReluDataGradKernelUtil<DeviceType::kGPU, T> {
static void Compute(const KernelCtx& ctx, const PReluDataGradOpConf& conf, const Blob* x_blob,
const Blob* alpha_blob, const Blob* dy_blob, Blob* dx_blob) {
const int64_t elem_cnt = dy_blob->shape().elem_cnt();
if (conf.channel_shared()) {
PReluDataBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, 1, 1, x_blob->dptr<T>(), alpha_blob->dptr<T>(), dy_blob->dptr<T>(),
dx_blob->mut_dptr<T>());
} else {
if (conf.data_format() == "channels_first") {
const int64_t channel_num = dy_blob->shape().At(1);
PReluDataBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, channel_num, dy_blob->shape().Count(2), x_blob->dptr<T>(),
alpha_blob->dptr<T>(), dy_blob->dptr<T>(), dx_blob->mut_dptr<T>());
} else if (conf.data_format() == "channels_last") {
const int64_t channel_num = dy_blob->shape().At(x_blob->shape().NumAxes() - 1);
PReluDataBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, channel_num, 1, x_blob->dptr<T>(), alpha_blob->dptr<T>(), dy_blob->dptr<T>(),
dx_blob->mut_dptr<T>());
} else {
UNIMPLEMENTED();
}
}
}
};
#define INSTANTIATE_P_RELU_DATA_GRAD_KERNEL_UTIL(type_cpp, type_proto) \
template class PReluDataGradKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_P_RELU_DATA_GRAD_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_PRELU_DATA_GRAD_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_PRELU_DATA_GRAD_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class PReluDataGradKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(PReluDataGradKernel);
PReluDataGradKernel() = default;
~PReluDataGradKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename T>
struct PReluDataGradKernelUtil {
static void Compute(const KernelCtx& ctx, const PReluDataGradOpConf& conf, const Blob* x_blob,
const Blob* alpha_blob, const Blob* dy_blob, Blob* dx_blob);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_PRELU_DATA_GRAD_KERNEL_H_
......@@ -9,22 +9,6 @@ void PReluKernel<device_type, T>::ForwardDataContent(
BnInOp2Blob("alpha"), BnInOp2Blob("out"));
}
template<DeviceType device_type, typename T>
void PReluKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* in_diff_blob = BnInOp2Blob("in_diff");
Blob* alpha_diff_blob = BnInOp2Blob("alpha_diff");
if (in_diff_blob == nullptr) { return; }
Memset<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), 0,
in_diff_blob->ByteSizeOfDataContentField());
Memset<device_type>(ctx.device_ctx, alpha_diff_blob->mut_dptr<T>(), 0,
alpha_diff_blob->ByteSizeOfDataContentField());
PReluKernelUtil<device_type, T>::Backward(
ctx, this->op_conf().prelu_conf(), this->kernel_conf().prelu_conf().perm(), BnInOp2Blob("in"),
BnInOp2Blob("alpha"), BnInOp2Blob("out_diff"), BnInOp2Blob("bw_buf"), in_diff_blob,
alpha_diff_blob);
}
template<typename T>
struct PReluKernelUtil<DeviceType::kCPU, T> {
static void Forward(const KernelCtx& ctx, const PReluOpConf& conf, const Blob* in_blob,
......@@ -56,51 +40,6 @@ struct PReluKernelUtil<DeviceType::kCPU, T> {
}
}
}
static void Backward(const KernelCtx& ctx, const PReluOpConf& conf,
const PbRf<int32_t>& permutation, const Blob* in_blob,
const Blob* alpha_blob, const Blob* out_diff_blob, Blob* bw_buf_blob,
Blob* in_diff_blob, Blob* alpha_diff_blob) {
const T* in_dptr = in_blob->dptr<T>();
const T* alpha_dptr = alpha_blob->dptr<T>();
const T* out_diff_dptr = out_diff_blob->dptr<T>();
T* in_diff_dptr = in_diff_blob->mut_dptr<T>();
T* alpha_diff_dptr = alpha_diff_blob->mut_dptr<T>();
const int64_t elem_cnt = in_blob->shape().elem_cnt();
if (conf.data_format() == "channels_first") {
const int64_t channel_num = in_blob->shape().At(1);
const int64_t alpha_channel_num = conf.channel_shared() ? channel_num : 1;
const int64_t area = in_blob->shape().Count(2);
FOR_RANGE(int64_t, i, 0, elem_cnt) {
if (in_dptr[i] <= 0) {
int64_t c = (i / area) % channel_num / alpha_channel_num;
alpha_diff_dptr[c] += out_diff_dptr[i] * in_dptr[i];
}
if (in_dptr[i] > 0) {
in_diff_dptr[i] = out_diff_dptr[i];
} else {
int64_t c = (i / area) % channel_num / alpha_channel_num;
in_diff_dptr[i] = alpha_dptr[c] * out_diff_dptr[i];
}
}
} else if (conf.data_format() == "channels_last") {
const int64_t channel_num = in_blob->shape().At(in_blob->shape().NumAxes() - 1);
const int64_t alpha_channel_num = conf.channel_shared() ? channel_num : 1;
FOR_RANGE(int64_t, i, 0, elem_cnt) {
if (in_dptr[i] <= 0) {
int64_t c = i % channel_num / alpha_channel_num;
alpha_diff_dptr[c] += out_diff_dptr[i] * in_dptr[i];
}
if (in_dptr[i] > 0) {
in_diff_dptr[i] = out_diff_dptr[i];
} else {
int64_t c = i % channel_num / alpha_channel_num;
in_diff_dptr[i] = alpha_dptr[c] * out_diff_dptr[i];
}
}
} else {
UNIMPLEMENTED();
}
}
};
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kPreluConf, PReluKernel, FLOATING_DATA_TYPE_SEQ);
......
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/kernel/prelu_kernel.h"
#include "oneflow/core/kernel/kernel_util.cuh"
#include "oneflow/core/device/cuda_util.h"
#include <cub/cub.cuh>
namespace oneflow {
namespace {
......@@ -15,25 +10,6 @@ __global__ void PReluForward(const int64_t elem_cnt, const int64_t channel_num,
out_dptr[i] = (in_dptr[i] <= 0) ? in_dptr[i] * alpha_dptr[c] : in_dptr[i];
}
}
template<typename T>
__global__ void PReluDataBackward(const int64_t elem_cnt, const int64_t channel_num,
const int64_t area, const T* in_dptr, const T* alpha_dptr,
const T* out_dff_dptr, T* in_diff_dptr) {
CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
int64_t c = (i / area) % channel_num;
in_diff_dptr[i] = (in_dptr[i] <= 0) ? out_dff_dptr[i] * alpha_dptr[c] : out_dff_dptr[i];
}
}
template<typename T>
__global__ void PReluAlphaBackward(const int64_t elem_cnt, const T* in_dptr, const T* out_diff_dptr,
T* alpha_diff_buf_dptr) {
CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
alpha_diff_buf_dptr[i] = (in_dptr[i] <= 0) ? out_diff_dptr[i] * in_dptr[i] : 0;
}
}
} // namespace
template<typename T>
......@@ -64,57 +40,6 @@ struct PReluKernelUtil<DeviceType::kGPU, T> {
}
}
}
static void Backward(const KernelCtx& ctx, const PReluOpConf& conf,
const PbRf<int32_t>& permutation, const Blob* in_blob,
const Blob* alpha_blob, const Blob* out_diff_blob, Blob* bw_buf_blob,
Blob* in_diff_blob, Blob* alpha_diff_blob) {
const int64_t elem_cnt = out_diff_blob->shape().elem_cnt();
// in_diff_blob acts as buffer here
PReluAlphaBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, in_blob->dptr<T>(), out_diff_blob->dptr<T>(), in_diff_blob->mut_dptr<T>());
if (conf.channel_shared()) {
KernelUtil<DeviceType::kGPU, T>::Sum(
ctx.device_ctx, elem_cnt, in_diff_blob->dptr<T>(), alpha_diff_blob->mut_dptr<T>(),
bw_buf_blob->mut_dptr<T>(), bw_buf_blob->ByteSizeOfDataContentField());
PReluDataBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, 1, 1, in_blob->dptr<T>(), alpha_blob->dptr<T>(), out_diff_blob->dptr<T>(),
in_diff_blob->mut_dptr<T>());
} else {
KernelUtil<DeviceType::kGPU, T>::Transpose(
ctx.device_ctx, in_diff_blob->shape().NumAxes(), in_diff_blob->shape(),
bw_buf_blob->shape(), permutation, in_diff_blob->shape().elem_cnt(),
in_diff_blob->dptr<T>(), bw_buf_blob->mut_dptr<T>());
CHECK_EQ(elem_cnt, bw_buf_blob->shape().elem_cnt());
if (conf.data_format() == "channels_first") {
const int64_t channel_num = out_diff_blob->shape().At(1);
CHECK_EQ(channel_num, bw_buf_blob->shape().At(0));
KernelUtil<DeviceType::kGPU, T>::RowSum(
ctx.device_ctx, channel_num, bw_buf_blob->shape().Count(1), bw_buf_blob->dptr<T>(),
alpha_diff_blob->mut_dptr<T>(), in_diff_blob->mut_dptr<T>(),
in_diff_blob->ByteSizeOfDataContentField());
PReluDataBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, channel_num, out_diff_blob->shape().Count(2), in_blob->dptr<T>(),
alpha_blob->dptr<T>(), out_diff_blob->dptr<T>(), in_diff_blob->mut_dptr<T>());
} else if (conf.data_format() == "channels_last") {
const int64_t channel_num = out_diff_blob->shape().At(in_blob->shape().NumAxes() - 1);
CHECK_EQ(channel_num, bw_buf_blob->shape().At(0));
KernelUtil<DeviceType::kGPU, T>::RowSum(
ctx.device_ctx, channel_num, bw_buf_blob->shape().Count(1), bw_buf_blob->dptr<T>(),
alpha_diff_blob->mut_dptr<T>(), in_diff_blob->mut_dptr<T>(),
in_diff_blob->ByteSizeOfDataContentField());
PReluDataBackward<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
elem_cnt, channel_num, 1, in_blob->dptr<T>(), alpha_blob->dptr<T>(),
out_diff_blob->dptr<T>(), in_diff_blob->mut_dptr<T>());
} else {
UNIMPLEMENTED();
}
}
}
};
#define INSTANTIATE_P_RELU_KERNEL_UTIL(type_cpp, type_proto) \
......
......@@ -15,18 +15,12 @@ class PReluKernel final : public KernelIf<device_type> {
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void BackwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename T>
struct PReluKernelUtil {
static void Forward(const KernelCtx& ctx, const PReluOpConf& conf, const Blob* in_blob,
const Blob* alpha_blob, Blob* out_blob);
static void Backward(const KernelCtx& ctx, const PReluOpConf& conf,
const PbRf<int32_t>& permutation, const Blob* in_blob,
const Blob* alpha_blob, const Blob* out_diff_blob, Blob* bw_buf_blob,
Blob* in_diff_blob, Blob* alpha_diff_blob);
};
} // namespace oneflow
......
......@@ -279,6 +279,24 @@ message PReluOpConf {
required string data_format = 3;
optional bool channel_shared = 4 [default = false];
optional float alpha_init = 5 [default = 0.25];
optional string alpha = 6;
}
message PReluDataGradOpConf {
required string dy = 1;
required string x = 2;
required string alpha = 3;
required string dx = 4;
required string data_format = 5;
required bool channel_shared = 6;
}
message PReluAlphaGradOpConf {
required string dy = 1;
required string x = 2;
required string alpha_grad = 3;
required string data_format = 4;
required bool channel_shared = 5;
}
message SigmoidOpConf {
......@@ -1595,6 +1613,8 @@ message OperatorConf {
AssignOpConf assign_conf = 296;
SnapshotOpConf snapshot_conf = 297;
LearningRateScheduleOpConf learning_rate_schedule_conf = 298;
PReluDataGradOpConf prelu_data_grad_conf = 299;
PReluAlphaGradOpConf prelu_alpha_grad_conf = 300;
// math op
BroadcastAddOpConf broadcast_add_conf = 500;
......
#include "oneflow/core/operator/prelu_alpha_grad_op.h"
#include "oneflow/core/job/sbp_signature_builder.h"
namespace oneflow {
void PReluAlphaGradOp::InitFromOpConf() {
CHECK(op_conf().has_prelu_alpha_grad_conf());
EnrollInputBn("dy", false);
EnrollInputBn("x", false);
EnrollOutputBn("alpha_grad", false);
if (device_type() == DeviceType::kGPU) {
EnrollTmpBn("bw_buf");
EnrollTmpBn("alpha_grad_buf");
}
}
const PbMessage& PReluAlphaGradOp::GetCustomizedConf() const {
return op_conf().prelu_alpha_grad_conf();
}
Maybe<void> PReluAlphaGradOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const PReluAlphaGradOpConf& conf = op_conf().prelu_alpha_grad_conf();
const BlobDesc* x_blob_desc = GetBlobDesc4BnInOp("x");
BlobDesc* alpha_grad_blob_desc = GetBlobDesc4BnInOp("alpha_grad");
if (conf.channel_shared()) {
alpha_grad_blob_desc->mut_shape() = Shape({1});
} else {
if (conf.data_format() == "channels_first") {
alpha_grad_blob_desc->mut_shape() = Shape({x_blob_desc->shape().At(1)});
} else if (conf.data_format() == "channels_last") {
alpha_grad_blob_desc->mut_shape() =
Shape({x_blob_desc->shape().At(x_blob_desc->shape().NumAxes() - 1)});
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
alpha_grad_blob_desc->set_data_type(x_blob_desc->data_type());
if (device_type() == DeviceType::kGPU) {
BlobDesc* bw_buf_desc = GetBlobDesc4BnInOp("bw_buf");
BlobDesc* alpha_grad_buf_desc = GetBlobDesc4BnInOp("alpha_grad_buf");
*alpha_grad_buf_desc = *x_blob_desc;
if (conf.channel_shared()) {
*bw_buf_desc = *x_blob_desc;
} else {
bw_buf_desc->set_data_type(x_blob_desc->data_type());
std::vector<int64_t> bw_buf_shape_vec = x_blob_desc->shape().dim_vec();
if (conf.data_format() == "channels_first") {
bw_buf_shape_vec[0] = x_blob_desc->shape().At(1);
bw_buf_shape_vec[1] = x_blob_desc->shape().At(0);
bw_buf_desc->mut_shape() = Shape(bw_buf_shape_vec);
} else if (conf.data_format() == "channels_last") {
bw_buf_shape_vec[0] = x_blob_desc->shape().At(x_blob_desc->shape().NumAxes() - 1);
bw_buf_shape_vec[x_blob_desc->shape().NumAxes() - 1] = x_blob_desc->shape().At(0);
bw_buf_desc->mut_shape() = Shape(bw_buf_shape_vec);
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
}
return Maybe<void>::Ok();
}
void PReluAlphaGradOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
const PReluAlphaGradOpConf& conf = op_conf().prelu_alpha_grad_conf();
PbRf<int32_t>* perm = kernel_conf->mutable_prelu_alpha_grad_conf()->mutable_perm();
const BlobDesc* x_blob_desc = GetBlobDesc4BnInOp("x");
int64_t num_axes = x_blob_desc->shape().NumAxes();
FOR_RANGE(int64_t, i, 0, num_axes) { perm->Add(i); }
if (!conf.channel_shared()) {
if (conf.data_format() == "channels_first") {
(*perm)[0] = 1;
(*perm)[1] = 0;
} else if (conf.data_format() == "channels_last") {
(*perm)[num_axes - 1] = 0;
(*perm)[0] = num_axes - 1;
} else {
UNIMPLEMENTED();
}
}
}
Maybe<void> PReluAlphaGradOp::InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {
CHECK_OR_RETURN(*HasBatchDim4BnInOp("dy"));
CHECK_OR_RETURN(*HasBatchDim4BnInOp("x"));
*HasBatchDim4BnInOp("alpha_grad") = false;
return Maybe<void>::Ok();
}
void PReluAlphaGradOp::GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder()
.Split("dy", 0)
.Split("x", 0)
.PartialSum("alpha_grad")
.Build(sbp_sig_list->mutable_sbp_signature()->Add());
}
REGISTER_OP(OperatorConf::kPreluAlphaGradConf, PReluAlphaGradOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_PRELU_ALPHA_GRAD_OP_H_
#define ONEFLOW_CORE_OPERATOR_PRELU_ALPHA_GRAD_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class PReluAlphaGradOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(PReluAlphaGradOp);
PReluAlphaGradOp() = default;
~PReluAlphaGradOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*) const override;
Maybe<void> InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
void GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_PRELU_ALPHA_GRAD_OP_H_
#include "oneflow/core/operator/prelu_data_grad_op.h"
#include "oneflow/core/job/sbp_signature_builder.h"
namespace oneflow {
void PReluDataGradOp::InitFromOpConf() {
CHECK(op_conf().has_prelu_data_grad_conf());
EnrollInputBn("dy", false);
EnrollInputBn("alpha", false);
EnrollInputBn("x", false);
EnrollOutputBn("dx", false);
}
const PbMessage& PReluDataGradOp::GetCustomizedConf() const {
return op_conf().prelu_data_grad_conf();
}
Maybe<void> PReluDataGradOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
*GetBlobDesc4BnInOp("dx") = *GetBlobDesc4BnInOp("x");
return Maybe<void>::Ok();
}
Maybe<void> PReluDataGradOp::InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {
CHECK_OR_RETURN(*HasBatchDim4BnInOp("dy"));
CHECK_OR_RETURN(*HasBatchDim4BnInOp("x"));
CHECK_OR_RETURN(*HasBatchDim4BnInOp("alpha") == false);
*HasBatchDim4BnInOp("dx") = true;
return Maybe<void>::Ok();
}
void PReluDataGradOp::GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder().Split("dy", 0).Broadcast("alpha").Split("x", 0).Split("dx", 0).Build(
sbp_sig_list->mutable_sbp_signature()->Add());
}
REGISTER_OP(OperatorConf::kPreluDataGradConf, PReluDataGradOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_PRELU_DATA_GRAD_OP_H_
#define ONEFLOW_CORE_OPERATOR_PRELU_DATA_GRAD_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class PReluDataGradOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(PReluDataGradOp);
PReluDataGradOp() = default;
~PReluDataGradOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
Maybe<void> InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
void GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_PRELU_OP_H_
......@@ -6,10 +6,16 @@ namespace oneflow {
void PReluOp::InitFromOpConf() {
CHECK(op_conf().has_prelu_conf());
const PReluOpConf& conf = op_conf().prelu_conf();
StrFieldTolower("data_format");
EnrollInputBn("in");
EnrollTmpBn("alpha");
if (conf.has_alpha()) {
EnrollInputBn("alpha");
} else {
EnrollTmpBn("alpha");
}
EnrollOutputBn("out")->set_mutable_inplace_ibn("in");
;
}
const PbMessage& PReluOp::GetCustomizedConf() const { return op_conf().prelu_conf(); }
......@@ -19,46 +25,34 @@ Maybe<void> PReluOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)>
const PReluOpConf& conf = op_conf().prelu_conf();
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
*GetBlobDesc4BnInOp("out") = *in_blob_desc;
BlobDesc* alpha_blob_desc = GetBlobDesc4BnInOp("alpha");
int32_t alpha_size;
if (conf.channel_shared()) {
alpha_blob_desc->mut_shape() = Shape({1});
alpha_size = 1;
} else {
if (conf.data_format() == "channels_first") {
alpha_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(1)});
alpha_size = in_blob_desc->shape().At(1);
} else if (conf.data_format() == "channels_last") {
alpha_blob_desc->mut_shape() =
Shape({in_blob_desc->shape().At(in_blob_desc->shape().NumAxes() - 1)});
alpha_size = in_blob_desc->shape().At(in_blob_desc->shape().NumAxes() - 1);
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
alpha_blob_desc->set_data_type(in_blob_desc->data_type());
const Shape alpha_shape({alpha_size});
if (conf.has_alpha()) {
CHECK_EQ_OR_RETURN(GetBlobDesc4BnInOp("alpha")->shape(), alpha_shape);
CHECK_EQ_OR_RETURN(GetBlobDesc4BnInOp("alpha")->data_type(), in_blob_desc->data_type());
} else {
BlobDesc* alpha_blob_desc = GetBlobDesc4BnInOp("alpha");
alpha_blob_desc->set_data_type(in_blob_desc->data_type());
alpha_blob_desc->mut_shape() = alpha_shape;
}
return Maybe<void>::Ok();
}
void PReluOp::VirtualFixParallelDesc(ParallelDesc* pr_desc) const {
pr_desc->set_policy(ParallelPolicy::kDataParallel);
}
void PReluOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
const PReluOpConf& conf = op_conf().prelu_conf();
PbRf<int32_t>* perm = kernel_conf->mutable_prelu_conf()->mutable_perm();
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
int64_t num_axes = in_blob_desc->shape().NumAxes();
FOR_RANGE(int64_t, i, 0, num_axes) { perm->Add(i); }
if (!conf.channel_shared()) {
if (conf.data_format() == "channels_first") {
(*perm)[0] = 1;
(*perm)[1] = 0;
} else if (conf.data_format() == "channels_last") {
(*perm)[num_axes - 1] = 0;
(*perm)[0] = num_axes - 1;
} else {
UNIMPLEMENTED();
}
}
Maybe<void> PReluOp::InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {
*HasBatchDim4BnInOp("out") = *HasBatchDim4BnInOp("in");
return Maybe<void>::Ok();
}
void PReluOp::GetSbpSignatures(
......
......@@ -15,16 +15,10 @@ class PReluOp final : public Operator {
const PbMessage& GetCustomizedConf() const override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
void VirtualFixParallelDesc(ParallelDesc* pr_desc) const override;
private:
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*) const override;
Maybe<void> InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override {
return NaiveInferHasBatchDim(HasBatchDim4BnInOp);
}
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
void GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment