diff --git a/oneflow/core/autograd/gradient_funcs/pooling.cpp b/oneflow/core/autograd/gradient_funcs/pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5d6c8c3737e221741859de4e486b2463bb8a57a --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/pooling.cpp @@ -0,0 +1,120 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +namespace { + +struct PoolingInterpState : public OpExprInterpState { + bool requires_grad; + size_t input_index; + size_t output_index; + size_t indice_index; + + std::string data_format; + std::string padding; + std::vector<int32_t> padding_before; + std::vector<int32_t> padding_after; + std::vector<int32_t> kernel_size; + std::vector<int32_t> stride; + std::vector<int32_t> dilation; + bool return_indices; + bool ceil_mode; +}; + +class PoolingNdGrad : public OpExprGradFunction<PoolingInterpState> { + public: + virtual ~PoolingNdGrad() = default; + Maybe<void> Init(const OpExpr& op, const std::string& mode); + Maybe<void> Capture(PoolingInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const PoolingInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + std::string mode_; + AttrMap base_attrs_; +}; + +Maybe<void> PoolingNdGrad::Init(const OpExpr& op, const std::string& mode) { + const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + mode_ = mode; + return Maybe<void>::Ok(); +} + +Maybe<void> PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe<void>::Ok(); } + + ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); + ctx->output_index = ctx->SaveTensorForBackward(outputs.at(0)); + ctx->indice_index = ctx->SaveTensorForBackward(outputs.at(1)); + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format")); + ctx->padding = JUST(composed_attrs.GetAttr<std::string>("padding")); + ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before")); + ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_after")); + ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size")); + ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride")); + ctx->dilation = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation")); + ctx->return_indices = JUST(composed_attrs.GetAttr<bool>("return_indices")); + ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>("ceil_mode")); + return Maybe<void>::Ok(); +} + +Maybe<void> PoolingNdGrad::Apply(const PoolingInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if (!ctx->requires_grad) { return Maybe<void>::Ok(); } + CHECK_LE_OR_RETURN(out_grads.size(), 2); + + int32_t ndims = ctx->kernel_size.size(); + const auto& input = ctx->SavedTensors().at(ctx->input_index); + const auto& output = ctx->SavedTensors().at(ctx->output_index); + const auto& indice = ctx->SavedTensors().at(ctx->indice_index); + + in_grads->resize(1); + in_grads->at(0) = JUST(functional::PoolingNdGrad( + input, output, indice, out_grads.at(0), mode_, ndims, ctx->data_format, ctx->padding, + ctx->padding_before, ctx->padding_after, ctx->kernel_size, ctx->stride, ctx->dilation, + ctx->return_indices, ctx->ceil_mode)); + + return Maybe<void>::Ok(); +} + +} // namespace + +class MaxpoolNdGrad final : public PoolingNdGrad { + public: + Maybe<void> Init(const OpExpr& op) override { return PoolingNdGrad::Init(op, "max"); } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_2d", MaxpoolNdGrad); +REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_3d", MaxpoolNdGrad); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 979f9556243cba135083f607987050ade3c0cc73..c422c2adfc90e3adf9f9744039c4e96a0170e0b8 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -403,6 +403,29 @@ Int32List strides, Bool ceil_mode)" bind_python: False +- name: "maxpool_2d" + signature: + "TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", String padding, + Int32List padding_before, Int32List padding_after, + Int32List kernel_size, Int32List stride, Int32List dilation, + Bool return_indices=True, Bool ceil_mode=False)" + bind_python: True + +- name: "maxpool_3d" + signature: + "TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", String padding, + Int32List padding_before, Int32List padding_after, + Int32List kernel_size, Int32List stride, Int32List dilation, + Bool return_indices=True, Bool ceil_mode=False)" + bind_python: True + +- name: "pooling_grad" + signature: + "Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, String data_format, + String padding, Int32List padding_before, Int32List padding_after, Int32List kernel_size, + Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode)" + bind_python: False + - name: "prelu" signature: "Tensor PRelu(Tensor x, Tensor alpha)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index da5619afdafd4ea6ae219d8f53c7453ceeb10747..d181e21af21e2002a87702c56c45db5978d2afb7 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -227,6 +227,35 @@ class PoolNDFunctor { std::shared_ptr<OpExpr> op_; }; +class PoolingNDFunctor { + public: + PoolingNDFunctor() = default; + virtual ~PoolingNDFunctor() = default; + Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, + const std::string& data_format, const std::string& padding, + const std::vector<int32_t>& padding_before, + const std::vector<int32_t>& padding_after, + const std::vector<int32_t>& kernel_size, + const std::vector<int32_t>& stride, + const std::vector<int32_t>& dilation, const bool& return_indices, + const bool& ceil_mode) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<std::string>("padding", padding)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after)); + JUST(attrs.SetAttr<std::string>("data_format", data_format)); + JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size)); + JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride)); + JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation)); + JUST(attrs.SetAttr<bool>("return_indices", return_indices)); + JUST(attrs.SetAttr<bool>("ceil_mode", ceil_mode)); + return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs); + } + + protected: + std::shared_ptr<OpExpr> op_; +}; + class AvgPool2DFunctor : public PoolNDFunctor { public: AvgPool2DFunctor() { @@ -241,6 +270,20 @@ class MaxPool2DFunctor : public PoolNDFunctor { } }; +class Maxpool2DFunctor : public PoolingNDFunctor { + public: + Maxpool2DFunctor() { + op_ = CHECK_JUST(one::OpBuilder("maxpool_2d").Input("x").Output("y").Output("indice").Build()); + } +}; + +class Maxpool3DFunctor : public PoolingNDFunctor { + public: + Maxpool3DFunctor() { + op_ = CHECK_JUST(one::OpBuilder("maxpool_3d").Input("x").Output("y").Output("indice").Build()); + } +}; + class SparseSoftmaxCrossEntropyFunctor { public: SparseSoftmaxCrossEntropyFunctor() { @@ -420,6 +463,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::LayerNormFunctor>("LayerNorm"); m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine"); m.add_functor<impl::AvgPool2DFunctor>("AvgPool2D"); + m.add_functor<impl::Maxpool2DFunctor>("Maxpool2D"); + m.add_functor<impl::Maxpool3DFunctor>("Maxpool3D"); m.add_functor<impl::MaxPool2DFunctor>("MaxPool2D"); m.add_functor<impl::SparseSoftmaxCrossEntropyFunctor>("SparseSoftmaxCrossEntropy"); m.add_functor<impl::SmoothL1LossFunctor>("SmoothL1Loss"); diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 11c750d5e4e104bac054f3c3807d98b2c64cc280..b47f345d70bda1f4e46761f42409e9086a736143 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -109,6 +109,57 @@ class ConvDataGradFunctor { std::shared_ptr<OpExpr> op_; }; +class PoolingNdGradFunctor { + public: + PoolingNdGradFunctor() { + for (const auto& mode : {"max"}) { + for (int ndims = 2; ndims <= 3; ++ndims) { + const auto& op_type_name = GetOpTypeName(mode, ndims); + op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name) + .Input("x") + .Input("y") + .Input("indice") + .Input("dy") + .Output("dx") + .Build()); + } + } + } + static std::string GetOpTypeName(const std::string& mode, const int32_t& ndims) { + return mode + "pool_" + std::to_string(ndims) + "d_grad"; + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, + const std::shared_ptr<one::Tensor>& y, + const std::shared_ptr<one::Tensor>& indice, + const std::shared_ptr<one::Tensor>& dy, const std::string& mode, + const int32_t& ndims, const std::string& data_format, + const std::string& padding, const std::vector<int32_t>& padding_before, + const std::vector<int32_t>& padding_after, + const std::vector<int32_t>& kernel_size, + const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation, + const bool& return_indices, const bool& ceil_mode) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<std::string>("padding", padding)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after)); + JUST(attrs.SetAttr<std::string>("data_format", data_format)); + JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size)); + JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride)); + JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation)); + JUST(attrs.SetAttr<bool>("return_indices", return_indices)); + JUST(attrs.SetAttr<bool>("ceil_mode", ceil_mode)); + const auto& op_type_name = GetOpTypeName(mode, ndims); + const auto& it = op_expr_map_.find(op_type_name); + CHECK_OR_RETURN(it != op_expr_map_.end()) + << "Encounter unsupported op " << op_type_name << " in PoolingNdGradFunctor."; + CHECK_NOTNULL_OR_RETURN(it->second); + return OpInterpUtil::Dispatch<Tensor>(*it->second, {x, y, indice, dy}, attrs); + } + + protected: + std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_; +}; + class PoolNdGradFunctor { public: PoolNdGradFunctor() { @@ -225,6 +276,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::ConvDataGradFunctor>("ConvDataGrad"); m.add_functor<impl::PoolNdGradFunctor>("PoolNdGrad"); m.add_functor<impl::SmoothL1LossGradFunctor>("SmoothL1LossGrad"); + m.add_functor<impl::PoolingNdGradFunctor>("PoolingNdGrad"); m.add_functor<impl::PadGradFunctor>("PadGrad"); }; diff --git a/oneflow/core/kernel/util/numeric_limits.cuh b/oneflow/core/kernel/util/numeric_limits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7ef9e9e3ee0e1f5fbcd4cccaf59898a5ad3609ad --- /dev/null +++ b/oneflow/core/kernel/util/numeric_limits.cuh @@ -0,0 +1,128 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// reference: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/NumericLimits.cuh +#pragma once +#include <limits.h> +#include <math.h> +#include <float.h> + +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/framework/framework.h" + +// numeric_limits.cuh is a holder for numeric limits definitions of commonly used +// types. This header is very specific to ROCm HIP and may be removed in the future. + +// The lower_bound and upper_bound constants are same as lowest and max for +// integral types, but are -inf and +inf for floating point types. They are +// useful in implementing min, max, etc. + +namespace oneflow { +namespace detail { + +#if defined(__CUDACC__) +#define OF_NUMERICS_FUNC static inline __host__ __device__ +#else +#define OF_NUMERICS_FUNC static inline +#endif + +template<typename T> +struct numeric_limits {}; + +// WARNING: the following oneflow::numeric_limits definitions are there only to support +// HIP compilation for the moment. Use std::numeric_limits if you are not +// compiling for ROCm. +// from @colesbury: "The functions on numeric_limits aren't marked with +// __device__ which is why they don't work with ROCm. CUDA allows them +// because they're constexpr." + +namespace { +// ROCm doesn't like INFINITY too. +constexpr double inf = INFINITY; +} // namespace + +template<> +struct numeric_limits<bool> { + OF_NUMERICS_FUNC bool lowest() { return false; } + OF_NUMERICS_FUNC bool max() { return true; } + OF_NUMERICS_FUNC bool lower_bound() { return false; } + OF_NUMERICS_FUNC bool upper_bound() { return true; } +}; + +template<> +struct numeric_limits<uint8_t> { + OF_NUMERICS_FUNC uint8_t lowest() { return 0; } + OF_NUMERICS_FUNC uint8_t max() { return UINT8_MAX; } + OF_NUMERICS_FUNC uint8_t lower_bound() { return 0; } + OF_NUMERICS_FUNC uint8_t upper_bound() { return UINT8_MAX; } +}; + +template<> +struct numeric_limits<int8_t> { + OF_NUMERICS_FUNC int8_t lowest() { return INT8_MIN; } + OF_NUMERICS_FUNC int8_t max() { return INT8_MAX; } + OF_NUMERICS_FUNC int8_t lower_bound() { return INT8_MIN; } + OF_NUMERICS_FUNC int8_t upper_bound() { return INT8_MAX; } +}; + +template<> +struct numeric_limits<int16_t> { + OF_NUMERICS_FUNC int16_t lowest() { return INT16_MIN; } + OF_NUMERICS_FUNC int16_t max() { return INT16_MAX; } + OF_NUMERICS_FUNC int16_t lower_bound() { return INT16_MIN; } + OF_NUMERICS_FUNC int16_t upper_bound() { return INT16_MAX; } +}; + +template<> +struct numeric_limits<int32_t> { + OF_NUMERICS_FUNC int32_t lowest() { return INT32_MIN; } + OF_NUMERICS_FUNC int32_t max() { return INT32_MAX; } + OF_NUMERICS_FUNC int32_t lower_bound() { return INT32_MIN; } + OF_NUMERICS_FUNC int32_t upper_bound() { return INT32_MAX; } +}; + +template<> +struct numeric_limits<int64_t> { +#ifdef _MSC_VER + OF_NUMERICS_FUNC int64_t lowest() { return _I64_MIN; } + OF_NUMERICS_FUNC int64_t max() { return _I64_MAX; } + OF_NUMERICS_FUNC int64_t lower_bound() { return _I64_MIN; } + OF_NUMERICS_FUNC int64_t upper_bound() { return _I64_MAX; } +#else + OF_NUMERICS_FUNC int64_t lowest() { return INT64_MIN; } + OF_NUMERICS_FUNC int64_t max() { return INT64_MAX; } + OF_NUMERICS_FUNC int64_t lower_bound() { return INT64_MIN; } + OF_NUMERICS_FUNC int64_t upper_bound() { return INT64_MAX; } +#endif +}; + +template<> +struct numeric_limits<float> { + OF_NUMERICS_FUNC float lowest() { return -FLT_MAX; } + OF_NUMERICS_FUNC float max() { return FLT_MAX; } + OF_NUMERICS_FUNC float lower_bound() { return -static_cast<float>(inf); } + OF_NUMERICS_FUNC float upper_bound() { return static_cast<float>(inf); } +}; + +template<> +struct numeric_limits<double> { + OF_NUMERICS_FUNC double lowest() { return -DBL_MAX; } + OF_NUMERICS_FUNC double max() { return DBL_MAX; } + OF_NUMERICS_FUNC double lower_bound() { return -inf; } + OF_NUMERICS_FUNC double upper_bound() { return inf; } +}; + +} // namespace detail +} // namespace oneflow diff --git a/oneflow/core/kernel/util/numerics.cuh b/oneflow/core/kernel/util/numerics.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d0e46faa070feebded6603961dbe58bf87254199 --- /dev/null +++ b/oneflow/core/kernel/util/numerics.cuh @@ -0,0 +1,249 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// reference: https://github.com/pytorch/pytorch/blob/master/aten/src/THC/THCNumerics.cuh +#ifndef ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H +#define ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H +#pragma once + +#include <limits.h> +#include <math.h> +#include <float.h> +#include <cstdlib> +#include <assert.h> + +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/util/numeric_limits.cuh" + +namespace oneflow { +namespace detail { + +template<typename T> +struct numerics {}; + +template<typename T> +OF_NUMERICS_FUNC T powi(T a, T b) { + assert(numerics<T>::ge(b, 0)); + T result = 1; + while (b) { + if (b & 1) { result *= a; } + b /= 2; + a *= a; + } + return result; +} + +template<> +struct numerics<uint8_t> { + OF_NUMERICS_FUNC uint8_t min() { return detail::numeric_limits<uint8_t>::lowest(); } + OF_NUMERICS_FUNC uint8_t max() { return detail::numeric_limits<uint8_t>::max(); } + OF_NUMERICS_FUNC uint8_t lower_bound() { return detail::numeric_limits<uint8_t>::lower_bound(); } + OF_NUMERICS_FUNC uint8_t upper_bound() { return detail::numeric_limits<uint8_t>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(uint8_t a, uint8_t b) { return a < b; } + OF_NUMERICS_FUNC bool le(uint8_t a, uint8_t b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(uint8_t a, uint8_t b) { return a > b; } + OF_NUMERICS_FUNC bool ge(uint8_t a, uint8_t b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(uint8_t a, uint8_t b) { return a == b; } + OF_NUMERICS_FUNC bool ne(uint8_t a, uint8_t b) { return a != b; } + + OF_NUMERICS_FUNC uint8_t add(uint8_t a, uint8_t b) { return a + b; } + OF_NUMERICS_FUNC uint8_t mul(uint8_t a, uint8_t b) { return a * b; } + OF_NUMERICS_FUNC uint8_t sub(uint8_t a, uint8_t b) { return a - b; } + OF_NUMERICS_FUNC uint8_t div(uint8_t a, uint8_t b) { return a / b; } + OF_NUMERICS_FUNC uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); } + OF_NUMERICS_FUNC bool isnan(uint8_t a) { return false; } + OF_NUMERICS_FUNC bool isinf(uint8_t a) { return false; } +}; + +#ifdef _MSC_VER +// Suppress warning C4804: '/': unsafe use of type 'bool' in operation +#pragma warning(push) +#pragma warning(disable : 4804) +#endif + +template<> +struct numerics<bool> { + OF_NUMERICS_FUNC bool min() { return detail::numeric_limits<bool>::lowest(); } + OF_NUMERICS_FUNC bool max() { return detail::numeric_limits<bool>::max(); } + OF_NUMERICS_FUNC bool lower_bound() { return detail::numeric_limits<bool>::lower_bound(); } + OF_NUMERICS_FUNC bool upper_bound() { return detail::numeric_limits<bool>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(bool a, bool b) { return a < b; } + OF_NUMERICS_FUNC bool le(bool a, bool b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(bool a, bool b) { return a > b; } + OF_NUMERICS_FUNC bool ge(bool a, bool b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(bool a, bool b) { return a == b; } + OF_NUMERICS_FUNC bool ne(bool a, bool b) { return a != b; } + OF_NUMERICS_FUNC bool add(bool a, bool b) { return a + b; } + OF_NUMERICS_FUNC bool mul(bool a, bool b) { return a && b; } + OF_NUMERICS_FUNC bool sub(bool a, bool b) { return a - b; } + OF_NUMERICS_FUNC bool div(bool a, bool b) { return a / b; } + OF_NUMERICS_FUNC bool isnan(bool a) { return false; } + OF_NUMERICS_FUNC bool isinf(bool a) { return false; } +}; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template<> +struct numerics<int8_t> { + OF_NUMERICS_FUNC int8_t min() { return detail::numeric_limits<int8_t>::lowest(); } + OF_NUMERICS_FUNC int8_t max() { return detail::numeric_limits<int8_t>::max(); } + OF_NUMERICS_FUNC int8_t lower_bound() { return detail::numeric_limits<int8_t>::lower_bound(); } + OF_NUMERICS_FUNC int8_t upper_bound() { return detail::numeric_limits<int8_t>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(int8_t a, int8_t b) { return a < b; } + OF_NUMERICS_FUNC bool le(int8_t a, int8_t b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(int8_t a, int8_t b) { return a > b; } + OF_NUMERICS_FUNC bool ge(int8_t a, int8_t b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(int8_t a, int8_t b) { return a == b; } + OF_NUMERICS_FUNC bool ne(int8_t a, int8_t b) { return a != b; } + + OF_NUMERICS_FUNC int8_t add(int8_t a, int8_t b) { return a + b; } + OF_NUMERICS_FUNC int8_t mul(int8_t a, int8_t b) { return a * b; } + OF_NUMERICS_FUNC int8_t sub(int8_t a, int8_t b) { return a - b; } + OF_NUMERICS_FUNC int8_t div(int8_t a, int8_t b) { return a / b; } + OF_NUMERICS_FUNC int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); } + OF_NUMERICS_FUNC bool isnan(int8_t a) { return false; } + OF_NUMERICS_FUNC bool isinf(int8_t a) { return false; } +}; + +template<> +struct numerics<int16_t> { + OF_NUMERICS_FUNC int16_t min() { return detail::numeric_limits<int16_t>::lowest(); } + OF_NUMERICS_FUNC int16_t max() { return detail::numeric_limits<int16_t>::max(); } + OF_NUMERICS_FUNC int16_t lower_bound() { return detail::numeric_limits<int16_t>::lower_bound(); } + OF_NUMERICS_FUNC int16_t upper_bound() { return detail::numeric_limits<int16_t>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(int16_t a, int16_t b) { return a < b; } + OF_NUMERICS_FUNC bool le(int16_t a, int16_t b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(int16_t a, int16_t b) { return a > b; } + OF_NUMERICS_FUNC bool ge(int16_t a, int16_t b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(int16_t a, int16_t b) { return a == b; } + OF_NUMERICS_FUNC bool ne(int16_t a, int16_t b) { return a != b; } + + OF_NUMERICS_FUNC int16_t add(int16_t a, int16_t b) { return a + b; } + OF_NUMERICS_FUNC int16_t mul(int16_t a, int16_t b) { return a * b; } + OF_NUMERICS_FUNC int16_t sub(int16_t a, int16_t b) { return a - b; } + OF_NUMERICS_FUNC int16_t div(int16_t a, int16_t b) { return a / b; } + OF_NUMERICS_FUNC int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); } + OF_NUMERICS_FUNC bool isnan(int16_t a) { return false; } + OF_NUMERICS_FUNC bool isinf(int16_t a) { return false; } +}; + +template<> +struct numerics<int32_t> { + OF_NUMERICS_FUNC int32_t min() { return detail::numeric_limits<int32_t>::lowest(); } + OF_NUMERICS_FUNC int32_t max() { return detail::numeric_limits<int32_t>::max(); } + OF_NUMERICS_FUNC int32_t lower_bound() { return detail::numeric_limits<int32_t>::lower_bound(); } + OF_NUMERICS_FUNC int32_t upper_bound() { return detail::numeric_limits<int32_t>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(int32_t a, int32_t b) { return a < b; } + OF_NUMERICS_FUNC bool le(int32_t a, int32_t b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(int32_t a, int32_t b) { return a > b; } + OF_NUMERICS_FUNC bool ge(int32_t a, int32_t b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(int32_t a, int32_t b) { return a == b; } + OF_NUMERICS_FUNC bool ne(int32_t a, int32_t b) { return a != b; } + + OF_NUMERICS_FUNC int32_t add(int32_t a, int32_t b) { return a + b; } + OF_NUMERICS_FUNC int32_t mul(int32_t a, int32_t b) { return a * b; } + OF_NUMERICS_FUNC int32_t sub(int32_t a, int32_t b) { return a - b; } + OF_NUMERICS_FUNC int32_t div(int32_t a, int32_t b) { return a / b; } + OF_NUMERICS_FUNC int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); } + OF_NUMERICS_FUNC bool isnan(int32_t a) { return false; } + OF_NUMERICS_FUNC bool isinf(int32_t a) { return false; } +}; + +template<> +struct numerics<int64_t> { + OF_NUMERICS_FUNC int64_t min() { return detail::numeric_limits<int64_t>::lowest(); } + OF_NUMERICS_FUNC int64_t max() { return detail::numeric_limits<int64_t>::max(); } + OF_NUMERICS_FUNC int64_t lower_bound() { return detail::numeric_limits<int64_t>::lower_bound(); } + OF_NUMERICS_FUNC int64_t upper_bound() { return detail::numeric_limits<int64_t>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(int64_t a, int64_t b) { return a < b; } + OF_NUMERICS_FUNC bool le(int64_t a, int64_t b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(int64_t a, int64_t b) { return a > b; } + OF_NUMERICS_FUNC bool ge(int64_t a, int64_t b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(int64_t a, int64_t b) { return a == b; } + OF_NUMERICS_FUNC bool ne(int64_t a, int64_t b) { return a != b; } + + OF_NUMERICS_FUNC int64_t add(int64_t a, int64_t b) { return a + b; } + OF_NUMERICS_FUNC int64_t mul(int64_t a, int64_t b) { return a * b; } + OF_NUMERICS_FUNC int64_t sub(int64_t a, int64_t b) { return a - b; } + OF_NUMERICS_FUNC int64_t div(int64_t a, int64_t b) { return a / b; }; + OF_NUMERICS_FUNC int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); } + OF_NUMERICS_FUNC bool isnan(int64_t a) { return false; } + OF_NUMERICS_FUNC bool isinf(int64_t a) { return false; } +}; + +// DEPRECATED: use math functions from std and cuda math API (if needed) +template<> +struct numerics<float> { + OF_NUMERICS_FUNC float min() { return detail::numeric_limits<float>::lowest(); } + OF_NUMERICS_FUNC float max() { return detail::numeric_limits<float>::max(); } + OF_NUMERICS_FUNC float lower_bound() { return detail::numeric_limits<float>::lower_bound(); } + OF_NUMERICS_FUNC float upper_bound() { return detail::numeric_limits<float>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(float a, float b) { return a < b; } + OF_NUMERICS_FUNC bool le(float a, float b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(float a, float b) { return a > b; } + OF_NUMERICS_FUNC bool ge(float a, float b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(float a, float b) { return a == b; } + OF_NUMERICS_FUNC bool ne(float a, float b) { return a != b; } + + OF_NUMERICS_FUNC float sqrt(float a) { return sqrtf(a); } + OF_NUMERICS_FUNC float atan(float a) { return atanf(a); } + OF_NUMERICS_FUNC float add(float a, float b) { return a + b; } + OF_NUMERICS_FUNC float div(float a, float b) { return a / b; } + OF_NUMERICS_FUNC float mul(float a, float b) { return a * b; } + OF_NUMERICS_FUNC float sub(float a, float b) { return a - b; } + OF_NUMERICS_FUNC float pow(float a, float b) { return powf(a, b); } + OF_NUMERICS_FUNC bool isnan(float a) { return ::isnan(a); } + OF_NUMERICS_FUNC bool isinf(float a) { return ::isinf(a); } +}; + +template<> +struct numerics<double> { + OF_NUMERICS_FUNC double min() { return detail::numeric_limits<double>::lowest(); } + OF_NUMERICS_FUNC double max() { return detail::numeric_limits<double>::max(); } + OF_NUMERICS_FUNC double lower_bound() { return detail::numeric_limits<double>::lower_bound(); } + OF_NUMERICS_FUNC double upper_bound() { return detail::numeric_limits<double>::upper_bound(); } + + OF_NUMERICS_FUNC bool lt(double a, double b) { return a < b; } + OF_NUMERICS_FUNC bool le(double a, double b) { return a <= b; } + OF_NUMERICS_FUNC bool gt(double a, double b) { return a > b; } + OF_NUMERICS_FUNC bool ge(double a, double b) { return a >= b; } + OF_NUMERICS_FUNC bool eq(double a, double b) { return a == b; } + OF_NUMERICS_FUNC bool ne(double a, double b) { return a != b; } + + OF_NUMERICS_FUNC double sqrt(double a) { return ::sqrt(a); } + OF_NUMERICS_FUNC double atan(double a) { return ::atan(a); } + OF_NUMERICS_FUNC double add(double a, double b) { return a + b; } + OF_NUMERICS_FUNC double div(double a, double b) { return a / b; } + OF_NUMERICS_FUNC double mul(double a, double b) { return a * b; } + OF_NUMERICS_FUNC double sub(double a, double b) { return a - b; } + OF_NUMERICS_FUNC double pow(double a, double b) { return ::pow(a, b); } + OF_NUMERICS_FUNC bool isnan(double a) { return ::isnan(a); } + OF_NUMERICS_FUNC bool isinf(double a) { return ::isinf(a); } +}; + +} // namespace detail +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py index 8f2e1b899b5e7ad88cf8e3d5a091d3772e77ef3f..47cfca1d451e529c3a6f2dfc7bd75060da0c3d2d 100644 --- a/oneflow/python/nn/modules/pooling.py +++ b/oneflow/python/nn/modules/pooling.py @@ -20,7 +20,7 @@ from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.nn.module import Module from oneflow.python.nn.modules.utils import _single, _pair, _triple from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t -from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset +from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset, _GetSequence @oneflow_export("nn.AvgPool1d") @@ -200,12 +200,13 @@ class AvgPool3d(Module): >>> import oneflow.experimental as flow >>> import numpy as np - >>> flow.enable_eager_execution() - >>> inputarr = np.random.randn(9, 7, 11, 32, 20) - >>> of_avgpool3d = flow.nn.AvgPool3d(kernel_size=(2,2,2),padding=(0,0,0),stride=(1,1,1),) - >>> x = flow.Tensor(inputarr) - >>> y = of_avgpool3d(x) + + >>> m = flow.nn.AvgPool3d(kernel_size=(2,2,2),padding=(0,0,0),stride=(1,1,1)) + >>> x = flow.Tensor(np.random.randn(9, 7, 11, 32, 20)) + >>> y = m(x) + >>> y.shape + flow.Size([9, 7, 10, 31, 19]) """ @@ -311,8 +312,51 @@ class MaxPool1d(Module): return_indices: bool = False, ceil_mode: bool = False, ): - # TODO: fix cuDNN bugs in pooling_1d - raise NotImplementedError + super().__init__() + self.kernel_size = _pair(tuple(kernel_size)[0]) + self.stride = _pair(tuple(stride)[0]) if (stride is not None) else kernel_size + data_format = "NCL" # Only suport "NCL" for now! + self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last" + self.dilation = _GetSequence(dilation, 2, "dilation") + padding = _pair(tuple(padding)[0]) + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + if len(padding) == 2: + if self.channel_pos == "channels_first": + padding = (0, 0, padding[0], padding[1]) + else: + raise ValueError("error padding param!") + else: + raise ValueError("error padding param!") + + self.padding_type, pads_list = calc_pool_padding( + padding, get_dhw_offset(self.channel_pos), 2 + ) + self.padding_before = [pad[0] for pad in pads_list] + self.padding_after = [pad[1] for pad in pads_list] + + def forward(self, x): + expand_x = x.unsqueeze(dim=-1) + expand_y, expand_indice = flow.F.maxpool_2d( + expand_x, + data_format=self.channel_pos, + padding=self.padding_type, + padding_before=self.padding_before, + padding_after=self.padding_after, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + return_indices=True, + ceil_mode=self.ceil_mode, + ) + + y = expand_y.squeeze(dim=-1) + indice = expand_indice.squeeze(dim=-1) + if self.return_indices: + return y, indice + else: + return y @oneflow_export("nn.MaxPool2d") @@ -374,21 +418,21 @@ class MaxPool2d(Module): >>> import numpy as np >>> flow.enable_eager_execution() - >>> kernel_size, stride, padding = (3, 3), (1, 1), (1, 2) + >>> kernel_size, stride, padding = (3, 4), (1, 1), (1, 2) >>> m = flow.nn.MaxPool2d(kernel_size, stride, padding) >>> np.random.seed(0) >>> x = flow.Tensor(np.random.rand(1, 1, 5, 3)) >>> y = m(x) >>> y #doctest: +ELLIPSIS - tensor([[[[0.5488, 0.7152, 0.7152, 0.7152, 0.6459], + tensor([[[[0.7152, 0.7152, 0.7152, 0.7152], ... - [0.568 , 0.9256, 0.9256, 0.9256, 0.5289]]]], dtype=oneflow.float32) + [0.9256, 0.9256, 0.9256, 0.9256]]]], dtype=oneflow.float32) - >>> kernel_size, stride, padding = (2, 3), (4, 5), (1, 2) + >>> kernel_size, stride, padding = (2, 4), (4, 5), (1, 2) >>> m = flow.nn.MaxPool2d(kernel_size, stride, padding) >>> x = flow.Tensor(np.random.randn(9, 7, 32, 20)) >>> y = m(x) - >>> y.size() + >>> y.shape flow.Size([9, 7, 9, 5]) """ @@ -404,14 +448,14 @@ class MaxPool2d(Module): ): super().__init__() self.kernel_size = _pair(kernel_size) - self.strides = _pair(stride) if (stride is not None) else kernel_size - data_format = "NCHW" + self.stride = _pair(stride) if (stride is not None) else kernel_size + data_format = "NCHW" # Only suport "NCHW" for now! self.channel_pos = ( - "channels_last" if data_format == "NHWC" else "channels_first" + "channels_first" if data_format == "NCHW" else "channels_last" ) - - assert return_indices is False, "Only support return_indices==False for now!" - assert dilation == 1 or dilation == (1, 1), "Only support dilation==1 for now!" + self.dilation = _GetSequence(dilation, 2, "dilation") + self.return_indices = return_indices + self.ceil_mode = ceil_mode padding = _pair(padding) if len(padding) == 2: @@ -421,25 +465,32 @@ class MaxPool2d(Module): raise ValueError("error padding param!") else: raise ValueError("error padding param!") + self.padding_type, pads_list = calc_pool_padding( padding, get_dhw_offset(self.channel_pos), 2 ) self.padding_before = [pad[0] for pad in pads_list] self.padding_after = [pad[1] for pad in pads_list] - self.ceil_mode = ceil_mode def forward(self, x): - return flow.F.max_pool_2d( + y, indice = flow.F.maxpool_2d( x, - kernel_size=self.kernel_size, - stride=self.strides, + data_format=self.channel_pos, padding=self.padding_type, padding_before=self.padding_before, padding_after=self.padding_after, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + return_indices=True, ceil_mode=self.ceil_mode, - data_format=self.channel_pos, ) + if self.return_indices: + return y, indice + else: + return y + @oneflow_export("nn.MaxPool3d") @experimental_api @@ -507,21 +558,21 @@ class MaxPool3d(Module): >>> import numpy as np >>> flow.enable_eager_execution() - >>> kernel_size, stride, padding = (3, 3, 3), (1, 1, 1), (1, 1, 2) + >>> kernel_size, stride, padding = (3, 3, 4), (1, 1, 1), (1, 1, 2) >>> m = flow.nn.MaxPool3d(kernel_size, stride, padding) >>> np.random.seed(0) >>> x = flow.Tensor(np.random.rand(1, 1, 3, 5, 3)) >>> y = m(x) >>> y #doctest: +ELLIPSIS - tensor([[[[[0.7782, 0.87 , 0.9786, 0.9786, 0.9786], + tensor([[[[[0.87 , 0.9786, 0.9786, 0.9786], ... - [0.9447, 0.9447, 0.9447, 0.6668, 0.6668]]]]], dtype=oneflow.float32) - >>> kernel_size, stride, padding = (2, 2, 3), (3, 4, 5), (2, 1, 2) + [0.9447, 0.9447, 0.9447, 0.6668]]]]], dtype=oneflow.float32) + >>> kernel_size, stride, padding = (4, 2, 4), (3, 4, 5), (2, 1, 2) >>> m = flow.nn.MaxPool3d(kernel_size, stride, padding) >>> x = flow.Tensor(np.random.randn(9, 7, 11, 32, 20)) >>> y = m(x) - >>> y.size() - flow.Size([9, 7, 5, 9, 5]) + >>> y.shape + flow.Size([9, 7, 4, 9, 5]) """ @@ -535,19 +586,17 @@ class MaxPool3d(Module): ceil_mode: bool = False, ): super().__init__() - kernel_size = _triple(kernel_size) - strides = _triple(stride) if (stride is not None) else kernel_size + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride) if (stride is not None) else kernel_size data_format = "NCDHW" - channel_pos = "channels_last" if data_format == "NDHWC" else "channels_first" - - assert return_indices is False, "Only support return_indices==False for now!" - assert dilation == 1 or dilation == ( - 1, - 1, - 1, - ), "Only support dilation==1 for now!" - + self.channel_pos = ( + "channels_last" if data_format == "NDHWC" else "channels_first" + ) + self.dilation = _GetSequence(dilation, 3, "dilation") padding = _triple(padding) + self.return_indices = return_indices + self.ceil_mode = ceil_mode + if len(padding) == 3: if data_format == "NCDHW": padding = (0, 0, padding[0], padding[1], padding[2]) @@ -556,28 +605,30 @@ class MaxPool3d(Module): else: raise ValueError("error padding param!") - padding_type, pads_list = calc_pool_padding( - padding, get_dhw_offset(channel_pos), 3 + self.padding_type, pads_list = calc_pool_padding( + padding, get_dhw_offset(self.channel_pos), 3 ) - padding_before = [pad[0] for pad in pads_list] - padding_after = [pad[1] for pad in pads_list] + self.padding_before = [pad[0] for pad in pads_list] + self.padding_after = [pad[1] for pad in pads_list] - self._op = ( - flow.builtin_op("max_pool_3d") - .Attr("data_format", channel_pos) - .Attr("pool_size", kernel_size) - .Attr("strides", strides) - .Attr("ceil_mode", ceil_mode) - .Attr("padding", padding_type) - .Attr("padding_before", padding_before) - .Attr("padding_after", padding_after) - .Input("x") - .Output("y") - .Build() + def forward(self, x): + y, indice = flow.F.maxpool_3d( + x, + data_format=self.channel_pos, + padding=self.padding_type, + padding_before=self.padding_before, + padding_after=self.padding_after, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + return_indices=True, + ceil_mode=self.ceil_mode, ) - def forward(self, x): - return self._op(x)[0] + if self.return_indices: + return y, indice + else: + return y if __name__ == "__main__": diff --git a/oneflow/python/ops/nn_ops.py b/oneflow/python/ops/nn_ops.py index c76d3cfef784e2c35e2fcd37777ef7fa687c8277..bcc290bcba79d427149a9f5b058f20376d968119 100644 --- a/oneflow/python/ops/nn_ops.py +++ b/oneflow/python/ops/nn_ops.py @@ -1789,6 +1789,426 @@ def calc_pool_padding(padding, dhw_offset, ndims): return padding_type, ndim_pads_list +@oneflow_export("nn.MaxPool1d") +@stable_api +def MaxPool1d( + input: oneflow._oneflow_internal.BlobDesc, + kernel_size: Union[int, IntPair], + stride: Union[int, IntPair], + padding: Union[str, IntPair], + dilation: Union[int, IntPair] = 1, + return_indices: bool = False, + ceil_mode: bool = False, + data_format: str = "NCHW", + name: Optional[str] = None, +) -> oneflow._oneflow_internal.BlobDesc: + r""" Performs the 1d-max pooling on the input `Blob`. + Different from nn.max_pool1d, nn.MaxPool2d supports more params e.g. dilation,return_indices. + + Args: + input (remote_blob_util.BlobDesc): A 3-D `Blob` of the format specified by data_format. + kernel_size (Union[int, IntPair]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input `Blob`. + stride (Union[int, IntPair]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input `Blob`. + padding (str): '`VALID'` or '`SAME'` or '`SAME_LOWER'` or '`SAME_UPPER'` or Tuple[IntPair, IntPair, IntPair, IntPair]`. The padding algorithm. + dilation (Union[int, IntPair]): a parameter that controls the stride of elements in the window. + return_indices (bool): if True, will return the max indices along with the outputs. + ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape. + data_format (str, optional): '`NHWC'`, '`NCHW'` or '`NCHW_VECT_C'`. Defaults to "NCHW", for now only supporr 'NCHW'. + name (Optional[str], optional): This operator's name(optional). Defaults to None. + + Returns: + remote_blob_util.BlobDesc: A `Blob` of format specified by data_format. The max pooled output `Blob`. + + For example: + + .. code-block:: python + + import oneflow as flow + import oneflow.typing as tp + from typing import Tuple + import numpy as np + + input_shape = (2, 2, 4) + @flow.global_function(type="train", function_config=func_config) + def maxpool1d_job_with_grad( + input_x: tp.Numpy.Placeholder(input_shape), + ) -> Tuple[tp.Numpy, tp.Numpy]: + x_var = flow.get_variable( + name="input_x", + shape=input_shape, + dtype=flow.float32, + initializer=flow.constant_initializer(0), + trainable=True, + ) + x_var = flow.cast_to_current_logical_view(x_var) + flow.watch_diff(x_var, Setter("x_diff")) + x = x_var + input_x + # x = flow.cast(x, dtype=flow.int32) + with flow.scope.placement("cpu", "0:0"): + (y, indice) = flow.nn.MaxPool1d( + x, + kernel_size=3, + stride=2, + padding=1, + dilation=1, + return_indices=True, + ceil_mode=False, + data_format="NCHW", + ) + flow.optimizer.SGD( + flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0 + ).minimize(y) + return (y, indice) + + x = np.arange(16).reshape(input_shape).astype(np.float32) + y, indice = maxpool1d_job(x) + print("in:\n", x, "\ny:\n", y, "\nindice:\n", indice) + + # x: + # [[[ 0. 1. 2. 3.] + # [ 4. 5. 6. 7.]] + + # [[ 8. 9. 10. 11.] + # [12. 13. 14. 15.]]] + # y: + # [[[ 1. 3.] + # [ 5. 7.]] + + # [[ 9. 11.] + # [13. 15.]]] + # indice: + # [[[1 3] + # [1 3]] + + # [[1 3] + # [1 3]]] + + """ + assert data_format in ["NCHW"] + channel_pos = "channels_last" if data_format == "NHWC" else "channels_first" + kernel_size = _GetSequence(kernel_size, 2, "kernel_size") + dilation = _GetSequence(dilation, 2, "dilation") + stride = _GetSequence(stride, 2, "stride") + assert padding >= 0 or padding in ["SAME", "VALID"] + if padding >= 0: + if data_format == "NCHW": + padding = (0, 0, padding, padding) + elif data_format == "NHWC": + padding = (0, padding, padding, 0) + else: + raise ValueError('data_format must be "NHWC" or "NCHW".') + padding_type, pads_list = calc_pool_padding(padding, get_dhw_offset(channel_pos), 2) + padding_before = [pad[0] for pad in pads_list] + padding_after = [pad[1] for pad in pads_list] + + expand_input = flow.expand_dims(input=input, axis=2) + assert len(pads_list) == len(expand_input.shape) - 2 + y, indice = ( + flow.user_op_builder( + name if name is not None else id_util.UniqueStr("MaxPool1d_") + ) + .Op("maxpool_2d") + .Input("x", [expand_input]) + .Output("y") + .Output("indice") + .Attr("data_format", channel_pos) + .Attr("stride", stride) + .Attr("kernel_size", kernel_size) + .Attr("padding", padding_type) + .Attr("padding_before", padding_before) + .Attr("padding_after", padding_after) + .Attr("dilation", dilation) + .Attr("return_indices", return_indices) + .Attr("ceil_mode", ceil_mode) + .Build() + .InferAndTryRun() + .RemoteBlobList() + ) + y = flow.squeeze(y, axis=(2,)) + indice = flow.squeeze(indice, axis=(2,)) + if return_indices == True: + return y, indice + else: + return y + + +@oneflow_export("nn.MaxPool2d") +@stable_api +def MaxPool2d( + input: oneflow._oneflow_internal.BlobDesc, + kernel_size: Union[int, IntPair], + stride: Union[int, IntPair], + padding: Union[str, int, Tuple[int, int]], + dilation: Union[int, IntPair] = 1, + return_indices: bool = False, + ceil_mode: bool = False, + data_format: str = "NCHW", + name: Optional[str] = None, +) -> oneflow._oneflow_internal.BlobDesc: + r""" Performs the 2d-max pooling on the input `Blob`. + Different from nn.max_pool2d, nn.MaxPool2d supports more params e.g. dilation,return_indices. + + Args: + input (remote_blob_util.BlobDesc): A 4-D `Blob` of the format specified by data_format. + kernel_size (Union[int, IntPair]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input `Blob`. + stride (Union[int, IntPair]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input `Blob`. + padding (str): '`VALID'` or '`SAME'` or '`SAME_LOWER'` or '`SAME_UPPER'` or Tuple[IntPair, IntPair, IntPair, IntPair]`. The padding algorithm. + dilation (Union[int, IntPair]): a parameter that controls the stride of elements in the window. + return_indices (bool): if True, will return the max indices along with the outputs. + ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape. + data_format (str, optional): '`NHWC'`, '`NCHW'` or '`NCHW_VECT_C'`. Defaults to "NCHW", for now only supporr 'NCHW'. + name (Optional[str], optional): This operator's name(optional). Defaults to None. + + Returns: + remote_blob_util.BlobDesc: A `Blob` of format specified by data_format. The max pooled output `Blob`. + + For example: + + .. code-block:: python + + import oneflow as flow + import oneflow.typing as tp + from typing import Tuple + import numpy as np + + input_shape = (1, 2, 4, 4) + @flow.global_function(type="predict") + def maxpool_job( + x: tp.Numpy.Placeholder(input_shape), + ) -> Tuple[tp.Numpy, tp.Numpy]: + with flow.scope.placement("gpu", "0:0"): + (y, indice) = flow.nn.MaxPool2d( + x, + kernel_size=3, + stride=2, + padding=1, + dilation=1, + return_indices=True, + ceil_mode=False, + data_format="NCHW", + ) + return (y, indice) + + x = np.arange(32).reshape(input_shape).astype(np.float32) + y, indice = maxpool_job(x) + print("in:\n", x, "\ny:\n", y, "\nindice:\n", indice) + + #in: + #[[[[ 0. 1. 2. 3.] + #[ 4. 5. 6. 7.] + #[ 8. 9. 10. 11.] + #[12. 13. 14. 15.]] + + #[[16. 17. 18. 19.] + #[20. 21. 22. 23.] + #[24. 25. 26. 27.] + #[28. 29. 30. 31.]]]] + #y: + #[[[[ 5. 7.] + #[13. 15.]] + + #[[21. 23.] + #[29. 31.]]]] + + #indice: + #[[[[5 7] + #[13 15]] + + #[[5 7] + #[13 15]]]] + + """ + assert data_format in ["NCHW"] + channel_pos = "channels_last" if data_format == "NHWC" else "channels_first" + kernel_size = _GetSequence(kernel_size, 2, "kernel_size") + dilation = _GetSequence(dilation, 2, "dilation") + stride = _GetSequence(stride, 2, "stride") + assert isinstance(padding, int) or len(padding) == 2 or padding in ["SAME", "VALID"] + + if isinstance(padding, int): + padding = [padding, padding] + if len(padding) == 2: + if data_format == "NCHW": + padding = (0, 0, padding[0], padding[1]) + elif data_format == "NHWC": + padding = (0, padding[0], padding[1], 0) + else: + raise ValueError('data_format must be "NHWC" or "NCHW".') + + padding_type, pads_list = calc_pool_padding(padding, get_dhw_offset(channel_pos), 2) + padding_before = [pad[0] for pad in pads_list] + padding_after = [pad[1] for pad in pads_list] + assert len(pads_list) == len(input.shape) - 2 + y, indice = ( + flow.user_op_builder( + name if name is not None else id_util.UniqueStr("MaxPool2d_") + ) + .Op("maxpool_2d") + .Input("x", [input]) + .Output("y") + .Output("indice") + .Attr("data_format", channel_pos) + .Attr("stride", stride) + .Attr("kernel_size", kernel_size) + .Attr("padding", padding_type) + .Attr("padding_before", padding_before) + .Attr("padding_after", padding_after) + .Attr("dilation", dilation) + .Attr("return_indices", return_indices) + .Attr("ceil_mode", ceil_mode) + .Build() + .InferAndTryRun() + .RemoteBlobList() + ) + if return_indices == True: + return y, indice + else: + return y + + +@oneflow_export("nn.MaxPool3d") +@stable_api +def MaxPool3d( + input: oneflow._oneflow_internal.BlobDesc, + kernel_size: Union[int, IntPair], + stride: Union[int, IntPair], + padding: Union[str, int, Tuple[int, int, int]], + dilation: Union[int, IntPair] = 1, + return_indices: bool = False, + ceil_mode: bool = False, + data_format: str = "NCDHW", + name: Optional[str] = None, +) -> oneflow._oneflow_internal.BlobDesc: + r""" Performs the 3d-max pooling on the input `Blob`. + Different from nn.max_pool3d, nn.MaxPool3d supports more params e.g. dilation,return_indices. + + Args: + input (remote_blob_util.BlobDesc): A 5-D `Blob` of the format specified by data_format. + kernel_size (Union[int, IntPair]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input `Blob`. + stride (Union[int, IntPair]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input `Blob`. + padding (str): '`VALID'` or '`SAME'` or '`SAME_LOWER'` or '`SAME_UPPER'` or int value or Tuple[int, int, int]`. The padding algorithm. + dilation (Union[int, IntPair]): a parameter that controls the stride of elements in the window. + return_indices (bool): if True, will return the max indices along with the outputs. + ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape. + data_format (str, optional): '`NCDHW'`, '`NCHWD'`. Defaults to "NCDHW", for now only supporr 'NCDHW'. + name (Optional[str], optional): This operator's name(optional). Defaults to None. + + Returns: + remote_blob_util.BlobDesc: A `Blob` of format specified by data_format. The max pooled output `Blob`. + + For example: + + .. code-block:: python + + import oneflow as flow + import oneflow.typing as tp + from typing import Tuple + import numpy as np + + input_shape = (1, 1, 2, 4, 4) + @flow.global_function(type="predict") + def maxpool3d_job( + x: tp.Numpy.Placeholder(input_shape), + ) -> tp.Numpy: + with flow.scope.placement("gpu", "0:0"): + (y, indice) = flow.nn.MaxPool3d( + input=x, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + return_indices=True, + ceil_mode=False, + data_format="NCDHW", + ) + return (y, indice) + + x = np.arange(32).reshape(input_shape).astype(np.float32) + y, indice = maxpool3d_job(x) + print("in:\n", x, "\ny:\n", y, "\nindice:\n", indice) + + # in: + # [[[[[ 0. 1. 2. 3.] + # [ 4. 5. 6. 7.] + # [ 8. 9. 10. 11.] + # [12. 13. 14. 15.]] + + # [[16. 17. 18. 19.] + # [20. 21. 22. 23.] + # [24. 25. 26. 27.] + # [28. 29. 30. 31.]]]]] + # y: + # [[[[[21. 22. 23. 23.] + # [25. 26. 27. 27.] + # [29. 30. 31. 31.] + # [29. 30. 31. 31.]] + + # [[21. 22. 23. 23.] + # [25. 26. 27. 27.] + # [29. 30. 31. 31.] + # [29. 30. 31. 31.]]]]] + # indice: + # [[[[[21 22 23 23] + # [25 26 27 27] + # [29 30 31 31] + # [29 30 31 31]] + + # [[21 22 23 23] + # [25 26 27 27] + # [29 30 31 31] + # [29 30 31 31]]]]] + + """ + assert data_format in ["NCDHW"] + channel_pos = "channels_first" if data_format == "NCDHW" else "channels_last" + kernel_size = _GetSequence(kernel_size, 3, "kernel_size") + dilation = _GetSequence(dilation, 3, "dilation") + stride = _GetSequence(stride, 3, "stride") + assert ( + isinstance(padding, int) + or isinstance(padding, Tuple) + or padding in ["SAME", "VALID"] + ) + if isinstance(padding, int): + padding = (padding, padding, padding) + if len(padding) == 3: + if data_format == "NCDHW": + padding = (0, 0, padding[0], padding[1], padding[2]) + elif data_format == "NDHWC": + padding = (0, padding[0], padding[1], padding[2], 0) + else: + raise ValueError('data_format must be "NHWDC" or "NCDHW".') + padding_type, pads_list = calc_pool_padding(padding, get_dhw_offset(channel_pos), 3) + padding_before = [pad[0] for pad in pads_list] + padding_after = [pad[1] for pad in pads_list] + assert len(pads_list) == len(input.shape) - 2 + y, indice = ( + flow.user_op_builder( + name if name is not None else id_util.UniqueStr("MaxPool3d_") + ) + .Op("maxpool_3d") + .Input("x", [input]) + .Output("y") + .Output("indice") + .Attr("data_format", channel_pos) + .Attr("stride", stride) + .Attr("kernel_size", kernel_size) + .Attr("padding", padding_type) + .Attr("padding_before", padding_before) + .Attr("padding_after", padding_after) + .Attr("dilation", dilation) + .Attr("return_indices", return_indices) + .Attr("ceil_mode", ceil_mode) + .Build() + .InferAndTryRun() + .RemoteBlobList() + ) + if return_indices == True: + return y, indice + else: + return y + + @oneflow_export("nn.max_pool2d") def max_pool2d( input: oneflow._oneflow_internal.BlobDesc, diff --git a/oneflow/python/test/modules/test_maxpool.py b/oneflow/python/test/modules/test_pooling.py similarity index 58% rename from oneflow/python/test/modules/test_maxpool.py rename to oneflow/python/test/modules/test_pooling.py index 08f23c6a7ae2c69adaf2c91098f60cf6f66f1228..ad1951c8a4f7e45040d711aaf9055389c4017f08 100644 --- a/oneflow/python/test/modules/test_maxpool.py +++ b/oneflow/python/test/modules/test_pooling.py @@ -158,21 +158,224 @@ class MaxPoolNumpy: return dx +def _test_maxpool1d_impl(test_case, device): + input_arr = np.array( + [ + [ + [-0.89042996, 2.33971243, -0.86660827, 0.80398747], + [-1.46769364, -0.78125064, 1.50086563, -0.76278226], + [1.31984534, 0.20741192, -0.86507054, -0.40776015], + [-0.89910823, 0.44932938, 1.49148118, -0.22036761], + ], + [ + [-0.5452334, -0.10255169, -1.42035108, 0.73922913], + [-0.03192764, 0.69341935, 0.96263152, -1.52070843], + [0.02058239, 1.504032, 1.84423001, -0.0130596], + [2.20517719, 0.38449598, 0.85677771, 0.60425179], + ], + [ + [-1.64366213, 0.51370298, -0.21754866, -0.05085382], + [1.17065374, 1.13857674, -1.13070507, 0.44353707], + [-1.30783846, -0.48031445, 0.41807536, -2.13778887], + [0.08259005, 0.5798125, 0.03024696, 1.96100924], + ], + ] + ) + kernel_size, stride, padding = (3,), (1,), (1,) + + output = np.array( + [ + [ + [2.33971243, 2.33971243, 2.33971243, 0.80398747], + [-0.78125064, 1.50086563, 1.50086563, 1.50086563], + [1.31984534, 1.31984534, 0.20741192, -0.40776015], + [0.44932938, 1.49148118, 1.49148118, 1.49148118], + ], + [ + [-0.10255169, -0.10255169, 0.73922913, 0.73922913], + [0.69341935, 0.96263152, 0.96263152, 0.96263152], + [1.504032, 1.84423001, 1.84423001, 1.84423001], + [2.20517719, 2.20517719, 0.85677771, 0.85677771], + ], + [ + [0.51370298, 0.51370298, 0.51370298, -0.05085382], + [1.17065374, 1.17065374, 1.13857674, 0.44353707], + [-0.48031445, 0.41807536, 0.41807536, 0.41807536], + [0.5798125, 0.5798125, 1.96100924, 1.96100924], + ], + ] + ) + + output_indice = np.array( + [ + [[1, 1, 1, 3], [1, 2, 2, 2], [0, 0, 1, 3], [1, 2, 2, 2]], + [[1, 1, 3, 3], [1, 2, 2, 2], [1, 2, 2, 2], [0, 0, 2, 2]], + [[1, 1, 1, 3], [0, 0, 1, 3], [1, 2, 2, 2], [1, 1, 3, 3]], + ] + ) + + grad = np.array( + [ + [ + [0.0, 3.0, 0.0, 1.0], + [0.0, 1.0, 3.0, 0.0], + [2.0, 1.0, 0.0, 1.0], + [0.0, 1.0, 3.0, 0.0], + ], + [ + [0.0, 2.0, 0.0, 2.0], + [0.0, 1.0, 3.0, 0.0], + [0.0, 1.0, 3.0, 0.0], + [2.0, 0.0, 2.0, 0.0], + ], + [ + [0.0, 3.0, 0.0, 1.0], + [2.0, 1.0, 0.0, 1.0], + [0.0, 1.0, 3.0, 0.0], + [0.0, 2.0, 0.0, 2.0], + ], + ] + ) + + m = flow.nn.MaxPool1d( + kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True + ) + m.to(flow.device(device)) + x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True) + of_output, of_indice = m(x) + + y = of_output.sum() + y.backward() + + test_case.assertTrue(np.allclose(x.grad.numpy(), grad, 1e-4, 1e-4)) + test_case.assertTrue(np.allclose(of_indice.numpy(), output_indice, 1e-4, 1e-4)) + test_case.assertTrue(np.allclose(of_output.numpy(), output, 1e-4, 1e-4)) + + def _test_maxpool2d(test_case, device): dim = 2 - input_arr = np.random.randn(6, 4, 7, 9) - kernel_size, stride, padding = (4, 4), (1, 1), (1, 2) + + input_arr = np.random.randn(2, 3, 4, 5) + kernel_size, stride, padding = (3, 3), (1, 1), (1, 1) m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) numpy_output = m_numpy(input_arr) - m = flow.nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding) + m = flow.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True + ) m.to(flow.device(device)) x = flow.Tensor(input_arr, device=flow.device(device)) - output = m(x) + output, indice = m(x) + test_case.assertTrue(indice.shape == x.shape) test_case.assertTrue(np.allclose(numpy_output, output.numpy(), 1e-4, 1e-4)) +def _test_maxpool2d_ceil_mode(test_case, device): + dim = 2 + input_arr = np.array( + [ + [ + [ + [-0.89042996, 2.33971243, -0.86660827, 0.80398747], + [-1.46769364, -0.78125064, 1.50086563, -0.76278226], + [1.31984534, 0.20741192, -0.86507054, -0.40776015], + [-0.89910823, 0.44932938, 1.49148118, -0.22036761], + ], + [ + [-0.5452334, -0.10255169, -1.42035108, 0.73922913], + [-0.03192764, 0.69341935, 0.96263152, -1.52070843], + [0.02058239, 1.504032, 1.84423001, -0.0130596], + [2.20517719, 0.38449598, 0.85677771, 0.60425179], + ], + [ + [-1.64366213, 0.51370298, -0.21754866, -0.05085382], + [1.17065374, 1.13857674, -1.13070507, 0.44353707], + [-1.30783846, -0.48031445, 0.41807536, -2.13778887], + [0.08259005, 0.5798125, 0.03024696, 1.96100924], + ], + ], + [ + [ + [0.45173843, -0.34680027, -0.99754943, 0.18539502], + [-0.68451047, -0.03217399, 0.44705642, -0.39016231], + [-0.18062337, 1.82099303, -0.19113869, 0.85298683], + [0.14080452, 0.15306701, -1.02466827, -0.34480665], + ], + [ + [-0.21048489, 0.20933038, -0.09206508, -1.80402519], + [-0.52028985, 0.01140166, -1.13452858, 0.96648332], + [0.26454393, 0.48343972, -1.84055509, -0.01256443], + [0.31024029, 0.11983007, 0.98806488, 0.93557438], + ], + [ + [0.39152445, 0.672159, 0.71289289, -0.68072016], + [0.33711062, -1.78106242, 0.34545201, -1.62029359], + [0.47343899, -2.3433269, -0.44517497, 0.09004267], + [0.26310742, -1.53121271, 0.65028836, 1.3669488], + ], + ], + ] + ) + + ceil_mode_out = np.array( + [ + [ + [ + [2.33971243, 2.33971243, 0.80398747], + [1.31984534, 1.50086563, -0.22036761], + [0.44932938, 1.49148118, -0.22036761], + ], + [ + [0.69341935, 0.96263152, 0.73922913], + [2.20517719, 1.84423001, 0.60425179], + [2.20517719, 0.85677771, 0.60425179], + ], + [ + [1.17065374, 1.13857674, 0.44353707], + [1.17065374, 1.96100924, 1.96100924], + [0.5798125, 1.96100924, 1.96100924], + ], + ], + [ + [ + [0.45173843, 0.44705642, 0.18539502], + [1.82099303, 1.82099303, 0.85298683], + [0.15306701, 0.15306701, -0.34480665], + ], + [ + [0.20933038, 0.96648332, 0.96648332], + [0.48343972, 0.98806488, 0.96648332], + [0.31024029, 0.98806488, 0.93557438], + ], + [ + [0.672159, 0.71289289, -0.68072016], + [0.47343899, 1.3669488, 1.3669488], + [0.26310742, 1.3669488, 1.3669488], + ], + ], + ] + ) + kernel_size, stride, padding = (3, 3), (2, 2), (1, 1) + + m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) + numpy_output = m_numpy(input_arr) + + m1 = flow.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=False + ) + m2 = flow.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=True + ) + m1.to(flow.device(device)) + m2.to(flow.device(device)) + x = flow.Tensor(input_arr, device=flow.device(device)) + output1 = m1(x) + output2 = m2(x) + test_case.assertTrue(np.allclose(numpy_output, output1.numpy(), 1e-4, 1e-4)) + test_case.assertTrue(np.allclose(ceil_mode_out, output2.numpy(), 1e-4, 1e-4)) + + def _test_maxpool2d_special_kernel_size(test_case, device): dim = 2 input_arr = np.random.randn(1, 1, 6, 6) @@ -191,7 +394,7 @@ def _test_maxpool2d_special_kernel_size(test_case, device): def _test_maxpool2d_diff_kernel_stride(test_case, device): dim = 2 input_arr = np.random.randn(9, 7, 32, 20) - kernel_size, stride, padding = (2, 3), (4, 5), (1, 2) + kernel_size, stride, padding = (2, 4), (4, 5), (1, 2) m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) numpy_output = m_numpy(input_arr) @@ -205,7 +408,7 @@ def _test_maxpool2d_diff_kernel_stride(test_case, device): def _test_maxpool2d_negative_input(test_case, device): dim = 2 - input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float) + input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float32) kernel_size, stride, padding = (5, 5), (5, 5), (2, 2) m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) @@ -235,7 +438,7 @@ def _test_maxpool2d_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) def _test_maxpool2d_special_kernel_size_backward(test_case, device): @@ -255,13 +458,13 @@ def _test_maxpool2d_special_kernel_size_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) def _test_maxpool2d_diff_kernel_stride_backward(test_case, device): dim = 2 input_arr = np.random.randn(9, 7, 32, 20) - kernel_size, stride, padding = (2, 3), (4, 5), (1, 2) + kernel_size, stride, padding = (2, 4), (4, 5), (1, 2) m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) numpy_output = m_numpy(input_arr) @@ -275,12 +478,12 @@ def _test_maxpool2d_diff_kernel_stride_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) def _test_maxpool2d_negative_input_backward(test_case, device): dim = 2 - input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float) + input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float32) kernel_size, stride, padding = (5, 5), (5, 5), (2, 2) m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) @@ -295,7 +498,22 @@ def _test_maxpool2d_negative_input_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) + + +def _test_maxpool3d(test_case, device): + dim = 3 + input_arr = np.random.randn(2, 3, 7, 9, 13) + kernel_size, stride, padding = (2, 3, 4), (2, 3, 4), (1, 1, 2) + + m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) + numpy_output = m_numpy(input_arr) + + m = flow.nn.MaxPool3d(kernel_size=kernel_size, stride=stride, padding=padding) + m.to(flow.device(device)) + x = flow.Tensor(input_arr, device=flow.device(device)) + output = m(x) + test_case.assertTrue(np.allclose(numpy_output, output.numpy(), 1e-4, 1e-4)) def _test_maxpool3d_backward(test_case, device): @@ -316,7 +534,7 @@ def _test_maxpool3d_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) def _test_maxpool3d_special_kernel_size_backward(test_case, device): @@ -337,12 +555,33 @@ def _test_maxpool3d_special_kernel_size_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) + + +def _test_maxpool3d_diff_kernel_stride_backward(test_case, device): + dim = 3 + input_arr = np.random.randn(9, 7, 48, 32, 20) + kernel_size, stride, padding = (6, 2, 4), (5, 4, 5), (3, 1, 2) + + m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) + numpy_output = m_numpy(input_arr) + + m = flow.nn.MaxPool3d(kernel_size=kernel_size, stride=stride, padding=padding) + m.to(flow.device(device)) + x = flow.Tensor(input_arr, requires_grad=True, device=flow.device(device)) + output = m(x) + test_case.assertTrue(np.allclose(numpy_output, output.numpy(), 1e-4, 1e-4)) + + output = output.sum() + output.backward() + doutput = np.ones_like(numpy_output, dtype=np.float64) + numpy_grad = m_numpy.backward(doutput) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) def _test_maxpool3d_negative_input_backward(test_case, device): dim = 3 - input_arr = -1.23456 * np.ones((1, 1, 1, 1, 1), dtype=np.float) + input_arr = -1.23456 * np.ones((1, 1, 1, 1, 1), dtype=np.float32) kernel_size, stride, padding = (5, 5, 5), (5, 5, 5), (2, 2, 2) m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding) @@ -358,18 +597,28 @@ def _test_maxpool3d_negative_input_backward(test_case, device): output.backward() doutput = np.ones_like(numpy_output, dtype=np.float64) numpy_grad = m_numpy.backward(doutput) - test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5)) + test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4)) @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", ) -class TestPoolingModule(flow.unittest.TestCase): +class TestPooling(flow.unittest.TestCase): + def test_maxpool1d(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_maxpool1d_impl, + ] + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + def test_maxpool2d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_maxpool2d, + _test_maxpool2d_ceil_mode, _test_maxpool2d_special_kernel_size, _test_maxpool2d_diff_kernel_stride, _test_maxpool2d_negative_input, @@ -378,6 +627,7 @@ class TestPoolingModule(flow.unittest.TestCase): _test_maxpool2d_diff_kernel_stride_backward, _test_maxpool2d_negative_input_backward, ] + arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @@ -385,9 +635,11 @@ class TestPoolingModule(flow.unittest.TestCase): def test_maxpool3d(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ + _test_maxpool3d, _test_maxpool3d_backward, _test_maxpool3d_special_kernel_size_backward, _test_maxpool3d_negative_input_backward, + _test_maxpool3d_diff_kernel_stride_backward, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): diff --git a/oneflow/user/kernels/pooling_kernel.cpp b/oneflow/user/kernels/pooling_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39213ea3b42db8451331349f978f58933fff8eff --- /dev/null +++ b/oneflow/user/kernels/pooling_kernel.cpp @@ -0,0 +1,252 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/pooling_kernel_util.h" + +namespace oneflow { + +struct PoolingOpKernelState final : public user_op::OpKernelState { + PoolingParams3D params_3d; + PoolingOpKernelState(PoolingParams3D params_3d) : params_3d(params_3d) {} + const PoolingParams3D& GetParams3D() { return params_3d; } +}; + +std::shared_ptr<PoolingOpKernelState> DoCreateOpKernelState(user_op::KernelComputeContext* ctx, + const int32_t& dim) { + const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); + const std::string& padding = ctx->Attr<std::string>("padding"); + const std::string& data_format = ctx->Attr<std::string>("data_format"); + const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); + const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size"); + const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride"); + const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation"); + const bool return_indices = ctx->Attr<bool>("return_indices"); + const bool ceil_mode = ctx->Attr<bool>("ceil_mode"); + + PoolingParams3D params_3d = + PoolingParams3D(dim, x_shape, data_format, padding, padding_before, padding_after, + kernel_size, stride, dilation, return_indices, ceil_mode); + std::shared_ptr<PoolingOpKernelState> state(new PoolingOpKernelState(params_3d)); + return std::move(state); +} + +template<typename T> +struct PoolingKernelUtil<DeviceType::kCPU, T> { + static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d) { + Maxpool2dFarwardCompute<T>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1], + params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3), + params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1], + params_3d.pooling_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2], + params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); + } + + static void Maxpool2dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d) { + Maxpool2dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, + params_3d.num_batch(), params_3d.num_channel(), + params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); + } + + static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d) { + Maxpool3dFarwardCompute<T>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0], + params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], params_3d.num_batch(), + params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3), + params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), + params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0], + params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], params_3d.stride_3d()[0], + params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[0], + params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); + } + + static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d) { + Maxpool3dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, + params_3d.num_batch(), params_3d.num_channel(), + params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), + params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); + } +}; + +template<DeviceType device_type, typename T> +class MaxPool2dKernel final : public user_op::OpKernel { + public: + MaxPool2dKernel() = default; + ~MaxPool2dKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); + + const auto& pooling_state = DoCreateOpKernelState(ctx, 2); + const PoolingParams3D& params_3d = pooling_state->GetParams3D(); + + const int64_t elem_num = y->shape().elem_cnt(); + const T* src = x->dptr<T>(); + T* dest = y->mut_dptr<T>(); + int64_t* indice_ptr = indice->mut_dptr<int64_t>(); + + DimVector y_vector; + y->shape().ToDimVector(&y_vector); + NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data()); + + PoolingKernelUtil<device_type, T>::Maxpool2dForward(ctx->device_ctx(), index_helper, elem_num, + src, dest, indice_ptr, params_3d); + }; +}; + +template<DeviceType device_type, typename T> +class MaxPool2dGradKernel final : public user_op::OpKernel { + public: + MaxPool2dGradKernel() = default; + ~MaxPool2dGradKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + + const auto& pooling_state = DoCreateOpKernelState(ctx, 2); + const PoolingParams3D& params_3d = pooling_state->GetParams3D(); + + const int64_t elem_num = dy->shape().elem_cnt(); + const T* src = dy->dptr<T>(); + const int64_t* indice_ptr = indice->dptr<int64_t>(); + T* dest = dx->mut_dptr<T>(); + DimVector dy_vector; + dy->shape().ToDimVector(&dy_vector); + NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data()); + + size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type()); + Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); + + PoolingKernelUtil<device_type, T>::Maxpool2dBackward(ctx->device_ctx(), index_helper, elem_num, + src, dest, indice_ptr, params_3d); + }; +}; + +template<DeviceType device_type, typename T> +class MaxPool3dKernel final : public user_op::OpKernel { + public: + MaxPool3dKernel() = default; + ~MaxPool3dKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); + + const auto& pooling_state = DoCreateOpKernelState(ctx, 3); + const PoolingParams3D& params_3d = pooling_state->GetParams3D(); + + const int64_t elem_num = y->shape().elem_cnt(); + const T* src = x->dptr<T>(); + T* dest = y->mut_dptr<T>(); + int64_t* indice_ptr = indice->mut_dptr<int64_t>(); + + DimVector y_vector; + y->shape().ToDimVector(&y_vector); + NdIndexOffsetHelper<int64_t, 5> index_helper(y_vector.data()); + + PoolingKernelUtil<device_type, T>::Maxpool3dForward(ctx->device_ctx(), index_helper, elem_num, + src, dest, indice_ptr, params_3d); + }; +}; + +template<DeviceType device_type, typename T> +class MaxPool3dGradKernel final : public user_op::OpKernel { + public: + MaxPool3dGradKernel() = default; + ~MaxPool3dGradKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + + const auto& pooling_state = DoCreateOpKernelState(ctx, 3); + const PoolingParams3D& params_3d = pooling_state->GetParams3D(); + + const int64_t elem_num = dy->shape().elem_cnt(); + const T* src = dy->dptr<T>(); + const int64_t* indice_ptr = indice->dptr<int64_t>(); + T* dest = dx->mut_dptr<T>(); + + DimVector dy_vector; + dy->shape().ToDimVector(&dy_vector); + NdIndexOffsetHelper<int64_t, 5> index_helper(dy_vector.data()); + + size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type()); + Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); + + PoolingKernelUtil<device_type, T>::Maxpool3dBackward(ctx->device_ctx(), index_helper, elem_num, + src, dest, indice_ptr, params_3d); + }; +}; + +#define REGISTER_POOLING_KERNELS(device, dtype) \ + REGISTER_USER_KERNEL("maxpool_2d") \ + .SetCreateFn<MaxPool2dKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("maxpool_2d_grad") \ + .SetCreateFn<MaxPool2dGradKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("maxpool_3d") \ + .SetCreateFn<MaxPool3dKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("maxpool_3d_grad") \ + .SetCreateFn<MaxPool3dGradKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); + +#define REGISTER_POOLING_WITH_DEVICE(device) \ + REGISTER_POOLING_KERNELS(device, int32_t) \ + REGISTER_POOLING_KERNELS(device, float) \ + REGISTER_POOLING_KERNELS(device, double) + +REGISTER_POOLING_WITH_DEVICE(DeviceType::kCPU) + +#ifdef WITH_CUDA +REGISTER_POOLING_WITH_DEVICE(DeviceType::kGPU) +// TODO: REGISTER_POOLING_KERNELS(DeviceType::kGPU, float16) +#endif + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOLING_KERNEL_UTIL, (DeviceType::kCPU), + POOLING_DATA_TYPE_CPU_SEQ); + +} // namespace oneflow diff --git a/oneflow/user/kernels/pooling_kernel.cu b/oneflow/user/kernels/pooling_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7aef65319fe28441f044e5a1e819ca7779bcc295 --- /dev/null +++ b/oneflow/user/kernels/pooling_kernel.cu @@ -0,0 +1,148 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include <cstdint> +#ifdef WITH_CUDA +#include "oneflow/core/cuda/elementwise.cuh" +#include "oneflow/user/kernels/pooling_kernel_util.h" + +namespace oneflow { + +constexpr int kBlockSize = cuda::elementwise::kBlockSize; + +const int GetMinThreadNum(int64_t elem_num) { return std::min<int64_t>(elem_num, kBlockSize); } + +int GetNumBlocks(int64_t elem_cnt) { + int num_blocks = 0; + OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks)); + return num_blocks; +} + +template<typename T> +__launch_bounds__(kBlockSize) __global__ + void DoCUDAMaxPool2dForward(const NdIndexOffsetHelper<int64_t, 4> index_helper, + int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + int32_t padding_h, int32_t padding_w, int64_t n_batch, + int64_t n_channel, int64_t x_height, int64_t x_width, + int64_t y_height, int64_t y_width, int32_t kernel_size_h, + int32_t kernel_size_w, int32_t stride_h, int32_t stride_w, + int32_t dilation_h, int32_t dilation_w) { + Maxpool2dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w, + n_batch, n_channel, x_height, x_width, y_height, y_width, + kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, + dilation_w); +}; + +template<typename T> +__launch_bounds__(kBlockSize) __global__ + void DoCUDAMaxPool3dForward(const NdIndexOffsetHelper<int64_t, 5> index_helper, + int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + int32_t padding_t, int32_t padding_h, int32_t padding_w, + int64_t n_batch, int64_t n_channel, int64_t x_time, + int64_t x_height, int64_t x_width, int64_t y_time, int64_t y_height, + int64_t y_width, int32_t kernel_size_t, int32_t kernel_size_h, + int32_t kernel_size_w, int32_t stride_t, int32_t stride_h, + int32_t stride_w, int32_t dilation_t, int32_t dilation_h, + int32_t dilation_w) { + Maxpool3dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h, + padding_w, n_batch, n_channel, x_time, x_height, x_width, y_time, + y_height, y_width, kernel_size_t, kernel_size_h, kernel_size_w, + stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w); +}; + +template<typename T> +__launch_bounds__(kBlockSize) __global__ + void DoCUDAMaxPool2dBackward(const NdIndexOffsetHelper<int64_t, 4> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const int64_t n_batch, + const int64_t n_channel, const int64_t src_height, + const int64_t src_width, const int64_t dst_height, + const int64_t dst_width) { + Maxpool2dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, + src_height, src_width, dst_height, dst_width); +}; + +template<typename T> +__launch_bounds__(kBlockSize) __global__ + void DoCUDAMaxPool3dBackward(const NdIndexOffsetHelper<int64_t, 5> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const int64_t n_batch, + const int64_t n_channel, const int64_t src_time, + const int64_t src_height, const int64_t src_width, + const int64_t dst_time, const int64_t dst_height, + const int64_t dst_width) { + Maxpool3dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel, + src_time, src_height, src_width, dst_time, dst_height, dst_width); +}; + +template<typename T> +struct PoolingKernelUtil<DeviceType::kGPU, T> { + static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d) { + DoCUDAMaxPool2dForward<T> + <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1], + params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), + params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), + params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], + params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1], + params_3d.dilation_3d()[2]); + } + + static void Maxpool2dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d) { + DoCUDAMaxPool2dBackward<T> + <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), + params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); + } + + static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d) { + DoCUDAMaxPool3dForward<T> + <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0], + params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], + params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), + params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), + params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0], + params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], + params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2], + params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]); + } + + static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d) { + DoCUDAMaxPool3dBackward<T> + <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>( + index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(), + params_3d.num_channel(), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3), + params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2), + params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4)); + } +}; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOLING_KERNEL_UTIL, (DeviceType::kGPU), + POOLING_DATA_TYPE_GPU_SEQ); + +} // namespace oneflow +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/pooling_kernel_util.cpp b/oneflow/user/kernels/pooling_kernel_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2f502ce2928c8fec6c60ccc5c288d1c074fb275 --- /dev/null +++ b/oneflow/user/kernels/pooling_kernel_util.cpp @@ -0,0 +1,115 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/pooling_kernel_util.h" + +namespace oneflow { + +std::vector<int32_t> Get3DVec(const std::vector<int32_t>& original_vec, int32_t NDims) { + std::vector<int32_t> vec; + FOR_RANGE(uint8_t, dim, 0, 3) { + int64_t index = static_cast<int64_t>(dim) - (3 - NDims); + if (index < 0) { + vec.push_back(1); + } else { + vec.push_back(original_vec.at(index)); + } + } + return vec; +} + +std::vector<int32_t> Get3DPadVec(const std::vector<int32_t>& original_vec, int32_t NDims) { + std::vector<int32_t> vec; + FOR_RANGE(uint8_t, dim, 0, 3) { + int64_t index = static_cast<int64_t>(dim) - (3 - NDims); + if (index < 0) { + vec.push_back(0); + } else { + vec.push_back(original_vec.at(index)); + } + } + return vec; +} + +PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape, + const std::string& data_format, const std::string& padding, + const std::vector<int32_t>& padding_before, + const std::vector<int32_t>& padding_after, + const std::vector<int32_t>& kernel_size, + const std::vector<int32_t>& stride, + const std::vector<int32_t>& dilation, const bool return_indices, + const bool ceil_mode) + : dim_(dim), + data_format_(data_format), + padding_(padding), + padding_before_3d_(Get3DPadVec(padding_before, dim)), + padding_after_3d_(Get3DPadVec(padding_after, dim)), + pooling_size_3d_(Get3DVec(kernel_size, dim)), + stride_3d_(Get3DVec(stride, dim)), + dilation_3d_(Get3DVec(dilation, dim)), + return_indices_(return_indices), + ceil_mode_(ceil_mode) { + x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim), + GetInDim(x_shape, data_format, 2, dim)}; + Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_, + &padding_before_3d_, &padding_after_3d_); + if (data_format == "channels_first") { + channel_num_ = x_shape.At(1); + } else { + CHECK_EQ(data_format_, "channels_last") + << "data_format must be 'channels_first' or 'channels_last'"; + channel_num_ = x_shape.At(x_shape.NumAxes() - 1); + } + batch_num_ = x_shape.At(0); +} + +void PoolingParams3D::Reset(const ShapeView& x_shape) { + x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_), + GetInDim(x_shape, data_format_, 2, dim_)}; + Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_, + &padding_before_3d_, &padding_after_3d_); +} + +Shape PoolingParams3D::GetYShape() const { + DimVector y_dim_vec; + if (dim_ == 1) { + y_dim_vec = {y_3d_.at(2)}; + } else if (dim_ == 2) { + y_dim_vec = {y_3d_.at(1), y_3d_.at(2)}; + } else if (dim_ == 3) { + y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}; + } else { + UNIMPLEMENTED(); + } + if (data_format_ == "channels_first") { + y_dim_vec.insert(y_dim_vec.begin(), channel_num_); + } else { + CHECK_EQ(data_format_, "channels_last") + << "data_format must be 'channels_first' or 'channels_last'"; + y_dim_vec.insert(y_dim_vec.end(), channel_num_); + } + y_dim_vec.insert(y_dim_vec.begin(), batch_num_); + return Shape(y_dim_vec); +} + +Shape PoolingParams3D::GetXShape5D() const { + return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)}); +} + +Shape PoolingParams3D::GetYShape5D() const { + return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)}); +} + +} // namespace oneflow diff --git a/oneflow/user/kernels/pooling_kernel_util.h b/oneflow/user/kernels/pooling_kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..b25ba34e04a3ce8c5027948965c686c69e1134f9 --- /dev/null +++ b/oneflow/user/kernels/pooling_kernel_util.h @@ -0,0 +1,277 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_POOLING_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_POOLING_KERNEL_UTIL_H_ +#include "oneflow/core/device/device_context.h" +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/operator/operator_util.h" +#include "oneflow/core/kernel/util/numerics.cuh" +#include "oneflow/core/kernel/util/numeric_limits.cuh" +#ifdef WITH_CUDA +#include "oneflow/core/cuda/atomic.cuh" +#endif // WITH_CUDA + +namespace oneflow { + +#define POOLING_DATA_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \ + OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ + OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) + +#define POOLING_DATA_TYPE_CPU_SEQ POOLING_DATA_TYPE_SEQ + +#define POOLING_DATA_TYPE_GPU_SEQ POOLING_DATA_TYPE_SEQ + +typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> FixedDimVector; + +template<typename T> +struct DeviceAdd { + OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { +#if defined(__CUDA_ARCH__) + cuda::atomic::Add(y, *x); +#else + *y += *x; +#endif + }; +}; + +class PoolingParams3D { + public: + PoolingParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format, + const std::string& padding, const std::vector<int32_t>& padding_before, + const std::vector<int32_t>& padding_after, + const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride, + const std::vector<int32_t>& dilation, const bool return_indices, + const bool ceil_mode); + ~PoolingParams3D() = default; + + const std::string& data_format() const { return data_format_; } + const std::vector<int32_t>& padding_before_3d() const { return padding_before_3d_; } + const std::vector<int32_t>& padding_after_3d() const { return padding_after_3d_; } + const std::vector<int32_t>& pooling_size_3d() const { return pooling_size_3d_; } + const std::vector<int32_t>& stride_3d() const { return stride_3d_; } + const std::vector<int32_t>& dilation_3d() const { return dilation_3d_; } + const bool& return_indices() const { return return_indices_; } + const bool& ceil_mode() const { return ceil_mode_; } + const int64_t& num_batch() const { return batch_num_; } + const int64_t& num_channel() const { return channel_num_; } + + void Reset(const ShapeView& x_shape); + Shape GetYShape() const; + Shape GetXShape5D() const; + Shape GetYShape5D() const; + + private: + int32_t dim_; + FixedDimVector x_3d_; + FixedDimVector y_3d_; + std::string data_format_; + std::string padding_; + std::vector<int32_t> padding_before_3d_; + std::vector<int32_t> padding_after_3d_; + std::vector<int32_t> pooling_size_3d_; + std::vector<int32_t> stride_3d_; + std::vector<int32_t> dilation_3d_; + bool return_indices_; + bool ceil_mode_; + int64_t batch_num_; + int64_t channel_num_; +}; + +template<DeviceType device_type, typename T> +struct PoolingKernelUtil { + static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d); + + static void Maxpool2dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d); + + static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, + const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr, + const PoolingParams3D& params_3d); + + static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const PoolingParams3D& params_3d); +}; + +template<typename T> +OF_DEVICE_FUNC void Maxpool2dFarwardCompute( + const NdIndexOffsetHelper<int64_t, 4> index_helper, int64_t elem_num, const T* src, T* dest, + int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch, + const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height, + const int64_t y_width, const int32_t kernel_size_h, const int32_t kernel_size_w, + const int32_t stride_h, const int32_t stride_w, const int32_t dilation_h, + const int32_t dilation_w) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, h, w; + index_helper.OffsetToNdIndex(num, n, c, h, w); + + const int64_t start_idx = (n * n_channel + c) * x_width * x_height; + int64_t hstart = h * stride_h - padding_h; + int64_t wstart = w * stride_w - padding_w; + const int64_t hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height + ? (hstart + (kernel_size_h - 1) * dilation_h + 1) + : x_height; + const int64_t wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width + ? (wstart + (kernel_size_w - 1) * dilation_w + 1) + : x_width; + + while (hstart < 0) { hstart += dilation_h; } + while (wstart < 0) { wstart += dilation_w; } + + /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */ + int64_t maxindex = hstart * x_width + wstart; + int64_t src_idx = 0; + + /* equal to -std::numeric_limits<T>::infinity(); */ + T max_value = detail::numeric_limits<T>::lower_bound(); + + for (int64_t i = hstart; i < hend; i += dilation_h) { + for (int64_t j = wstart; j < wend; j += dilation_w) { + const int64_t tcntr = i * x_width + j; + const int64_t search_idx = start_idx + tcntr; + T val = src[search_idx]; + /* NOTE: + std::isnan(val) only supports a few data types, see: + https://en.cppreference.com/w/cpp/numeric/math/isnan and when use gcc/g++ 4.x to compile, + the following exception will be throw: + + new_kernel_util.cu:24] Check failed: cudaMemcpyAsync(dst, src, sz, cudaMemcpyDefault, + ctx->cuda_stream() ) : unspecified launch failure (719) + + but if use gcc/g++ 7.x to compile, everything is ok! the exact reason is still unknown! + */ + if (val > max_value || detail::numerics<T>::isnan(val)) { + max_value = val; + maxindex = tcntr; + src_idx = search_idx; + } + } + } + dest[num] = src[src_idx]; + indice_ptr[num] = maxindex; + } +} + +template<typename T> +OF_DEVICE_FUNC void Maxpool2dBackwardCompute(const NdIndexOffsetHelper<int64_t, 4> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const int64_t n_batch, + const int64_t n_channel, const int64_t src_height, + const int64_t src_width, const int64_t dst_height, + const int64_t dst_width) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, h, w; + index_helper.OffsetToNdIndex(num, n, c, h, w); + + const int64_t src_start = (n * n_channel + c) * src_height * src_width; + const int64_t dst_start = (n * n_channel + c) * dst_height * dst_width; + const int64_t index = src_start + h * src_width + w; + const int64_t maxindex = dst_start + indice_ptr[index]; + if (maxindex != -1) { + /* update gradient, equals to dest[maxindex] += src[index]; */ + DeviceAdd<T>::Invoke(src + index, dest + maxindex); + } + } +} + +template<typename T> +OF_DEVICE_FUNC void Maxpool3dFarwardCompute( + const NdIndexOffsetHelper<int64_t, 5> index_helper, int64_t elem_num, const T* src, T* dest, + int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w, + const int64_t n_batch, const int64_t n_channel, const int64_t x_time, const int64_t x_height, + const int64_t x_width, const int64_t y_time, const int64_t y_height, const int64_t y_width, + const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w, + const int32_t stride_t, const int32_t stride_h, const int32_t stride_w, + const int32_t dilation_t, const int32_t dilation_h, const int32_t dilation_w) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, t, h, w; + index_helper.OffsetToNdIndex(num, n, c, t, h, w); + + int64_t xstart = n * n_channel * x_time * x_width * x_height; + int64_t start_idx = xstart + c * x_time * x_width * x_height; + int64_t tstart = t * stride_t - padding_t; + int64_t hstart = h * stride_h - padding_h; + int64_t wstart = w * stride_w - padding_w; + + const int64_t t1 = tstart + (kernel_size_t - 1) * dilation_t + 1; + const int64_t t2 = hstart + (kernel_size_h - 1) * dilation_h + 1; + const int64_t t3 = wstart + (kernel_size_w - 1) * dilation_w + 1; + const int64_t tend = t1 <= x_time ? t1 : x_time; + const int64_t hend = t2 <= x_height ? t2 : x_height; + const int64_t wend = t3 <= x_width ? t3 : x_width; + + while (tstart < 0) { tstart += dilation_t; } + while (hstart < 0) { hstart += dilation_h; } + while (wstart < 0) { wstart += dilation_w; } + + int64_t maxindex = tstart * x_height * x_width + hstart * x_width + wstart; + int64_t src_idx = 0; + + T max_value = detail::numeric_limits<T>::lower_bound(); + for (int64_t zi = tstart; zi < tend; zi += dilation_t) { + for (int64_t i = hstart; i < hend; i += dilation_h) { + for (int64_t j = wstart; j < wend; j += dilation_w) { + const int64_t tcntr = zi * x_height * x_width + i * x_width + j; + const int64_t search_idx = start_idx + tcntr; + T val = src[search_idx]; + if (val > max_value || detail::numerics<T>::isnan(val)) { + max_value = val; + maxindex = tcntr; + src_idx = search_idx; + } + } + } + /* set output to local max */ + dest[num] = src[src_idx]; + /* store location of max */ + indice_ptr[num] = maxindex; + } + } +} + +template<typename T> +OF_DEVICE_FUNC void Maxpool3dBackwardCompute(const NdIndexOffsetHelper<int64_t, 5> index_helper, + const int64_t elem_num, const T* src, T* dest, + const int64_t* indice_ptr, const int64_t n_batch, + const int64_t n_channel, const int64_t src_time, + const int64_t src_height, const int64_t src_width, + const int64_t dst_time, const int64_t dst_height, + const int64_t dst_width) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, t, h, w; + index_helper.OffsetToNdIndex(num, n, c, t, h, w); + + const int64_t src_start = (n * n_channel + c) * src_time * src_height * src_width; + const int64_t dst_start = (n * n_channel + c) * dst_time * dst_height * dst_width; + const int64_t index = src_start + t * src_height * src_width + h * src_width + w; + const int64_t maxindex = dst_start + indice_ptr[index]; + + if (maxindex != -1) { DeviceAdd<T>::Invoke(src + index, dest + maxindex); } + } +} + +#define INSTANTIATE_POOLING_KERNEL_UTIL(device_type_v, dtype_pair) \ + template struct PoolingKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_POOLING_KERNEL_UTIL_H_ diff --git a/oneflow/user/ops/pooling_op.cpp b/oneflow/user/ops/pooling_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09577f81ecd0f5a0f1159eb3dc020bab81c12dc6 --- /dev/null +++ b/oneflow/user/ops/pooling_op.cpp @@ -0,0 +1,219 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/pooling_kernel_util.h" + +namespace oneflow { + +namespace { + +typedef std::function<Maybe<void>(user_op::InferContext* ctx)> TensorDescInferFn; +typedef std::function<void(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp)> + GenBackwardOpConfFn; + +TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) { + return [dim](user_op::InferContext* ctx) -> Maybe<void> { + const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0); + const std::string& data_format = ctx->Attr<std::string>("data_format"); + const std::string& padding = ctx->Attr<std::string>("padding"); + const std::vector<int32_t>& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); + const std::vector<int32_t>& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size"); + const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride"); + const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation"); + const bool return_indices = ctx->Attr<bool>("return_indices"); + const bool ceil_mode = ctx->Attr<bool>("ceil_mode"); + + CHECK_EQ_OR_RETURN(kernel_size.size(), dim); + for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); } + CHECK_EQ_OR_RETURN(stride.size(), dim); + for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); } + for (int32_t i = 0; i < padding_after.size(); i++) { + CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding_after[i]) + << "pad should be smaller than half of kernel size"; + } + + const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, padding_before, + padding_after, kernel_size, stride, dilation, return_indices, + ceil_mode); + user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0); + *y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0); + *y_desc->mut_shape() = params_3d.GetYShape(); + + user_op::TensorDesc* indice_desc = ctx->TensorDesc4ArgNameAndIndex("indice", 0); + *indice_desc = *ctx->TensorDesc4ArgNameAndIndex("y", 0); + *indice_desc->mut_shape() = *y_desc->mut_shape(); + DataType* dtype = indice_desc->mut_data_type(); + *dtype = kInt64; + return Maybe<void>::Ok(); + }; +} + +Maybe<void> ForwardGetSbpFn(user_op::SbpContext* ctx) { + const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); + const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { + if (padding_before[i] == 0 && padding_after[i] == 0) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("indice", 0), i) + .Build(); + } + } + return Maybe<void>::Ok(); +} + +Maybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) { + *ctx->TensorDesc4ArgNameAndIndex("dx", 0) = *ctx->TensorDesc4ArgNameAndIndex("x", 0); + return Maybe<void>::Ok(); +} + +Maybe<void> BackwardGetSbpFn(user_op::SbpContext* ctx) { + const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before"); + const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after"); + FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) { + if (padding_before[i] == 0 && padding_after[i] == 0) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("indice", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + } + return Maybe<void>::Ok(); +} + +Maybe<void> FwInferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); +} + +Maybe<void> BwInferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); +} + +GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t dim) { + return [mode, dim](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { + if (op.NeedGenGradTensor4OpInput("x", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = + builder.Op(mode + "pool_" + std::to_string(dim) + "d_grad") + .Input("x", op.input("x", 0)) + .Input("y", op.output("y", 0)) + .Input("indice", op.output("indice", 0)) + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) + .Output("dx") + .Attr("data_format", op.attr<std::string>("data_format")) + .Attr("padding", op.attr<std::string>("padding")) + .Attr("padding_before", op.attr<std::vector<int32_t>>("padding_before")) + .Attr("padding_after", op.attr<std::vector<int32_t>>("padding_after")) + .Attr("kernel_size", op.attr<std::vector<int32_t>>("kernel_size")) + .Attr("stride", op.attr<std::vector<int32_t>>("stride")) + .Attr("dilation", op.attr<std::vector<int32_t>>("dilation")) + .Attr("return_indices", op.attr<bool>("return_indices")) + .Attr("ceil_mode", op.attr<bool>("ceil_mode")) + .Build(); + op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); + AddOp(grad_op); + } + }; +} + +} // namespace + +REGISTER_USER_OP("maxpool_2d") + .Input("x") + .Output("y") + .Output("indice") + .Attr<std::string>("padding") + .Attr<std::vector<int32_t>>("padding_before") + .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::string>("data_format") + .Attr<std::vector<int32_t>>("kernel_size") + .Attr<std::vector<int32_t>>("stride") + .Attr<std::vector<int32_t>>("dilation") + .Attr<bool>("return_indices") + .Attr<bool>("ceil_mode") + .SetTensorDescInferFn(MakeForwardTensorDescInferFn(2)) + .SetGetSbpFn(ForwardGetSbpFn) + .SetDataTypeInferFn(FwInferDataType); + +REGISTER_USER_OP("maxpool_2d_grad") + .Input("x") + .Input("y") + .Input("indice") + .Input("dy") + .Output("dx") + .Attr<std::string>("padding") + .Attr<std::vector<int32_t>>("padding_before") + .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::string>("data_format") + .Attr<std::vector<int32_t>>("kernel_size") + .Attr<std::vector<int32_t>>("stride") + .Attr<std::vector<int32_t>>("dilation") + .Attr<bool>("return_indices") + .Attr<bool>("ceil_mode") + .SetTensorDescInferFn(BackwardTensorDescInferFn) + .SetGetSbpFn(BackwardGetSbpFn) + .SetDataTypeInferFn(BwInferDataType); + +REGISTER_USER_OP_GRAD("maxpool_2d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 2)); + +REGISTER_USER_OP("maxpool_3d") + .Input("x") + .Output("y") + .Output("indice") + .Attr<std::string>("padding") + .Attr<std::vector<int32_t>>("padding_before") + .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::string>("data_format") + .Attr<std::vector<int32_t>>("kernel_size") + .Attr<std::vector<int32_t>>("stride") + .Attr<std::vector<int32_t>>("dilation") + .Attr<bool>("return_indices") + .Attr<bool>("ceil_mode") + .SetTensorDescInferFn(MakeForwardTensorDescInferFn(3)) + .SetGetSbpFn(ForwardGetSbpFn) + .SetDataTypeInferFn(FwInferDataType); + +REGISTER_USER_OP("maxpool_3d_grad") + .Input("x") + .Input("y") + .Input("indice") + .Input("dy") + .Output("dx") + .Attr<std::string>("padding") + .Attr<std::vector<int32_t>>("padding_before") + .Attr<std::vector<int32_t>>("padding_after") + .Attr<std::string>("data_format") + .Attr<std::vector<int32_t>>("kernel_size") + .Attr<std::vector<int32_t>>("stride") + .Attr<std::vector<int32_t>>("dilation") + .Attr<bool>("return_indices") + .Attr<bool>("ceil_mode") + .SetTensorDescInferFn(BackwardTensorDescInferFn) + .SetGetSbpFn(BackwardGetSbpFn) + .SetDataTypeInferFn(BwInferDataType); + +REGISTER_USER_OP_GRAD("maxpool_3d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 3)); + +} // namespace oneflow