diff --git a/oneflow/core/autograd/gradient_funcs/pooling.cpp b/oneflow/core/autograd/gradient_funcs/pooling.cpp index a5d6c8c3737e221741859de4e486b2463bb8a57a..488aaa9c77339bfe3cf079b6a836da6216bf8f5c 100644 --- a/oneflow/core/autograd/gradient_funcs/pooling.cpp +++ b/oneflow/core/autograd/gradient_funcs/pooling.cpp @@ -33,9 +33,7 @@ struct PoolingInterpState : public OpExprInterpState { size_t indice_index; std::string data_format; - std::string padding; - std::vector<int32_t> padding_before; - std::vector<int32_t> padding_after; + std::vector<int32_t> padding; std::vector<int32_t> kernel_size; std::vector<int32_t> stride; std::vector<int32_t> dilation; @@ -76,9 +74,7 @@ Maybe<void> PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& i ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format")); - ctx->padding = JUST(composed_attrs.GetAttr<std::string>("padding")); - ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before")); - ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_after")); + ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding")); ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size")); ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride")); ctx->dilation = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation")); @@ -100,8 +96,7 @@ Maybe<void> PoolingNdGrad::Apply(const PoolingInterpState* ctx, const TensorTupl in_grads->resize(1); in_grads->at(0) = JUST(functional::PoolingNdGrad( input, output, indice, out_grads.at(0), mode_, ndims, ctx->data_format, ctx->padding, - ctx->padding_before, ctx->padding_after, ctx->kernel_size, ctx->stride, ctx->dilation, - ctx->return_indices, ctx->ceil_mode)); + ctx->kernel_size, ctx->stride, ctx->dilation, ctx->return_indices, ctx->ceil_mode)); return Maybe<void>::Ok(); } @@ -113,6 +108,7 @@ class MaxpoolNdGrad final : public PoolingNdGrad { Maybe<void> Init(const OpExpr& op) override { return PoolingNdGrad::Init(op, "max"); } }; +REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_1d", MaxpoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_2d", MaxpoolNdGrad); REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_3d", MaxpoolNdGrad); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 5c1a6f5ab320936f275bb75426503ca7577e8521..f24b6a5dc654af47e9f6a1371318b63c17c75433 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -411,26 +411,31 @@ Int32List strides, Bool ceil_mode)" bind_python: False +- name: "maxpool_1d" + signature: + "TensorTuple Maxpool1D(Tensor x, *, String data_format=\"channels_first\", Int32List padding, + Int32List kernel_size, Int32List stride, Int32List dilation, + Bool return_indices=True, Bool ceil_mode=False)" + bind_python: True + - name: "maxpool_2d" signature: - "TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", String padding, - Int32List padding_before, Int32List padding_after, + "TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", Int32List padding, Int32List kernel_size, Int32List stride, Int32List dilation, Bool return_indices=True, Bool ceil_mode=False)" bind_python: True - name: "maxpool_3d" signature: - "TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", String padding, - Int32List padding_before, Int32List padding_after, + "TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", Int32List padding, Int32List kernel_size, Int32List stride, Int32List dilation, Bool return_indices=True, Bool ceil_mode=False)" bind_python: True - name: "pooling_grad" signature: - "Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, String data_format, - String padding, Int32List padding_before, Int32List padding_after, Int32List kernel_size, + "Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, + String data_format, Int32List padding, Int32List kernel_size, Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode)" bind_python: False diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 535b301820954039b0ec083863780bd904cb3910..54267dfdb72d24b7de484bf3e4bdec95b2edb94c 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -232,18 +232,14 @@ class PoolingNDFunctor { PoolingNDFunctor() = default; virtual ~PoolingNDFunctor() = default; Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, - const std::string& data_format, const std::string& padding, - const std::vector<int32_t>& padding_before, - const std::vector<int32_t>& padding_after, + const std::string& data_format, const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation, const bool& return_indices, const bool& ceil_mode) const { MutableAttrMap attrs; - JUST(attrs.SetAttr<std::string>("padding", padding)); - JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before)); - JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after)); JUST(attrs.SetAttr<std::string>("data_format", data_format)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding)); JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size)); JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride)); JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation)); @@ -270,6 +266,13 @@ class MaxPool2DFunctor : public PoolNDFunctor { } }; +class Maxpool1DFunctor : public PoolingNDFunctor { + public: + Maxpool1DFunctor() { + op_ = CHECK_JUST(one::OpBuilder("maxpool_1d").Input("x").Output("y").Output("indice").Build()); + } +}; + class Maxpool2DFunctor : public PoolingNDFunctor { public: Maxpool2DFunctor() { @@ -472,6 +475,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::LayerNormFunctor>("LayerNorm"); m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine"); m.add_functor<impl::AvgPool2DFunctor>("AvgPool2D"); + m.add_functor<impl::Maxpool1DFunctor>("Maxpool1D"); m.add_functor<impl::Maxpool2DFunctor>("Maxpool2D"); m.add_functor<impl::Maxpool3DFunctor>("Maxpool3D"); m.add_functor<impl::MaxPool2DFunctor>("MaxPool2D"); diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index fa6110e84bd73abbee04a3f32af5e460be7a1c36..5ab9a9716d9a10323adfba0d9ab4d4b949d09cb9 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -113,7 +113,7 @@ class PoolingNdGradFunctor { public: PoolingNdGradFunctor() { for (const auto& mode : {"max"}) { - for (int ndims = 2; ndims <= 3; ++ndims) { + for (int ndims = 1; ndims <= 3; ++ndims) { const auto& op_type_name = GetOpTypeName(mode, ndims); op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name) .Input("x") @@ -133,16 +133,13 @@ class PoolingNdGradFunctor { const std::shared_ptr<one::Tensor>& indice, const std::shared_ptr<one::Tensor>& dy, const std::string& mode, const int32_t& ndims, const std::string& data_format, - const std::string& padding, const std::vector<int32_t>& padding_before, - const std::vector<int32_t>& padding_after, + const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation, const bool& return_indices, const bool& ceil_mode) const { MutableAttrMap attrs; - JUST(attrs.SetAttr<std::string>("padding", padding)); - JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before)); - JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after)); JUST(attrs.SetAttr<std::string>("data_format", data_format)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding)); JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size)); JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride)); JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation)); diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py index 57de40a2cbb4fc10f614eb51e3b8e512a5801695..b183d196283efdc530cd926e3db53650ed231635 100644 --- a/oneflow/python/nn/modules/pooling.py +++ b/oneflow/python/nn/modules/pooling.py @@ -316,44 +316,26 @@ class MaxPool1d(Module): ceil_mode: bool = False, ): super().__init__() - self.kernel_size = _getint(kernel_size) - self.stride = _getint(stride) if stride is not None else self.kernel_size - data_format = "NCL" # Only suport "NCL" for now! + self.kernel_size = _single(kernel_size) + self.stride = _single(stride) if stride is not None else self.kernel_size + data_format = "NCL" # only support "NCL" for now ! self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last" - self.dilation = _getint(dilation) - self.padding = _getint(padding) + self.dilation = _single(dilation) + self.padding = _single(padding) self.return_indices = return_indices self.ceil_mode = ceil_mode - if self.channel_pos == "channels_first": - padding = (0, 0, self.padding, 0) - else: - raise ValueError("error padding param!") - - self.padding_type, pads_list = calc_pool_padding( - padding, get_dhw_offset(self.channel_pos), 2 - ) - self.padding_before = [pad[0] for pad in pads_list] - self.padding_after = [pad[1] for pad in pads_list] - def forward(self, x): - expand_x = x.unsqueeze(dim=-1) - - expand_y, expand_indice = flow.F.maxpool_2d( - expand_x, + y, indice = flow.F.maxpool_1d( + x, data_format=self.channel_pos, - padding=self.padding_type, - padding_before=self.padding_before, - padding_after=self.padding_after, - kernel_size=[self.kernel_size, 1], - stride=[self.stride, 1], - dilation=[self.dilation, 1], + padding=self.padding, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, return_indices=True, ceil_mode=self.ceil_mode, ) - - y = expand_y.squeeze(dim=-1) - indice = expand_indice.squeeze(dim=-1) if self.return_indices: return y, indice else: @@ -454,45 +436,27 @@ class MaxPool2d(Module): ): super().__init__() self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size) - data_format = "NCHW" # Only suport "NCHW" for now! + data_format = "NCHW" # only support "NCHW" for now ! self.channel_pos = ( "channels_first" if data_format == "NCHW" else "channels_last" ) + self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size) self.dilation = _GetSequence(dilation, 2, "dilation") self.return_indices = return_indices self.ceil_mode = ceil_mode - - padding = _pair(padding) - self.padding = padding - if len(padding) == 2: - if data_format == "NCHW": - padding = (0, 0, padding[0], padding[1]) - else: - raise ValueError("error padding param!") - else: - raise ValueError("error padding param!") - - self.padding_type, pads_list = calc_pool_padding( - padding, get_dhw_offset(self.channel_pos), 2 - ) - self.padding_before = [pad[0] for pad in pads_list] - self.padding_after = [pad[1] for pad in pads_list] + self.padding = _pair(padding) def forward(self, x): y, indice = flow.F.maxpool_2d( x, data_format=self.channel_pos, - padding=self.padding_type, - padding_before=self.padding_before, - padding_after=self.padding_after, + padding=self.padding, kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, return_indices=True, ceil_mode=self.ceil_mode, ) - if self.return_indices: return y, indice else: @@ -605,32 +569,15 @@ class MaxPool3d(Module): "channels_last" if data_format == "NDHWC" else "channels_first" ) self.dilation = _GetSequence(dilation, 3, "dilation") - padding = _triple(padding) - self.padding = padding + self.padding = _triple(padding) self.return_indices = return_indices self.ceil_mode = ceil_mode - if len(padding) == 3: - if data_format == "NCDHW": - padding = (0, 0, padding[0], padding[1], padding[2]) - else: - raise ValueError("error padding param!") - else: - raise ValueError("error padding param!") - - self.padding_type, pads_list = calc_pool_padding( - padding, get_dhw_offset(self.channel_pos), 3 - ) - self.padding_before = [pad[0] for pad in pads_list] - self.padding_after = [pad[1] for pad in pads_list] - def forward(self, x): y, indice = flow.F.maxpool_3d( x, data_format=self.channel_pos, - padding=self.padding_type, - padding_before=self.padding_before, - padding_after=self.padding_after, + padding=self.padding, kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, diff --git a/oneflow/user/kernels/pooling_kernel.cpp b/oneflow/user/kernels/pooling_kernel.cpp index 39213ea3b42db8451331349f978f58933fff8eff..2f678241f327c4825861116834c2c17b88312fc4 100644 --- a/oneflow/user/kernels/pooling_kernel.cpp +++ b/oneflow/user/kernels/pooling_kernel.cpp @@ -26,31 +26,46 @@ struct PoolingOpKernelState final : public user_op::OpKernelState { std::shared_ptr<PoolingOpKernelState> DoCreateOpKernelState(user_op::KernelComputeContext* ctx, const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); - const std::string& padding = ctx->Attr<std::string>("padding"); const std::string& data_format = ctx->Attr<std::string>("data_format"); - const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); - const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding"); const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size"); const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride"); const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation"); const bool return_indices = ctx->Attr<bool>("return_indices"); const bool ceil_mode = ctx->Attr<bool>("ceil_mode"); - PoolingParams3D params_3d = - PoolingParams3D(dim, x_shape, data_format, padding, padding_before, padding_after, - kernel_size, stride, dilation, return_indices, ceil_mode); + PoolingParams3D params_3d = PoolingParams3D(dim, x_shape, data_format, padding, kernel_size, + stride, dilation, return_indices, ceil_mode); std::shared_ptr<PoolingOpKernelState> state(new PoolingOpKernelState(params_3d)); return std::move(state); } template<typename T> struct PoolingKernelUtil<DeviceType::kCPU, T> { + static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d) { + Maxpool1dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, + params_3d.padding()[2], params_3d.num_batch(), + params_3d.num_channel(), params_3d.GetXShape5D().At(4), + params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[2], + params_3d.stride_3d()[2], params_3d.dilation_3d()[2]); + } + + static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d) { + Maxpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, + params_3d.num_batch(), params_3d.num_channel(), + params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4)); + } + static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, const PoolingParams3D& params_3d) { - Maxpool2dFarwardCompute<T>( - index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1], - params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(), + Maxpool2dForwardCompute<T>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1], + params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], @@ -69,9 +84,9 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> { static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, const PoolingParams3D& params_3d) { - Maxpool3dFarwardCompute<T>( - index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0], - params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], params_3d.num_batch(), + Maxpool3dForwardCompute<T>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0], + params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0], @@ -91,6 +106,68 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> { } }; +template<DeviceType device_type, typename T> +class MaxPool1dKernel final : public user_op::OpKernel { + public: + MaxPool1dKernel() = default; + ~MaxPool1dKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); + + const auto& pooling_state = DoCreateOpKernelState(ctx, 1); + const PoolingParams3D& params_3d = pooling_state->GetParams3D(); + + const int64_t elem_num = y->shape().elem_cnt(); + const T* src = x->dptr<T>(); + T* dest = y->mut_dptr<T>(); + int64_t* indice_ptr = indice->mut_dptr<int64_t>(); + + DimVector y_vector; + y->shape().ToDimVector(&y_vector); + NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data()); + + PoolingKernelUtil<device_type, T>::Maxpool1dForward(ctx->device_ctx(), index_helper, elem_num, + src, dest, indice_ptr, params_3d); + }; +}; + +template<DeviceType device_type, typename T> +class MaxPool1dGradKernel final : public user_op::OpKernel { + public: + MaxPool1dGradKernel() = default; + ~MaxPool1dGradKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + + const auto& pooling_state = DoCreateOpKernelState(ctx, 1); + const PoolingParams3D& params_3d = pooling_state->GetParams3D(); + + const int64_t elem_num = dy->shape().elem_cnt(); + const T* src = dy->dptr<T>(); + const int64_t* indice_ptr = indice->dptr<int64_t>(); + T* dest = dx->mut_dptr<T>(); + DimVector dy_vector; + dy->shape().ToDimVector(&dy_vector); + NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data()); + + size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type()); + Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); + + PoolingKernelUtil<device_type, T>::Maxpool1dBackward(ctx->device_ctx(), index_helper, elem_num, + src, dest, indice_ptr, params_3d); + }; +}; + template<DeviceType device_type, typename T> class MaxPool2dKernel final : public user_op::OpKernel { public: @@ -217,6 +294,14 @@ class MaxPool3dGradKernel final : public user_op::OpKernel { }; #define REGISTER_POOLING_KERNELS(device, dtype) \ + REGISTER_USER_KERNEL("maxpool_1d") \ + .SetCreateFn<MaxPool1dKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("maxpool_1d_grad") \ + .SetCreateFn<MaxPool1dGradKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \ REGISTER_USER_KERNEL("maxpool_2d") \ .SetCreateFn<MaxPool2dKernel<device, dtype>>() \ .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ diff --git a/oneflow/user/kernels/pooling_kernel.cu b/oneflow/user/kernels/pooling_kernel.cu index 7aef65319fe28441f044e5a1e819ca7779bcc295..5f6b4d42eea293bb54cccd5ee4272b306594a564 100644 --- a/oneflow/user/kernels/pooling_kernel.cu +++ b/oneflow/user/kernels/pooling_kernel.cu @@ -30,6 +30,17 @@ int GetNumBlocks(int64_t elem_cnt) { return num_blocks; } +template<typename T> +__launch_bounds__(kBlockSize) __global__ + void DoCUDAMaxPool1dForward(const NdIndexOffsetHelper<int64_t, 3> index_helper, + int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + int32_t padding_l, int64_t n_batch, int64_t n_channel, + int64_t x_length, int64_t y_length, int32_t kernel_size_l, + int32_t stride_l, int32_t dilation_l) { + Maxpool1dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_l, n_batch, + n_channel, x_length, y_length, kernel_size_l, stride_l, dilation_l); +}; + template<typename T> __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool2dForward(const NdIndexOffsetHelper<int64_t, 4> index_helper, @@ -39,7 +50,7 @@ __launch_bounds__(kBlockSize) __global__ int64_t y_height, int64_t y_width, int32_t kernel_size_h, int32_t kernel_size_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w) { - Maxpool2dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w, + Maxpool2dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w, n_batch, n_channel, x_height, x_width, y_height, y_width, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w); @@ -56,12 +67,23 @@ __launch_bounds__(kBlockSize) __global__ int32_t kernel_size_w, int32_t stride_t, int32_t stride_h, int32_t stride_w, int32_t dilation_t, int32_t dilation_h, int32_t dilation_w) { - Maxpool3dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h, + Maxpool3dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h, padding_w, n_batch, n_channel, x_time, x_height, x_width, y_time, y_height, y_width, kernel_size_t, kernel_size_h, kernel_size_w, stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w); }; +template<typename T> +__launch_bounds__(kBlockSize) __global__ + void DoCUDAMaxPool1dBackward(const NdIndexOffsetHelper<int64_t, 3> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const int64_t n_batch, + const int64_t n_channel, const int64_t src_length, + const int64_t dst_length) { + Maxpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, + src_length, dst_length); +}; + template<typename T> __launch_bounds__(kBlockSize) __global__ void DoCUDAMaxPool2dBackward(const NdIndexOffsetHelper<int64_t, 4> index_helper, @@ -89,13 +111,33 @@ __launch_bounds__(kBlockSize) __global__ template<typename T> struct PoolingKernelUtil<DeviceType::kGPU, T> { + static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d) { + DoCUDAMaxPool1dForward<T> + <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2], + params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4), + params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[2], params_3d.stride_3d()[2], + params_3d.dilation_3d()[2]); + } + + static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d) { + DoCUDAMaxPool1dBackward<T> + <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), + params_3d.num_channel(), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4)); + } + static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, const PoolingParams3D& params_3d) { DoCUDAMaxPool2dForward<T> <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( - index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1], - params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(), + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1], + params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], @@ -118,15 +160,15 @@ struct PoolingKernelUtil<DeviceType::kGPU, T> { const PoolingParams3D& params_3d) { DoCUDAMaxPool3dForward<T> <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( - index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0], - params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], - params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), - params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), - params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), - params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0], - params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], - params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], - params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0], + params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(), + params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), + params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2), + params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), + params_3d.pooling_size_3d()[0], params_3d.pooling_size_3d()[1], + params_3d.pooling_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1], + params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], + params_3d.dilation_3d()[2]); } static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, diff --git a/oneflow/user/kernels/pooling_kernel_util.cpp b/oneflow/user/kernels/pooling_kernel_util.cpp index c2f502ce2928c8fec6c60ccc5c288d1c074fb275..8da50e3b40abace6a31bc7350c59a919d34d0abb 100644 --- a/oneflow/user/kernels/pooling_kernel_util.cpp +++ b/oneflow/user/kernels/pooling_kernel_util.cpp @@ -43,19 +43,36 @@ std::vector<int32_t> Get3DPadVec(const std::vector<int32_t>& original_vec, int32 return vec; } +void GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride, + int32_t padding, bool ceil_mode, int32_t dilation_rate, + int64_t* output_ptr) { + *output_ptr = (input_size + 2 * padding - dilation_rate * (filter_size - 1) - 1 + stride + + (ceil_mode ? stride - 1 : 0)) + / stride; +} + +void Get3DOutputShape(const DimVector& in, const std::vector<int32_t>& pool_size, + const std::vector<int32_t>& strides, const std::vector<int32_t>& padding, + const bool ceil_mode, std::vector<int32_t> dilation_rate, DimVector* out) { + out->clear(); + out->resize(3); + FOR_RANGE(size_t, i, 0, 3) { + int64_t* out_ptr = &(*out).at(i); + GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), ceil_mode, + dilation_rate.at(i), out_ptr); + } +} + PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape, - const std::string& data_format, const std::string& padding, - const std::vector<int32_t>& padding_before, - const std::vector<int32_t>& padding_after, + const std::string& data_format, + const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation, const bool return_indices, const bool ceil_mode) : dim_(dim), data_format_(data_format), - padding_(padding), - padding_before_3d_(Get3DPadVec(padding_before, dim)), - padding_after_3d_(Get3DPadVec(padding_after, dim)), + padding_(Get3DPadVec(padding, dim)), pooling_size_3d_(Get3DVec(kernel_size, dim)), stride_3d_(Get3DVec(stride, dim)), dilation_3d_(Get3DVec(dilation, dim)), @@ -63,8 +80,7 @@ PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape, ceil_mode_(ceil_mode) { x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim), GetInDim(x_shape, data_format, 2, dim)}; - Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_, - &padding_before_3d_, &padding_after_3d_); + Get3DOutputShape(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_); if (data_format == "channels_first") { channel_num_ = x_shape.At(1); } else { @@ -78,8 +94,7 @@ PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape, void PoolingParams3D::Reset(const ShapeView& x_shape) { x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_), GetInDim(x_shape, data_format_, 2, dim_)}; - Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_, - &padding_before_3d_, &padding_after_3d_); + Get3DOutputShape(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_); } Shape PoolingParams3D::GetYShape() const { diff --git a/oneflow/user/kernels/pooling_kernel_util.h b/oneflow/user/kernels/pooling_kernel_util.h index b25ba34e04a3ce8c5027948965c686c69e1134f9..0d126ea63fa066469151ed708604ecc29fc4abd1 100644 --- a/oneflow/user/kernels/pooling_kernel_util.h +++ b/oneflow/user/kernels/pooling_kernel_util.h @@ -53,16 +53,13 @@ struct DeviceAdd { class PoolingParams3D { public: PoolingParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, - const std::string& padding, const std::vector<int32_t>& padding_before, - const std::vector<int32_t>& padding_after, - const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride, - const std::vector<int32_t>& dilation, const bool return_indices, - const bool ceil_mode); + const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size, + const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation, + const bool return_indices, const bool ceil_mode); ~PoolingParams3D() = default; const std::string& data_format() const { return data_format_; } - const std::vector<int32_t>& padding_before_3d() const { return padding_before_3d_; } - const std::vector<int32_t>& padding_after_3d() const { return padding_after_3d_; } + const std::vector<int32_t>& padding() const { return padding_; } const std::vector<int32_t>& pooling_size_3d() const { return pooling_size_3d_; } const std::vector<int32_t>& stride_3d() const { return stride_3d_; } const std::vector<int32_t>& dilation_3d() const { return dilation_3d_; } @@ -81,9 +78,7 @@ class PoolingParams3D { FixedDimVector x_3d_; FixedDimVector y_3d_; std::string data_format_; - std::string padding_; - std::vector<int32_t> padding_before_3d_; - std::vector<int32_t> padding_after_3d_; + std::vector<int32_t> padding_; std::vector<int32_t> pooling_size_3d_; std::vector<int32_t> stride_3d_; std::vector<int32_t> dilation_3d_; @@ -95,6 +90,14 @@ class PoolingParams3D { template<DeviceType device_type, typename T> struct PoolingKernelUtil { + static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d); + + static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d); + static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, const PoolingParams3D& params_3d); @@ -113,7 +116,70 @@ struct PoolingKernelUtil { }; template<typename T> -OF_DEVICE_FUNC void Maxpool2dFarwardCompute( +OF_DEVICE_FUNC void Maxpool1dForwardCompute(const NdIndexOffsetHelper<int64_t, 3> index_helper, + int64_t elem_num, const T* src, T* dest, + int64_t* indice_ptr, const int32_t padding_l, + const int64_t n_batch, const int64_t n_channel, + const int64_t x_length, const int64_t y_length, + const int32_t kernel_size_l, const int32_t stride_l, + const int32_t dilation_l) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, l; + index_helper.OffsetToNdIndex(num, n, c, l); + + // n, c, l->index = n*c*l + c* l + const int64_t start_idx = (n * n_channel + c) * x_length; + int64_t lstart = l * stride_l - padding_l; + const int64_t lend = (lstart + (kernel_size_l - 1) * dilation_l + 1) <= x_length + ? (lstart + (kernel_size_l - 1) * dilation_l + 1) + : x_length; + + while (lstart < 0) { lstart += dilation_l; } + + /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */ + int64_t maxindex = lstart; + int64_t src_idx = 0; + + /* equal to -std::numeric_limits<T>::infinity(); */ + T max_value = detail::numeric_limits<T>::lower_bound(); + + for (int64_t idx = lstart; idx < lend; idx += dilation_l) { + const int64_t search_idx = start_idx + idx; + T val = src[search_idx]; + if (val > max_value || detail::numerics<T>::isnan(val)) { + max_value = val; + maxindex = idx; + src_idx = search_idx; + } + } + dest[num] = src[src_idx]; + indice_ptr[num] = maxindex; + } +} + +template<typename T> +OF_DEVICE_FUNC void Maxpool1dBackwardCompute(const NdIndexOffsetHelper<int64_t, 3> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const int64_t n_batch, + const int64_t n_channel, const int64_t src_length, + const int64_t dst_length) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, l; + index_helper.OffsetToNdIndex(num, n, c, l); + + const int64_t src_start = (n * n_channel + c) * src_length; + const int64_t dst_start = (n * n_channel + c) * dst_length; + const int64_t index = src_start + l; + const int64_t maxindex = dst_start + indice_ptr[index]; + if (maxindex != -1) { + /* update gradient, equals to dest[maxindex] += src[index]; */ + DeviceAdd<T>::Invoke(src + index, dest + maxindex); + } + } +} + +template<typename T> +OF_DEVICE_FUNC void Maxpool2dForwardCompute( const NdIndexOffsetHelper<int64_t, 4> index_helper, int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch, const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height, @@ -194,7 +260,7 @@ OF_DEVICE_FUNC void Maxpool2dBackwardCompute(const NdIndexOffsetHelper<int64_t, } template<typename T> -OF_DEVICE_FUNC void Maxpool3dFarwardCompute( +OF_DEVICE_FUNC void Maxpool3dForwardCompute( const NdIndexOffsetHelper<int64_t, 5> index_helper, int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch, const int64_t n_channel, const int64_t x_time, const int64_t x_height, diff --git a/oneflow/user/ops/pooling_op.cpp b/oneflow/user/ops/pooling_op.cpp index 03896b002c1537ae929d693ffa24ccc4592dabc4..adfbf70be83d39ccd4c50c9d7b7f5556499734ef 100644 --- a/oneflow/user/ops/pooling_op.cpp +++ b/oneflow/user/ops/pooling_op.cpp @@ -28,9 +28,7 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe<void> { const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0); const std::string& data_format = ctx->Attr<std::string>("data_format"); - const std::string& padding = ctx->Attr<std::string>("padding"); - const std::vector<int32_t>& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); - const std::vector<int32_t>& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding"); const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size"); const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride"); const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation"); @@ -41,14 +39,13 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) { for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); } CHECK_EQ_OR_RETURN(stride.size(), dim); for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); } - for (int32_t i = 0; i < padding_after.size(); i++) { - CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding_after[i]) + for (int32_t i = 0; i < padding.size(); i++) { + CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i]) << "pad should be smaller than half of kernel size"; } - const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, padding_before, - padding_after, kernel_size, stride, dilation, return_indices, - ceil_mode); + const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride, + dilation, return_indices, ceil_mode); user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0); *y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0); *y_desc->mut_shape() = params_3d.GetYShape(); @@ -64,10 +61,9 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) { Maybe<void> ForwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); - const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding"); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { - if (padding_before[i] == 0 && padding_after[i] == 0) { + if (padding[i] == 0) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) @@ -85,10 +81,9 @@ Maybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) { Maybe<void> BackwardGetSbpFn(user_op::SbpContext* ctx) { const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); - const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding"); FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { - if (padding_before[i] == 0 && padding_after[i] == 0) { + if (padding[i] == 0) { ctx->NewBuilder() .Split(user_op::OpArg("x", 0), i) .Split(user_op::OpArg("y", 0), i) @@ -123,9 +118,7 @@ GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Output("dx") .Attr("data_format", op.attr<std::string>("data_format")) - .Attr("padding", op.attr<std::string>("padding")) - .Attr("padding_before", op.attr<std::vector<int32_t>>("padding_before")) - .Attr("padding_after", op.attr<std::vector<int32_t>>("padding_after")) + .Attr("padding", op.attr<std::vector<int32_t>>("padding")) .Attr("kernel_size", op.attr<std::vector<int32_t>>("kernel_size")) .Attr("stride", op.attr<std::vector<int32_t>>("stride")) .Attr("dilation", op.attr<std::vector<int32_t>>("dilation")) @@ -141,13 +134,45 @@ GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t } // namespace +REGISTER_USER_OP("maxpool_1d") + .Input("x") + .Output("y") + .Output("indice") + .Attr<std::vector<int32_t>>("padding") + .Attr<std::string>("data_format") + .Attr<std::vector<int32_t>>("kernel_size") + .Attr<std::vector<int32_t>>("stride") + .Attr<std::vector<int32_t>>("dilation") + .Attr<bool>("return_indices") + .Attr<bool>("ceil_mode") + .SetTensorDescInferFn(MakeForwardTensorDescInferFn(1)) + .SetGetSbpFn(ForwardGetSbpFn) + .SetDataTypeInferFn(FwInferDataType); + +REGISTER_USER_OP("maxpool_1d_grad") + .Input("x") + .Input("y") + .Input("indice") + .Input("dy") + .Output("dx") + .Attr<std::vector<int32_t>>("padding") + .Attr<std::string>("data_format") + .Attr<std::vector<int32_t>>("kernel_size") + .Attr<std::vector<int32_t>>("stride") + .Attr<std::vector<int32_t>>("dilation") + .Attr<bool>("return_indices") + .Attr<bool>("ceil_mode") + .SetTensorDescInferFn(BackwardTensorDescInferFn) + .SetGetSbpFn(BackwardGetSbpFn) + .SetDataTypeInferFn(BwInferDataType); + +REGISTER_USER_OP_GRAD("maxpool_1d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 1)); + REGISTER_USER_OP("maxpool_2d") .Input("x") .Output("y") .Output("indice") - .Attr<std::string>("padding") - .Attr<std::vector<int32_t>>("padding_before") - .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::vector<int32_t>>("padding") .Attr<std::string>("data_format") .Attr<std::vector<int32_t>>("kernel_size") .Attr<std::vector<int32_t>>("stride") @@ -164,9 +189,7 @@ REGISTER_USER_OP("maxpool_2d_grad") .Input("indice") .Input("dy") .Output("dx") - .Attr<std::string>("padding") - .Attr<std::vector<int32_t>>("padding_before") - .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::vector<int32_t>>("padding") .Attr<std::string>("data_format") .Attr<std::vector<int32_t>>("kernel_size") .Attr<std::vector<int32_t>>("stride") @@ -183,9 +206,7 @@ REGISTER_USER_OP("maxpool_3d") .Input("x") .Output("y") .Output("indice") - .Attr<std::string>("padding") - .Attr<std::vector<int32_t>>("padding_before") - .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::vector<int32_t>>("padding") .Attr<std::string>("data_format") .Attr<std::vector<int32_t>>("kernel_size") .Attr<std::vector<int32_t>>("stride") @@ -202,9 +223,7 @@ REGISTER_USER_OP("maxpool_3d_grad") .Input("indice") .Input("dy") .Output("dx") - .Attr<std::string>("padding") - .Attr<std::vector<int32_t>>("padding_before") - .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::vector<int32_t>>("padding") .Attr<std::string>("data_format") .Attr<std::vector<int32_t>>("kernel_size") .Attr<std::vector<int32_t>>("stride")