Skip to content
Snippets Groups Projects
Unverified Commit 399c401b authored by ZZK's avatar ZZK Committed by GitHub
Browse files

Align pytorch maxpool (#5525)


* align torch maxpool

* remove redundant params

* add torch style functional

* remove cout code

* align pooling backward

* fix ceil mode

* add dataformat back

* add annotation

* add maxpool1d functor

* fix farward to forward

* add getwindowedOutputShape function

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent f8331628
No related branches found
No related tags found
No related merge requests found
......@@ -33,9 +33,7 @@ struct PoolingInterpState : public OpExprInterpState {
size_t indice_index;
std::string data_format;
std::string padding;
std::vector<int32_t> padding_before;
std::vector<int32_t> padding_after;
std::vector<int32_t> padding;
std::vector<int32_t> kernel_size;
std::vector<int32_t> stride;
std::vector<int32_t> dilation;
......@@ -76,9 +74,7 @@ Maybe<void> PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& i
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding = JUST(composed_attrs.GetAttr<std::string>("padding"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_after"));
ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
ctx->dilation = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation"));
......@@ -100,8 +96,7 @@ Maybe<void> PoolingNdGrad::Apply(const PoolingInterpState* ctx, const TensorTupl
in_grads->resize(1);
in_grads->at(0) = JUST(functional::PoolingNdGrad(
input, output, indice, out_grads.at(0), mode_, ndims, ctx->data_format, ctx->padding,
ctx->padding_before, ctx->padding_after, ctx->kernel_size, ctx->stride, ctx->dilation,
ctx->return_indices, ctx->ceil_mode));
ctx->kernel_size, ctx->stride, ctx->dilation, ctx->return_indices, ctx->ceil_mode));
return Maybe<void>::Ok();
}
......@@ -113,6 +108,7 @@ class MaxpoolNdGrad final : public PoolingNdGrad {
Maybe<void> Init(const OpExpr& op) override { return PoolingNdGrad::Init(op, "max"); }
};
REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_1d", MaxpoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_2d", MaxpoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_3d", MaxpoolNdGrad);
......
......@@ -411,26 +411,31 @@
Int32List strides, Bool ceil_mode)"
bind_python: False
- name: "maxpool_1d"
signature:
"TensorTuple Maxpool1D(Tensor x, *, String data_format=\"channels_first\", Int32List padding,
Int32List kernel_size, Int32List stride, Int32List dilation,
Bool return_indices=True, Bool ceil_mode=False)"
bind_python: True
- name: "maxpool_2d"
signature:
"TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", String padding,
Int32List padding_before, Int32List padding_after,
"TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", Int32List padding,
Int32List kernel_size, Int32List stride, Int32List dilation,
Bool return_indices=True, Bool ceil_mode=False)"
bind_python: True
- name: "maxpool_3d"
signature:
"TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", String padding,
Int32List padding_before, Int32List padding_after,
"TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", Int32List padding,
Int32List kernel_size, Int32List stride, Int32List dilation,
Bool return_indices=True, Bool ceil_mode=False)"
bind_python: True
- name: "pooling_grad"
signature:
"Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, String data_format,
String padding, Int32List padding_before, Int32List padding_after, Int32List kernel_size,
"Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims,
String data_format, Int32List padding, Int32List kernel_size,
Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode)"
bind_python: False
......
......@@ -232,18 +232,14 @@ class PoolingNDFunctor {
PoolingNDFunctor() = default;
virtual ~PoolingNDFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,
const std::string& data_format, const std::string& padding,
const std::vector<int32_t>& padding_before,
const std::vector<int32_t>& padding_after,
const std::string& data_format, const std::vector<int32_t>& padding,
const std::vector<int32_t>& kernel_size,
const std::vector<int32_t>& stride,
const std::vector<int32_t>& dilation, const bool& return_indices,
const bool& ceil_mode) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("padding", padding));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after));
JUST(attrs.SetAttr<std::string>("data_format", data_format));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding));
JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
......@@ -270,6 +266,13 @@ class MaxPool2DFunctor : public PoolNDFunctor {
}
};
class Maxpool1DFunctor : public PoolingNDFunctor {
public:
Maxpool1DFunctor() {
op_ = CHECK_JUST(one::OpBuilder("maxpool_1d").Input("x").Output("y").Output("indice").Build());
}
};
class Maxpool2DFunctor : public PoolingNDFunctor {
public:
Maxpool2DFunctor() {
......@@ -472,6 +475,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::LayerNormFunctor>("LayerNorm");
m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine");
m.add_functor<impl::AvgPool2DFunctor>("AvgPool2D");
m.add_functor<impl::Maxpool1DFunctor>("Maxpool1D");
m.add_functor<impl::Maxpool2DFunctor>("Maxpool2D");
m.add_functor<impl::Maxpool3DFunctor>("Maxpool3D");
m.add_functor<impl::MaxPool2DFunctor>("MaxPool2D");
......
......@@ -113,7 +113,7 @@ class PoolingNdGradFunctor {
public:
PoolingNdGradFunctor() {
for (const auto& mode : {"max"}) {
for (int ndims = 2; ndims <= 3; ++ndims) {
for (int ndims = 1; ndims <= 3; ++ndims) {
const auto& op_type_name = GetOpTypeName(mode, ndims);
op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name)
.Input("x")
......@@ -133,16 +133,13 @@ class PoolingNdGradFunctor {
const std::shared_ptr<one::Tensor>& indice,
const std::shared_ptr<one::Tensor>& dy, const std::string& mode,
const int32_t& ndims, const std::string& data_format,
const std::string& padding, const std::vector<int32_t>& padding_before,
const std::vector<int32_t>& padding_after,
const std::vector<int32_t>& padding,
const std::vector<int32_t>& kernel_size,
const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,
const bool& return_indices, const bool& ceil_mode) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("padding", padding));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after));
JUST(attrs.SetAttr<std::string>("data_format", data_format));
JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding));
JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
......
......@@ -316,44 +316,26 @@ class MaxPool1d(Module):
ceil_mode: bool = False,
):
super().__init__()
self.kernel_size = _getint(kernel_size)
self.stride = _getint(stride) if stride is not None else self.kernel_size
data_format = "NCL" # Only suport "NCL" for now!
self.kernel_size = _single(kernel_size)
self.stride = _single(stride) if stride is not None else self.kernel_size
data_format = "NCL" # only support "NCL" for now !
self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last"
self.dilation = _getint(dilation)
self.padding = _getint(padding)
self.dilation = _single(dilation)
self.padding = _single(padding)
self.return_indices = return_indices
self.ceil_mode = ceil_mode
if self.channel_pos == "channels_first":
padding = (0, 0, self.padding, 0)
else:
raise ValueError("error padding param!")
self.padding_type, pads_list = calc_pool_padding(
padding, get_dhw_offset(self.channel_pos), 2
)
self.padding_before = [pad[0] for pad in pads_list]
self.padding_after = [pad[1] for pad in pads_list]
def forward(self, x):
expand_x = x.unsqueeze(dim=-1)
expand_y, expand_indice = flow.F.maxpool_2d(
expand_x,
y, indice = flow.F.maxpool_1d(
x,
data_format=self.channel_pos,
padding=self.padding_type,
padding_before=self.padding_before,
padding_after=self.padding_after,
kernel_size=[self.kernel_size, 1],
stride=[self.stride, 1],
dilation=[self.dilation, 1],
padding=self.padding,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
return_indices=True,
ceil_mode=self.ceil_mode,
)
y = expand_y.squeeze(dim=-1)
indice = expand_indice.squeeze(dim=-1)
if self.return_indices:
return y, indice
else:
......@@ -454,45 +436,27 @@ class MaxPool2d(Module):
):
super().__init__()
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size)
data_format = "NCHW" # Only suport "NCHW" for now!
data_format = "NCHW" # only support "NCHW" for now !
self.channel_pos = (
"channels_first" if data_format == "NCHW" else "channels_last"
)
self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size)
self.dilation = _GetSequence(dilation, 2, "dilation")
self.return_indices = return_indices
self.ceil_mode = ceil_mode
padding = _pair(padding)
self.padding = padding
if len(padding) == 2:
if data_format == "NCHW":
padding = (0, 0, padding[0], padding[1])
else:
raise ValueError("error padding param!")
else:
raise ValueError("error padding param!")
self.padding_type, pads_list = calc_pool_padding(
padding, get_dhw_offset(self.channel_pos), 2
)
self.padding_before = [pad[0] for pad in pads_list]
self.padding_after = [pad[1] for pad in pads_list]
self.padding = _pair(padding)
def forward(self, x):
y, indice = flow.F.maxpool_2d(
x,
data_format=self.channel_pos,
padding=self.padding_type,
padding_before=self.padding_before,
padding_after=self.padding_after,
padding=self.padding,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
return_indices=True,
ceil_mode=self.ceil_mode,
)
if self.return_indices:
return y, indice
else:
......@@ -605,32 +569,15 @@ class MaxPool3d(Module):
"channels_last" if data_format == "NDHWC" else "channels_first"
)
self.dilation = _GetSequence(dilation, 3, "dilation")
padding = _triple(padding)
self.padding = padding
self.padding = _triple(padding)
self.return_indices = return_indices
self.ceil_mode = ceil_mode
if len(padding) == 3:
if data_format == "NCDHW":
padding = (0, 0, padding[0], padding[1], padding[2])
else:
raise ValueError("error padding param!")
else:
raise ValueError("error padding param!")
self.padding_type, pads_list = calc_pool_padding(
padding, get_dhw_offset(self.channel_pos), 3
)
self.padding_before = [pad[0] for pad in pads_list]
self.padding_after = [pad[1] for pad in pads_list]
def forward(self, x):
y, indice = flow.F.maxpool_3d(
x,
data_format=self.channel_pos,
padding=self.padding_type,
padding_before=self.padding_before,
padding_after=self.padding_after,
padding=self.padding,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
......
......@@ -26,31 +26,46 @@ struct PoolingOpKernelState final : public user_op::OpKernelState {
std::shared_ptr<PoolingOpKernelState> DoCreateOpKernelState(user_op::KernelComputeContext* ctx,
const int32_t& dim) {
const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape();
const std::string& padding = ctx->Attr<std::string>("padding");
const std::string& data_format = ctx->Attr<std::string>("data_format");
const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride");
const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation");
const bool return_indices = ctx->Attr<bool>("return_indices");
const bool ceil_mode = ctx->Attr<bool>("ceil_mode");
PoolingParams3D params_3d =
PoolingParams3D(dim, x_shape, data_format, padding, padding_before, padding_after,
kernel_size, stride, dilation, return_indices, ceil_mode);
PoolingParams3D params_3d = PoolingParams3D(dim, x_shape, data_format, padding, kernel_size,
stride, dilation, return_indices, ceil_mode);
std::shared_ptr<PoolingOpKernelState> state(new PoolingOpKernelState(params_3d));
return std::move(state);
}
template<typename T>
struct PoolingKernelUtil<DeviceType::kCPU, T> {
static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d) {
Maxpool1dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
params_3d.padding()[2], params_3d.num_batch(),
params_3d.num_channel(), params_3d.GetXShape5D().At(4),
params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[2],
params_3d.stride_3d()[2], params_3d.dilation_3d()[2]);
}
static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
Maxpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));
}
static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d) {
Maxpool2dFarwardCompute<T>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1],
params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(),
Maxpool2dForwardCompute<T>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],
params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),
params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1],
params_3d.pooling_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
......@@ -69,9 +84,9 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> {
static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d) {
Maxpool3dFarwardCompute<T>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0],
params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], params_3d.num_batch(),
Maxpool3dForwardCompute<T>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0],
params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(),
params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),
params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0],
......@@ -91,6 +106,68 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> {
}
};
template<DeviceType device_type, typename T>
class MaxPool1dKernel final : public user_op::OpKernel {
public:
MaxPool1dKernel() = default;
~MaxPool1dKernel() = default;
private:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
const auto& pooling_state = DoCreateOpKernelState(ctx, 1);
const PoolingParams3D& params_3d = pooling_state->GetParams3D();
const int64_t elem_num = y->shape().elem_cnt();
const T* src = x->dptr<T>();
T* dest = y->mut_dptr<T>();
int64_t* indice_ptr = indice->mut_dptr<int64_t>();
DimVector y_vector;
y->shape().ToDimVector(&y_vector);
NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());
PoolingKernelUtil<device_type, T>::Maxpool1dForward(ctx->device_ctx(), index_helper, elem_num,
src, dest, indice_ptr, params_3d);
};
};
template<DeviceType device_type, typename T>
class MaxPool1dGradKernel final : public user_op::OpKernel {
public:
MaxPool1dGradKernel() = default;
~MaxPool1dGradKernel() = default;
private:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
const auto& pooling_state = DoCreateOpKernelState(ctx, 1);
const PoolingParams3D& params_3d = pooling_state->GetParams3D();
const int64_t elem_num = dy->shape().elem_cnt();
const T* src = dy->dptr<T>();
const int64_t* indice_ptr = indice->dptr<int64_t>();
T* dest = dx->mut_dptr<T>();
DimVector dy_vector;
dy->shape().ToDimVector(&dy_vector);
NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());
size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
PoolingKernelUtil<device_type, T>::Maxpool1dBackward(ctx->device_ctx(), index_helper, elem_num,
src, dest, indice_ptr, params_3d);
};
};
template<DeviceType device_type, typename T>
class MaxPool2dKernel final : public user_op::OpKernel {
public:
......@@ -217,6 +294,14 @@ class MaxPool3dGradKernel final : public user_op::OpKernel {
};
#define REGISTER_POOLING_KERNELS(device, dtype) \
REGISTER_USER_KERNEL("maxpool_1d") \
.SetCreateFn<MaxPool1dKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("maxpool_1d_grad") \
.SetCreateFn<MaxPool1dGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("maxpool_2d") \
.SetCreateFn<MaxPool2dKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
......
......@@ -30,6 +30,17 @@ int GetNumBlocks(int64_t elem_cnt) {
return num_blocks;
}
template<typename T>
__launch_bounds__(kBlockSize) __global__
void DoCUDAMaxPool1dForward(const NdIndexOffsetHelper<int64_t, 3> index_helper,
int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
int32_t padding_l, int64_t n_batch, int64_t n_channel,
int64_t x_length, int64_t y_length, int32_t kernel_size_l,
int32_t stride_l, int32_t dilation_l) {
Maxpool1dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_l, n_batch,
n_channel, x_length, y_length, kernel_size_l, stride_l, dilation_l);
};
template<typename T>
__launch_bounds__(kBlockSize) __global__
void DoCUDAMaxPool2dForward(const NdIndexOffsetHelper<int64_t, 4> index_helper,
......@@ -39,7 +50,7 @@ __launch_bounds__(kBlockSize) __global__
int64_t y_height, int64_t y_width, int32_t kernel_size_h,
int32_t kernel_size_w, int32_t stride_h, int32_t stride_w,
int32_t dilation_h, int32_t dilation_w) {
Maxpool2dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w,
Maxpool2dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w,
n_batch, n_channel, x_height, x_width, y_height, y_width,
kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h,
dilation_w);
......@@ -56,12 +67,23 @@ __launch_bounds__(kBlockSize) __global__
int32_t kernel_size_w, int32_t stride_t, int32_t stride_h,
int32_t stride_w, int32_t dilation_t, int32_t dilation_h,
int32_t dilation_w) {
Maxpool3dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h,
Maxpool3dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h,
padding_w, n_batch, n_channel, x_time, x_height, x_width, y_time,
y_height, y_width, kernel_size_t, kernel_size_h, kernel_size_w,
stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w);
};
template<typename T>
__launch_bounds__(kBlockSize) __global__
void DoCUDAMaxPool1dBackward(const NdIndexOffsetHelper<int64_t, 3> index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr, const int64_t n_batch,
const int64_t n_channel, const int64_t src_length,
const int64_t dst_length) {
Maxpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel,
src_length, dst_length);
};
template<typename T>
__launch_bounds__(kBlockSize) __global__
void DoCUDAMaxPool2dBackward(const NdIndexOffsetHelper<int64_t, 4> index_helper,
......@@ -89,13 +111,33 @@ __launch_bounds__(kBlockSize) __global__
template<typename T>
struct PoolingKernelUtil<DeviceType::kGPU, T> {
static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d) {
DoCUDAMaxPool1dForward<T>
<<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2],
params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4),
params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[2], params_3d.stride_3d()[2],
params_3d.dilation_3d()[2]);
}
static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
DoCUDAMaxPool1dBackward<T>
<<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),
params_3d.num_channel(), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));
}
static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d) {
DoCUDAMaxPool2dForward<T>
<<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1],
params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(),
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],
params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),
params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2],
......@@ -118,15 +160,15 @@ struct PoolingKernelUtil<DeviceType::kGPU, T> {
const PoolingParams3D& params_3d) {
DoCUDAMaxPool3dForward<T>
<<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0],
params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2],
params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),
params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0],
params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2],
params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0],
params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(),
params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),
params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2),
params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
params_3d.pooling_size_3d()[0], params_3d.pooling_size_3d()[1],
params_3d.pooling_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1],
params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1],
params_3d.dilation_3d()[2]);
}
static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
......
......@@ -43,19 +43,36 @@ std::vector<int32_t> Get3DPadVec(const std::vector<int32_t>& original_vec, int32
return vec;
}
void GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride,
int32_t padding, bool ceil_mode, int32_t dilation_rate,
int64_t* output_ptr) {
*output_ptr = (input_size + 2 * padding - dilation_rate * (filter_size - 1) - 1 + stride
+ (ceil_mode ? stride - 1 : 0))
/ stride;
}
void Get3DOutputShape(const DimVector& in, const std::vector<int32_t>& pool_size,
const std::vector<int32_t>& strides, const std::vector<int32_t>& padding,
const bool ceil_mode, std::vector<int32_t> dilation_rate, DimVector* out) {
out->clear();
out->resize(3);
FOR_RANGE(size_t, i, 0, 3) {
int64_t* out_ptr = &(*out).at(i);
GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), ceil_mode,
dilation_rate.at(i), out_ptr);
}
}
PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
const std::string& data_format, const std::string& padding,
const std::vector<int32_t>& padding_before,
const std::vector<int32_t>& padding_after,
const std::string& data_format,
const std::vector<int32_t>& padding,
const std::vector<int32_t>& kernel_size,
const std::vector<int32_t>& stride,
const std::vector<int32_t>& dilation, const bool return_indices,
const bool ceil_mode)
: dim_(dim),
data_format_(data_format),
padding_(padding),
padding_before_3d_(Get3DPadVec(padding_before, dim)),
padding_after_3d_(Get3DPadVec(padding_after, dim)),
padding_(Get3DPadVec(padding, dim)),
pooling_size_3d_(Get3DVec(kernel_size, dim)),
stride_3d_(Get3DVec(stride, dim)),
dilation_3d_(Get3DVec(dilation, dim)),
......@@ -63,8 +80,7 @@ PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
ceil_mode_(ceil_mode) {
x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),
GetInDim(x_shape, data_format, 2, dim)};
Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_,
&padding_before_3d_, &padding_after_3d_);
Get3DOutputShape(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_);
if (data_format == "channels_first") {
channel_num_ = x_shape.At(1);
} else {
......@@ -78,8 +94,7 @@ PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
void PoolingParams3D::Reset(const ShapeView& x_shape) {
x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_),
GetInDim(x_shape, data_format_, 2, dim_)};
Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_,
&padding_before_3d_, &padding_after_3d_);
Get3DOutputShape(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_);
}
Shape PoolingParams3D::GetYShape() const {
......
......@@ -53,16 +53,13 @@ struct DeviceAdd {
class PoolingParams3D {
public:
PoolingParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,
const std::string& padding, const std::vector<int32_t>& padding_before,
const std::vector<int32_t>& padding_after,
const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride,
const std::vector<int32_t>& dilation, const bool return_indices,
const bool ceil_mode);
const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size,
const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,
const bool return_indices, const bool ceil_mode);
~PoolingParams3D() = default;
const std::string& data_format() const { return data_format_; }
const std::vector<int32_t>& padding_before_3d() const { return padding_before_3d_; }
const std::vector<int32_t>& padding_after_3d() const { return padding_after_3d_; }
const std::vector<int32_t>& padding() const { return padding_; }
const std::vector<int32_t>& pooling_size_3d() const { return pooling_size_3d_; }
const std::vector<int32_t>& stride_3d() const { return stride_3d_; }
const std::vector<int32_t>& dilation_3d() const { return dilation_3d_; }
......@@ -81,9 +78,7 @@ class PoolingParams3D {
FixedDimVector x_3d_;
FixedDimVector y_3d_;
std::string data_format_;
std::string padding_;
std::vector<int32_t> padding_before_3d_;
std::vector<int32_t> padding_after_3d_;
std::vector<int32_t> padding_;
std::vector<int32_t> pooling_size_3d_;
std::vector<int32_t> stride_3d_;
std::vector<int32_t> dilation_3d_;
......@@ -95,6 +90,14 @@ class PoolingParams3D {
template<DeviceType device_type, typename T>
struct PoolingKernelUtil {
static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d);
static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr, const PoolingParams3D& params_3d);
static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const PoolingParams3D& params_3d);
......@@ -113,7 +116,70 @@ struct PoolingKernelUtil {
};
template<typename T>
OF_DEVICE_FUNC void Maxpool2dFarwardCompute(
OF_DEVICE_FUNC void Maxpool1dForwardCompute(const NdIndexOffsetHelper<int64_t, 3> index_helper,
int64_t elem_num, const T* src, T* dest,
int64_t* indice_ptr, const int32_t padding_l,
const int64_t n_batch, const int64_t n_channel,
const int64_t x_length, const int64_t y_length,
const int32_t kernel_size_l, const int32_t stride_l,
const int32_t dilation_l) {
XPU_1D_KERNEL_LOOP(num, elem_num) {
int64_t n, c, l;
index_helper.OffsetToNdIndex(num, n, c, l);
// n, c, l->index = n*c*l + c* l
const int64_t start_idx = (n * n_channel + c) * x_length;
int64_t lstart = l * stride_l - padding_l;
const int64_t lend = (lstart + (kernel_size_l - 1) * dilation_l + 1) <= x_length
? (lstart + (kernel_size_l - 1) * dilation_l + 1)
: x_length;
while (lstart < 0) { lstart += dilation_l; }
/* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */
int64_t maxindex = lstart;
int64_t src_idx = 0;
/* equal to -std::numeric_limits<T>::infinity(); */
T max_value = detail::numeric_limits<T>::lower_bound();
for (int64_t idx = lstart; idx < lend; idx += dilation_l) {
const int64_t search_idx = start_idx + idx;
T val = src[search_idx];
if (val > max_value || detail::numerics<T>::isnan(val)) {
max_value = val;
maxindex = idx;
src_idx = search_idx;
}
}
dest[num] = src[src_idx];
indice_ptr[num] = maxindex;
}
}
template<typename T>
OF_DEVICE_FUNC void Maxpool1dBackwardCompute(const NdIndexOffsetHelper<int64_t, 3> index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr, const int64_t n_batch,
const int64_t n_channel, const int64_t src_length,
const int64_t dst_length) {
XPU_1D_KERNEL_LOOP(num, elem_num) {
int64_t n, c, l;
index_helper.OffsetToNdIndex(num, n, c, l);
const int64_t src_start = (n * n_channel + c) * src_length;
const int64_t dst_start = (n * n_channel + c) * dst_length;
const int64_t index = src_start + l;
const int64_t maxindex = dst_start + indice_ptr[index];
if (maxindex != -1) {
/* update gradient, equals to dest[maxindex] += src[index]; */
DeviceAdd<T>::Invoke(src + index, dest + maxindex);
}
}
}
template<typename T>
OF_DEVICE_FUNC void Maxpool2dForwardCompute(
const NdIndexOffsetHelper<int64_t, 4> index_helper, int64_t elem_num, const T* src, T* dest,
int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch,
const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height,
......@@ -194,7 +260,7 @@ OF_DEVICE_FUNC void Maxpool2dBackwardCompute(const NdIndexOffsetHelper<int64_t,
}
template<typename T>
OF_DEVICE_FUNC void Maxpool3dFarwardCompute(
OF_DEVICE_FUNC void Maxpool3dForwardCompute(
const NdIndexOffsetHelper<int64_t, 5> index_helper, int64_t elem_num, const T* src, T* dest,
int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,
const int64_t n_batch, const int64_t n_channel, const int64_t x_time, const int64_t x_height,
......
......@@ -28,9 +28,7 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
return [dim](user_op::InferContext* ctx) -> Maybe<void> {
const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0);
const std::string& data_format = ctx->Attr<std::string>("data_format");
const std::string& padding = ctx->Attr<std::string>("padding");
const std::vector<int32_t>& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
const std::vector<int32_t>& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride");
const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation");
......@@ -41,14 +39,13 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); }
CHECK_EQ_OR_RETURN(stride.size(), dim);
for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); }
for (int32_t i = 0; i < padding_after.size(); i++) {
CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding_after[i])
for (int32_t i = 0; i < padding.size(); i++) {
CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i])
<< "pad should be smaller than half of kernel size";
}
const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, padding_before,
padding_after, kernel_size, stride, dilation, return_indices,
ceil_mode);
const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride,
dilation, return_indices, ceil_mode);
user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0);
*y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0);
*y_desc->mut_shape() = params_3d.GetYShape();
......@@ -64,10 +61,9 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
Maybe<void> ForwardGetSbpFn(user_op::SbpContext* ctx) {
const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {
if (padding_before[i] == 0 && padding_after[i] == 0) {
if (padding[i] == 0) {
ctx->NewBuilder()
.Split(user_op::OpArg("x", 0), i)
.Split(user_op::OpArg("y", 0), i)
......@@ -85,10 +81,9 @@ Maybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) {
Maybe<void> BackwardGetSbpFn(user_op::SbpContext* ctx) {
const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {
if (padding_before[i] == 0 && padding_after[i] == 0) {
if (padding[i] == 0) {
ctx->NewBuilder()
.Split(user_op::OpArg("x", 0), i)
.Split(user_op::OpArg("y", 0), i)
......@@ -123,9 +118,7 @@ GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t
.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Output("dx")
.Attr("data_format", op.attr<std::string>("data_format"))
.Attr("padding", op.attr<std::string>("padding"))
.Attr("padding_before", op.attr<std::vector<int32_t>>("padding_before"))
.Attr("padding_after", op.attr<std::vector<int32_t>>("padding_after"))
.Attr("padding", op.attr<std::vector<int32_t>>("padding"))
.Attr("kernel_size", op.attr<std::vector<int32_t>>("kernel_size"))
.Attr("stride", op.attr<std::vector<int32_t>>("stride"))
.Attr("dilation", op.attr<std::vector<int32_t>>("dilation"))
......@@ -141,13 +134,45 @@ GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t
} // namespace
REGISTER_USER_OP("maxpool_1d")
.Input("x")
.Output("y")
.Output("indice")
.Attr<std::vector<int32_t>>("padding")
.Attr<std::string>("data_format")
.Attr<std::vector<int32_t>>("kernel_size")
.Attr<std::vector<int32_t>>("stride")
.Attr<std::vector<int32_t>>("dilation")
.Attr<bool>("return_indices")
.Attr<bool>("ceil_mode")
.SetTensorDescInferFn(MakeForwardTensorDescInferFn(1))
.SetGetSbpFn(ForwardGetSbpFn)
.SetDataTypeInferFn(FwInferDataType);
REGISTER_USER_OP("maxpool_1d_grad")
.Input("x")
.Input("y")
.Input("indice")
.Input("dy")
.Output("dx")
.Attr<std::vector<int32_t>>("padding")
.Attr<std::string>("data_format")
.Attr<std::vector<int32_t>>("kernel_size")
.Attr<std::vector<int32_t>>("stride")
.Attr<std::vector<int32_t>>("dilation")
.Attr<bool>("return_indices")
.Attr<bool>("ceil_mode")
.SetTensorDescInferFn(BackwardTensorDescInferFn)
.SetGetSbpFn(BackwardGetSbpFn)
.SetDataTypeInferFn(BwInferDataType);
REGISTER_USER_OP_GRAD("maxpool_1d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 1));
REGISTER_USER_OP("maxpool_2d")
.Input("x")
.Output("y")
.Output("indice")
.Attr<std::string>("padding")
.Attr<std::vector<int32_t>>("padding_before")
.Attr<std::vector<int32_t>>("padding_after")
.Attr<std::vector<int32_t>>("padding")
.Attr<std::string>("data_format")
.Attr<std::vector<int32_t>>("kernel_size")
.Attr<std::vector<int32_t>>("stride")
......@@ -164,9 +189,7 @@ REGISTER_USER_OP("maxpool_2d_grad")
.Input("indice")
.Input("dy")
.Output("dx")
.Attr<std::string>("padding")
.Attr<std::vector<int32_t>>("padding_before")
.Attr<std::vector<int32_t>>("padding_after")
.Attr<std::vector<int32_t>>("padding")
.Attr<std::string>("data_format")
.Attr<std::vector<int32_t>>("kernel_size")
.Attr<std::vector<int32_t>>("stride")
......@@ -183,9 +206,7 @@ REGISTER_USER_OP("maxpool_3d")
.Input("x")
.Output("y")
.Output("indice")
.Attr<std::string>("padding")
.Attr<std::vector<int32_t>>("padding_before")
.Attr<std::vector<int32_t>>("padding_after")
.Attr<std::vector<int32_t>>("padding")
.Attr<std::string>("data_format")
.Attr<std::vector<int32_t>>("kernel_size")
.Attr<std::vector<int32_t>>("stride")
......@@ -202,9 +223,7 @@ REGISTER_USER_OP("maxpool_3d_grad")
.Input("indice")
.Input("dy")
.Output("dx")
.Attr<std::string>("padding")
.Attr<std::vector<int32_t>>("padding_before")
.Attr<std::vector<int32_t>>("padding_after")
.Attr<std::vector<int32_t>>("padding")
.Attr<std::string>("data_format")
.Attr<std::vector<int32_t>>("kernel_size")
.Attr<std::vector<int32_t>>("stride")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment