diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index 4242190fa01c9c224bdad9c7853546869d3c95d2..e58b0d9c1d66b184a40faad8635144d0d5de7c0e 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -49,7 +49,6 @@ Experimental features .. autofunction:: oneflow.experimental.Tensor.argmax .. autofunction:: oneflow.experimental.nn.BatchNorm1d .. autofunction:: oneflow.experimental.nn.BatchNorm2d -.. autofunction:: oneflow.experimental.nn.ReplicationPad2d .. autofunction:: oneflow.experimental.nn.InstanceNorm1d .. autofunction:: oneflow.experimental.nn.InstanceNorm2d .. autofunction:: oneflow.experimental.nn.InstanceNorm3d @@ -71,7 +70,11 @@ Experimental features .. autofunction:: oneflow.experimental.nn.ModuleDict .. autofunction:: oneflow.experimental.nn.Conv1d .. autofunction:: oneflow.experimental.nn.Conv2d +.. autofunction:: oneflow.experimental.nn.ZeroPad2d +.. autofunction:: oneflow.experimental.nn.ReflectionPad2d +.. autofunction:: oneflow.experimental.nn.ReplicationPad2d .. autofunction:: oneflow.experimental.nn.ConstantPad2d +.. autofunction:: oneflow.experimental.nn.ConstantPad3d .. autofunction:: oneflow.experimental.nn.ConvTranspose2d .. autofunction:: oneflow.experimental.nn.Dropout .. autofunction:: oneflow.experimental.slice @@ -232,7 +235,6 @@ Experimental features .. autofunction:: oneflow.experimental.Tensor.ceil .. autofunction:: oneflow.experimental.expm1 .. autofunction:: oneflow.experimental.Tensor.expm1 -.. autofunction:: oneflow.experimental.nn.ReflectionPad2d .. autofunction:: oneflow.experimental.meshgrid .. autofunction:: oneflow.experimental.topk .. autofunction:: oneflow.experimental.Tensor.topk @@ -241,7 +243,6 @@ Experimental features .. autofunction:: oneflow.experimental.nn.GroupNorm .. autofunction:: oneflow.experimental.gather_nd .. autofunction:: oneflow.experimental.scatter_nd -.. autofunction:: oneflow.experimental.nn.ZeroPad2d .. autofunction:: oneflow.experimental.nn.image.flip .. autofunction:: oneflow.experimental.tensor_buffer_to_tensor .. autofunction:: oneflow.experimental.tensor_to_tensor_buffer diff --git a/oneflow/core/autograd/gradient_funcs/pad2d.cpp b/oneflow/core/autograd/gradient_funcs/padding.cpp similarity index 92% rename from oneflow/core/autograd/gradient_funcs/pad2d.cpp rename to oneflow/core/autograd/gradient_funcs/padding.cpp index 7b130315dc1f8c604014b71238d8afb01b97719b..a733cf953c322b89468cb9f0d8124bf516d26109 100644 --- a/oneflow/core/autograd/gradient_funcs/pad2d.cpp +++ b/oneflow/core/autograd/gradient_funcs/padding.cpp @@ -78,13 +78,13 @@ class ReplicationPad2d : public Pad2d { REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPad2d); REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPad2d); -struct ConstantPad2dInterpState : public OpExprInterpState { +struct ConstantPadNdInterpState : public OpExprInterpState { bool requires_grad; std::vector<int64_t> paddings; functional::Scalar padding_value; }; -class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> { +class ConstantPadNd : public OpExprGradFunction<ConstantPadNdInterpState> { public: Maybe<void> Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); @@ -93,7 +93,7 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> { return Maybe<void>::Ok(); } - Maybe<void> Capture(ConstantPad2dInterpState* ctx, const TensorTuple& inputs, + Maybe<void> Capture(ConstantPadNdInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -112,7 +112,7 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> { return Maybe<void>::Ok(); } - Maybe<void> Apply(const ConstantPad2dInterpState* ctx, const TensorTuple& out_grads, + Maybe<void> Apply(const ConstantPadNdInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -127,7 +127,8 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> { AttrMap base_attrs_; }; -REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPad2d); +REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPadNd); +REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad3d", ConstantPadNd); } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index d181e21af21e2002a87702c56c45db5978d2afb7..535b301820954039b0ec083863780bd904cb3910 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -375,6 +375,7 @@ class NormalizationFunctor { class PadFunctor { public: PadFunctor() { + constant_pad_3d_ = CHECK_JUST(one::OpBuilder("constant_pad3d").Input("x").Output("y").Build()); constant_pad_ = CHECK_JUST(one::OpBuilder("constant_pad2d").Input("x").Output("y").Build()); reflect_pad_ = CHECK_JUST(one::OpBuilder("reflection_pad2d").Input("x").Output("y").Build()); replicate_pad_ = CHECK_JUST(one::OpBuilder("replication_pad2d").Input("x").Output("y").Build()); @@ -396,7 +397,14 @@ class PadFunctor { } else { UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type."; } - return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs); + switch (x->shape()->NumAxes()) { + case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs); + case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_, {x}, attrs); + default: + UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but " << x->shape()->NumAxes() + << "d-tensor is not support yet! "; + } + } else if (mode == "reflect") { return OpInterpUtil::Dispatch<Tensor>(*reflect_pad_, {x}, attrs); } else if (mode == "replicate") { @@ -409,6 +417,7 @@ class PadFunctor { private: std::shared_ptr<OpExpr> constant_pad_; + std::shared_ptr<OpExpr> constant_pad_3d_; std::shared_ptr<OpExpr> reflect_pad_; std::shared_ptr<OpExpr> replicate_pad_; }; diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index b47f345d70bda1f4e46761f42409e9086a736143..fa6110e84bd73abbee04a3f32af5e460be7a1c36 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -229,6 +229,8 @@ class PadGradFunctor { PadGradFunctor() { constant_pad_grad_ = CHECK_JUST(one::OpBuilder("constant_pad2d_grad").Input("dy").Output("dx").Build()); + constant_pad_3d_grad_ = + CHECK_JUST(one::OpBuilder("constant_pad3d_grad").Input("dy").Output("dx").Build()); reflect_pad_grad_ = CHECK_JUST(one::OpBuilder("reflection_pad2d_grad").Input("dy").Output("dx").Build()); replicate_pad_grad_ = @@ -251,7 +253,13 @@ class PadGradFunctor { } else { UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type."; } - return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs); + switch (dy->shape()->NumAxes()) { + case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs); + case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_grad_, {dy}, attrs); + default: + UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but " + << dy->shape()->NumAxes() << "d-tensor is not support yet! "; + } } else if (mode == "reflect") { return OpInterpUtil::Dispatch<Tensor>(*reflect_pad_grad_, {dy}, attrs); } else if (mode == "replicate") { @@ -266,6 +274,7 @@ class PadGradFunctor { std::shared_ptr<OpExpr> constant_pad_grad_; std::shared_ptr<OpExpr> reflect_pad_grad_; std::shared_ptr<OpExpr> replicate_pad_grad_; + std::shared_ptr<OpExpr> constant_pad_3d_grad_; }; } // namespace impl diff --git a/oneflow/python/nn/modules/constantpad2d.py b/oneflow/python/nn/modules/constantpad2d.py deleted file mode 100644 index 4a7b4eb08df9f9d6636550f8c44d676eea896b44..0000000000000000000000000000000000000000 --- a/oneflow/python/nn/modules/constantpad2d.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -from __future__ import absolute_import - -from typing import Union - -import oneflow as flow -from oneflow.python.oneflow_export import oneflow_export, experimental_api -from oneflow.python.nn.module import Module - - -@oneflow_export("nn.ConstantPad2d") -@experimental_api -class ConstantPad2d(Module): - r"""The interface is consistent with PyTorch. - The documentation is referenced from: - https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d - - This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`. - - Args: - padding (Union[int, tuple, list]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`) - - value (Union[int, float]): The constant value used for padding. Defaults to 0. - - Shape: - - Input: :math:`(N, C, H_{in}, W_{in})` - - Output: :math:`(N, C, H_{out}, W_{out})` where - - :math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}` - - :math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}` - - For example: - - .. code-block:: python - - >>> import oneflow.experimental as flow - >>> import numpy as np - >>> flow.enable_eager_execution() - >>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1) - >>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) - >>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32)) - >>> output = constantpad_layer_0(input) - >>> output.shape - flow.Size([1, 2, 5, 7]) - >>> output - tensor([[[[ 1., 1., 1., 1., 1., 1., 1.], - [ 1., 1., 0., 1., 2., 1., 1.], - [ 1., 1., 3., 4., 5., 1., 1.], - [ 1., 1., 6., 7., 8., 1., 1.], - [ 1., 1., 1., 1., 1., 1., 1.]], - <BLANKLINE> - [[ 1., 1., 1., 1., 1., 1., 1.], - [ 1., 1., 9., 10., 11., 1., 1.], - [ 1., 1., 12., 13., 14., 1., 1.], - [ 1., 1., 15., 16., 17., 1., 1.], - [ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32) - >>> output_int = constantpad_layer_0(input_int) - >>> output_int - tensor([[[[ 1., 1., 1., 1., 1., 1., 1.], - [ 1., 1., 0., 1., 2., 1., 1.], - [ 1., 1., 3., 4., 5., 1., 1.], - [ 1., 1., 6., 7., 8., 1., 1.], - [ 1., 1., 1., 1., 1., 1., 1.]], - <BLANKLINE> - [[ 1., 1., 1., 1., 1., 1., 1.], - [ 1., 1., 9., 10., 11., 1., 1.], - [ 1., 1., 12., 13., 14., 1., 1.], - [ 1., 1., 15., 16., 17., 1., 1.], - [ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32) - """ - - def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0): - super().__init__() - if isinstance(padding, (tuple, list)): - assert len(padding) == 4, ValueError("Length of padding must be 4") - boundary = [padding[0], padding[1], padding[2], padding[3]] - elif isinstance(padding, int): - boundary = [padding, padding, padding, padding] - else: - raise ValueError("padding must be int or list or tuple!") - - self.padding = boundary - self.value = value - - def forward(self, x): - _, _, h, w = x.shape - - if x.dtype in [flow.float32, flow.float16, flow.float64]: - floating_value = float(self.value) - integral_value = int(0) - else: - floating_value = float(0) - integral_value = int(self.value) - - self._op = ( - flow.builtin_op("constant_pad2d") - .Input("x") - .Output("y") - .Attr("padding", self.padding) - .Attr("floating_value", floating_value) - .Attr("integral_value", integral_value) - .Build() - ) - - res = self._op(x)[0] - return res - - -if __name__ == "__main__": - import doctest - - doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/nn/modules/padding.py b/oneflow/python/nn/modules/padding.py index 1d24752e86098db6f278dafeff51addbfbdf8f8c..edbea801ea479f3d9727d5ab180323dfa164d4fd 100644 --- a/oneflow/python/nn/modules/padding.py +++ b/oneflow/python/nn/modules/padding.py @@ -190,6 +190,176 @@ class ReflectionPad2d(Module): return "{}".format(self.padding) +@oneflow_export("nn.ConstantPad2d") +@experimental_api +class ConstantPad2d(Module): + r"""The interface is consistent with PyTorch. + The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d + + This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`. + + Args: + padding (Union[int, tuple, list]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`) + + value (Union[int, float]): The constant value used for padding. Defaults to 0. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + :math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}` + + :math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}` + + For example: + + .. code-block:: python + + >>> import oneflow.experimental as flow + >>> import numpy as np + >>> flow.enable_eager_execution() + >>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1) + >>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)) + >>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32)) + >>> output = constantpad_layer_0(input) + >>> output.shape + flow.Size([1, 2, 5, 7]) + >>> output + tensor([[[[ 1., 1., 1., 1., 1., 1., 1.], + [ 1., 1., 0., 1., 2., 1., 1.], + [ 1., 1., 3., 4., 5., 1., 1.], + [ 1., 1., 6., 7., 8., 1., 1.], + [ 1., 1., 1., 1., 1., 1., 1.]], + <BLANKLINE> + [[ 1., 1., 1., 1., 1., 1., 1.], + [ 1., 1., 9., 10., 11., 1., 1.], + [ 1., 1., 12., 13., 14., 1., 1.], + [ 1., 1., 15., 16., 17., 1., 1.], + [ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32) + >>> output_int = constantpad_layer_0(input_int) + >>> output_int + tensor([[[[ 1., 1., 1., 1., 1., 1., 1.], + [ 1., 1., 0., 1., 2., 1., 1.], + [ 1., 1., 3., 4., 5., 1., 1.], + [ 1., 1., 6., 7., 8., 1., 1.], + [ 1., 1., 1., 1., 1., 1., 1.]], + <BLANKLINE> + [[ 1., 1., 1., 1., 1., 1., 1.], + [ 1., 1., 9., 10., 11., 1., 1.], + [ 1., 1., 12., 13., 14., 1., 1.], + [ 1., 1., 15., 16., 17., 1., 1.], + [ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32) + """ + + def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0): + super().__init__() + if isinstance(padding, (tuple, list)): + assert len(padding) == 4, ValueError("Length of padding must be 4") + boundary = [padding[0], padding[1], padding[2], padding[3]] + elif isinstance(padding, int): + boundary = [padding, padding, padding, padding] + else: + raise ValueError("padding must be int or list or tuple!") + + self.padding = boundary + self.value = value + + def forward(self, x): + if x.dtype in (flow.float32, flow.float16, flow.float64): + self.value = float(self.value) + else: + self.value = int(self.value) + + return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value) + + +@oneflow_export("nn.ConstantPad3d") +@experimental_api +class ConstantPad3d(Module): + r"""Pads the input tensor boundaries with a constant value. + The interface is consistent with PyTorch, and referenced from: + https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad3d.html?highlight=constantpad3d#torch.nn.ConstantPad3d + + For `N`-dimensional padding, use :func:`flow.nn.functional.pad()`. + + Args: + padding (int, list, tuple): the size of the padding. If is `int`, uses the same + padding in all boundaries. If a 6-`tuple`, uses + (:math:`\text{padding_left}`, :math:`\text{padding_right}`, + :math:`\text{padding_top}`, :math:`\text{padding_bottom}`, + :math:`\text{padding_front}`, :math:`\text{padding_back}`) + + value (Union[int, float]): The constant value used for padding. Defaults to 0. + + Shape: + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where + + :math:`D_{out} = D_{in} + \text{padding_front} + \text{padding_back}` + + :math:`H_{out} = H_{in} + \text{padding_top} + \text{padding_bottom}` + + :math:`W_{out} = W_{in} + \text{padding_left} + \text{padding_right}` + + Examples:: + + >>> import oneflow.experimental as flow + >>> import numpy as np + + >>> input = flow.tensor(np.arange(8).reshape(1,1,2,2,2).astype(np.int32)) + >>> m = flow.nn.ConstantPad3d(padding=1, value=9) + >>> output = m(input) + >>> output + tensor([[[[[9, 9, 9, 9], + [9, 9, 9, 9], + [9, 9, 9, 9], + [9, 9, 9, 9]], + <BLANKLINE> + [[9, 9, 9, 9], + [9, 0, 1, 9], + [9, 2, 3, 9], + [9, 9, 9, 9]], + <BLANKLINE> + [[9, 9, 9, 9], + [9, 4, 5, 9], + [9, 6, 7, 9], + [9, 9, 9, 9]], + <BLANKLINE> + [[9, 9, 9, 9], + [9, 9, 9, 9], + [9, 9, 9, 9], + [9, 9, 9, 9]]]]], dtype=oneflow.int32) + """ + + def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0): + super().__init__() + if isinstance(padding, (tuple, list)): + assert len(padding) == 6, ValueError("Length of padding must be 6") + boundary = [ + padding[0], + padding[1], + padding[2], + padding[3], + padding[4], + padding[5], + ] + elif isinstance(padding, int): + boundary = [padding, padding, padding, padding, padding, padding] + else: + raise ValueError("padding must be int or list or tuple!") + + self.padding = boundary + self.value = value + + def forward(self, x): + if x.dtype in (flow.float32, flow.float16, flow.float64): + self.value = float(self.value) + else: + self.value = int(self.value) + return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value) + + if __name__ == "__main__": import doctest diff --git a/oneflow/python/test/modules/test_constantpad2d.py b/oneflow/python/test/modules/test_constantpad.py similarity index 81% rename from oneflow/python/test/modules/test_constantpad2d.py rename to oneflow/python/test/modules/test_constantpad.py index d3a42eb45e956e1a47169cd7425efc6787742de1..22391dbb728f52c38a5ffb481473a0b8858adc70 100644 --- a/oneflow/python/test/modules/test_constantpad2d.py +++ b/oneflow/python/test/modules/test_constantpad.py @@ -111,7 +111,7 @@ class TestConstantPad2dModule(flow.unittest.TestCase): def test_with_random_data(test_case): for device in ["cpu", "cuda"]: - spatial_size = np.random.randint(10, 20) + spatial_size = np.random.randint(1, 6) test_module_against_pytorch( test_case, "nn.ConstantPad2d", @@ -120,8 +120,31 @@ class TestConstantPad2dModule(flow.unittest.TestCase): "input": random_tensor( ndim=4, dim2=spatial_size, dim3=spatial_size ), - "padding": random(0, 6), - "value": random(0, 6), + "padding": random(0, 3), + "value": random(0, 10), + }, + device=device, + ) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestConstantPad3dModule(flow.unittest.TestCase): + def test_with_random_data(test_case): + for device in ["cpu", "cuda"]: + spatial_size = np.random.randint(1, 6) + test_module_against_pytorch( + test_case, + "nn.ConstantPad3d", + extra_annotations={"padding": int, "value": float}, + extra_generators={ + "input": random_tensor( + ndim=5, dim2=spatial_size, dim3=spatial_size, dim4=spatial_size + ), + "padding": random(0, 3), + "value": random(0, 10), }, device=device, ) diff --git a/oneflow/user/kernels/constantpad3d_kernel.cpp b/oneflow/user/kernels/constantpad3d_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d146fb8d9a3b8ee2f87e338762734c4c5b51101d --- /dev/null +++ b/oneflow/user/kernels/constantpad3d_kernel.cpp @@ -0,0 +1,147 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/device/memory_copier.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/user/kernels/constantpad3d_kernel_util.h" + +namespace oneflow { + +namespace { + +template<typename T> +T GetDtypeMatchedValue(double floating, int64_t integral); + +template<> +float16 GetDtypeMatchedValue(double floating, int64_t integral) { + return static_cast<float16>(floating); +} + +template<> +float GetDtypeMatchedValue(double floating, int64_t integral) { + return static_cast<float>(floating); +} + +template<> +double GetDtypeMatchedValue(double floating, int64_t integral) { + return floating; +} + +template<> +int8_t GetDtypeMatchedValue(double floating, int64_t integral) { + return static_cast<int8_t>(integral); +} + +template<> +int32_t GetDtypeMatchedValue(double floating, int64_t integral) { + return static_cast<int32_t>(integral); +} + +template<> +int64_t GetDtypeMatchedValue(double floating, int64_t integral) { + return integral; +} + +} // namespace + +namespace user_op { + +template<DeviceType device_type, typename IN_T> +class ConstantPad3dKernel final : public OpKernel { + public: + ConstantPad3dKernel() = default; + ~ConstantPad3dKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + const ShapeView& x_shape = x->shape(); + const ShapeView& y_shape = y->shape(); + CHECK_EQ(x->shape().NumAxes(), 5); + const std::vector<int64_t>& padding = ctx->Attr<std::vector<int64_t>>("padding"); + CHECK_EQ(padding.size(), 6); + const IN_T constant_value = GetDtypeMatchedValue<IN_T>(ctx->Attr<double>("floating_value"), + ctx->Attr<int64_t>("integral_value")); + + IN_T* dest = y->mut_dptr<IN_T>(); + const IN_T* src = x->dptr<IN_T>(); + DimVector y_vector; + y->shape().ToDimVector(&y_vector); + NdIndexOffsetHelper<int64_t, 5> index_helper(y_vector.data()); + + ConstantPad3dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, x_shape, + y_shape, padding, constant_value); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template<DeviceType device_type, typename IN_T> +class ConstantPad3dGradKernel final : public OpKernel { + public: + ConstantPad3dGradKernel() = default; + ~ConstantPad3dGradKernel() = default; + + private: + void Compute(KernelComputeContext* ctx) const override { + const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + CHECK_EQ(dy->shape().NumAxes(), 5); + Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + const ShapeView& dx_shape = dx->shape(); + const ShapeView& dy_shape = dy->shape(); + + const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); + CHECK_EQ(padding.size(), 6); + + const IN_T* src = dy->dptr<IN_T>(); + IN_T* dest = dx->mut_dptr<IN_T>(); + DimVector dy_vector; + dy->shape().ToDimVector(&dy_vector); + NdIndexOffsetHelper<int64_t, 5> index_helper(dy_vector.data()); + + size_t out_bytes_size = dx->shape().Count(0) * GetSizeOfDataType(dx->data_type()); + Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); + + ConstantPad3dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, + dy_shape, dx_shape, padding); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CONSTANT_PAD3D_KERNELS(device, dtype) \ + REGISTER_USER_KERNEL("constant_pad3d") \ + .SetCreateFn<ConstantPad3dKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("constant_pad3d_grad") \ + .SetCreateFn<ConstantPad3dGradKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); + +#define REGISTER_CONSTANT_PAD3D_WITH_DEVICE(device) \ + REGISTER_CONSTANT_PAD3D_KERNELS(device, float) \ + REGISTER_CONSTANT_PAD3D_KERNELS(device, double) \ + REGISTER_CONSTANT_PAD3D_KERNELS(device, int32_t) + +REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kCPU) +#ifdef WITH_CUDA +REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kGPU) +REGISTER_CONSTANT_PAD3D_KERNELS(DeviceType::kGPU, float16) +#endif + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/constantpad3d_kernel_util.cpp b/oneflow/user/kernels/constantpad3d_kernel_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0282cd2d3d0e266fade148844e1de03d323808ad --- /dev/null +++ b/oneflow/user/kernels/constantpad3d_kernel_util.cpp @@ -0,0 +1,64 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/constantpad3d_kernel_util.h" +#include "oneflow/core/framework/framework.h" + +namespace oneflow { +namespace user_op { + +template<typename IN_T> +struct ConstantPad3dFunctor<DeviceType::kCPU, IN_T> final { + void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector<int64_t>& padding, + IN_T constant_value) { + // for NCDHW format input tensor, index of n, c, d, h, w is 0, 1, 2, 3, 4 + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + // padding vector: [left, right, top, bottom, front, back] + DoConstantPad3d<IN_T>(src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx), + y_shape.At(d_idx), y_shape.At(h_idx), y_shape.At(w_idx), + x_shape.At(d_idx), x_shape.At(h_idx), x_shape.At(w_idx), padding[4], + padding[0], padding[2], constant_value); + } +}; + +template<typename IN_T> +struct ConstantPad3dGradFunctor<DeviceType::kCPU, IN_T> final { + void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape, + const ShapeView& dx_shape, const std::vector<int64_t>& padding) { + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + DoConstantPad3dGrad<IN_T>(src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx), + dy_shape.At(d_idx), dy_shape.At(h_idx), dy_shape.At(w_idx), + dx_shape.At(d_idx), dx_shape.At(h_idx), dx_shape.At(w_idx), + padding[4], padding[0], padding[2]); + } +}; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR, (DeviceType::kCPU), + PADDING_DATA_TYPE_CPU_SEQ); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR, (DeviceType::kCPU), + PADDING_DATA_TYPE_CPU_SEQ); + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/constantpad3d_kernel_util.cu b/oneflow/user/kernels/constantpad3d_kernel_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..459177be8259dff5947d3a78d7f3a72bf1bb5f87 --- /dev/null +++ b/oneflow/user/kernels/constantpad3d_kernel_util.cu @@ -0,0 +1,125 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include <cstdint> +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/constantpad3d_kernel_util.h" + +namespace oneflow { +namespace user_op { + +template<typename IN_T> +__global__ void DoCUDAConstantPad3d(const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5> index_helper, + int64_t elem_num, int64_t n_channel, int64_t y_depth, + int64_t y_height, int64_t y_width, int64_t x_depth, + int64_t x_height, int64_t x_width, int64_t pad_front, + int64_t pad_left, int64_t pad_top, const IN_T const_value) { + DoConstantPad3d<IN_T>(src, dest, index_helper, elem_num, n_channel, y_depth, y_height, y_width, + x_depth, x_height, x_width, pad_front, pad_left, pad_top, const_value); +}; + +template<typename IN_T> +__global__ void DoCUDAConstantPad3dGrad(const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5> index_helper, + int64_t elem_num, int64_t n_channel, int64_t dy_depth, + int64_t dy_height, int64_t dy_width, int64_t dx_depth, + int64_t dx_height, int64_t dx_width, int64_t pad_front, + int64_t pad_left, int64_t pad_top) { + DoConstantPad3dGrad<IN_T>(src, dest, index_helper, elem_num, n_channel, dy_depth, dy_height, + dy_width, dx_height, dx_depth, dx_width, pad_front, pad_left, pad_top); +}; + +template<typename IN_T> +struct ConstantPad3dFunctor<DeviceType::kGPU, IN_T> final { + void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector<int64_t>& padding, + IN_T constant_value) { + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DoCUDAConstantPad3d<IN_T><<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0, + ctx->cuda_stream()>>>( + src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx), y_shape.At(d_idx), + y_shape.At(h_idx), y_shape.At(w_idx), x_shape.At(d_idx), x_shape.At(h_idx), + x_shape.At(w_idx), padding[4], padding[0], padding[2], constant_value); + } +}; + +// float16 implementation +template<> +void ConstantPad3dFunctor<DeviceType::kGPU, float16>::operator()( + DeviceCtx* ctx, const float16* src, float16* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector<int64_t>& padding, float16 constant_value) { + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + DoCUDAConstantPad3d<half> + <<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, + y_shape.Count(0), y_shape.At(c_idx), y_shape.At(d_idx), y_shape.At(h_idx), + y_shape.At(w_idx), x_shape.At(d_idx), x_shape.At(h_idx), x_shape.At(w_idx), padding[4], + padding[0], padding[2], static_cast<const half>(constant_value)); +} + +template<typename IN_T> +struct ConstantPad3dGradFunctor<DeviceType::kGPU, IN_T> final { + void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape, + const ShapeView& dx_shape, const std::vector<int64_t>& padding) { + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + DoCUDAConstantPad3dGrad<IN_T><<<BlocksNum4ThreadsNum(dy_shape.Count(0)), + kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(d_idx), + dy_shape.At(h_idx), dy_shape.At(w_idx), dx_shape.At(d_idx), dx_shape.At(h_idx), + dx_shape.At(w_idx), padding[4], padding[0], padding[2]); + } +}; + +// float16 implementation +template<> +void ConstantPad3dGradFunctor<DeviceType::kGPU, float16>::operator()( + DeviceCtx* ctx, const float16* src, float16* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape, + const ShapeView& dx_shape, const std::vector<int64_t>& padding) { + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + DoCUDAConstantPad3dGrad<half> + <<<BlocksNum4ThreadsNum(dy_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, + dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(d_idx), dy_shape.At(h_idx), + dy_shape.At(w_idx), dx_shape.At(d_idx), dx_shape.At(h_idx), dx_shape.At(w_idx), + padding[4], padding[0], padding[2]); +} + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ); + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/constantpad3d_kernel_util.h b/oneflow/user/kernels/constantpad3d_kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0839dd5b8b00c4a9d72ef61b1eb595f26405daa7 --- /dev/null +++ b/oneflow/user/kernels/constantpad3d_kernel_util.h @@ -0,0 +1,122 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_ +#define ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_ +#ifdef WITH_CUDA +#include "oneflow/core/cuda/atomic.cuh" +#endif // WITH_CUDA +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/ndarray/xpu_util.h" + +namespace oneflow { + +#define PADDING_DATA_TYPE_CPU_SEQ \ + FLOATING_DATA_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) + +#define PADDING_DATA_TYPE_GPU_SEQ \ + FLOAT16_DATA_TYPE_SEQ \ + PADDING_DATA_TYPE_CPU_SEQ + +namespace user_op { + +template<typename T> +struct DeviceAdd { + OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { +#if defined(__CUDA_ARCH__) + cuda::atomic::Add(y, *x); +#else + *y += *x; +#endif + }; +}; + +template<DeviceType device_type, typename IN_T> +struct ConstantPad3dFunctor final { + void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector<int64_t>& padding, + IN_T constant_value); +}; + +template<DeviceType device_type, typename IN_T> +struct ConstantPad3dGradFunctor final { + void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape, + const ShapeView& dx_shape, const std::vector<int64_t>& padding); +}; + +template<typename IN_T> +OF_DEVICE_FUNC void DoConstantPad3d(const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, + int64_t elem_num, int64_t n_channel, int64_t y_depth, + int64_t y_height, int64_t y_width, int64_t x_depth, + int64_t x_height, int64_t x_width, int64_t pad_front, + int64_t pad_left, int64_t pad_top, IN_T constant_value) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, d, h, w; + index_helper.OffsetToNdIndex(num, n, c, d, h, w); + + const int64_t src_num = n_channel * x_depth * x_height * x_width; + if (pad_front <= d && d < pad_front + x_depth && w >= pad_left && w < x_width + pad_left + && h >= pad_top && h < x_height + pad_top) { + const int64_t len_w = w - pad_left; + const int64_t len_h = h - pad_top; + const int64_t len_d = d - pad_front; + const int64_t src_index = n * src_num + c * x_depth * x_width * x_height + + len_d * x_height * x_width + len_h * x_width + len_w; + dest[num] = src[src_index]; + } else { + dest[num] = constant_value; + } + } +} + +template<typename IN_T> +OF_DEVICE_FUNC void DoConstantPad3dGrad(const IN_T* src, IN_T* dest, + const NdIndexOffsetHelper<int64_t, 5>& index_helper, + int64_t elem_num, int64_t n_channel, int64_t dy_depth, + int64_t dy_height, int64_t dy_width, int64_t dx_depth, + int64_t dx_height, int64_t dx_width, int64_t pad_front, + int64_t pad_left, int64_t pad_top) { + XPU_1D_KERNEL_LOOP(num, elem_num) { + int64_t n, c, d, h, w; + index_helper.OffsetToNdIndex(num, n, c, d, h, w); + + const int64_t dest_num = n_channel * dx_depth * dx_height * dx_width; + if (pad_front <= d && d < pad_front + dx_depth && w >= pad_left && w < dx_width + pad_left + && h >= pad_top && h < dx_height + pad_top) { + const int64_t len_d = d - pad_front; + const int64_t len_w = w - pad_left; + const int64_t len_h = h - pad_top; + const int64_t dest_index = n * dest_num + c * dx_depth * dx_width * dx_height + + len_d * dx_width * dx_height + len_h * dx_width + len_w; + + DeviceAdd<IN_T>::Invoke(src + num, dest + dest_index); + } + } +} + +#define INSTANTIATE_CONSTANT_PAD3D_FUNCTOR(device_type_v, dtype_pair) \ + template struct ConstantPad3dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +#define INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR(device_type_v, dtype_pair) \ + template struct ConstantPad3dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +} // namespace user_op +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_ diff --git a/oneflow/user/ops/pad2d_ops.cpp b/oneflow/user/ops/padding_ops.cpp similarity index 75% rename from oneflow/user/ops/pad2d_ops.cpp rename to oneflow/user/ops/padding_ops.cpp index a64e0092d81a8027d1d57d5b3187ded5bfb1f734..4e84c86a514ffe816eff1a349934eae40e760e1e 100644 --- a/oneflow/user/ops/pad2d_ops.cpp +++ b/oneflow/user/ops/padding_ops.cpp @@ -343,4 +343,113 @@ REGISTER_USER_OP_GRAD("constant_pad2d") return Maybe<void>::Ok(); }); +REGISTER_USER_OP("constant_pad3d") + .Input("x") + .Output("y") + .Attr<std::vector<int64_t>>("padding") + .Attr<double>("floating_value") + .Attr<int64_t>("integral_value") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); + CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5); + // only support NCDHW format input tensor for now ! + // for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4 + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t d_x = x_shape.At(d_idx); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[d_idx] = d_x + padding[4] + padding[5]; + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn(GetOpSbpSignature) + .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) -> Maybe<void> { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("constant_pad3d_grad") + .Input("dy") + .Output("dx") + .Attr<std::vector<int64_t>>("padding") + .Attr<double>("floating_value") + .Attr<int64_t>("integral_value") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); + CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t d_dy, h_dy, w_dy; + d_dy = dy_shape.At(d_idx); + h_dy = dy_shape.At(h_idx); + w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5]; + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn(GetOpGradSbpSignature) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP_GRAD("constant_pad3d") + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, + user_op::AddOpFn AddOp) -> Maybe<void> { + if (op.NeedGenGradTensor4OpInput("x", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = + builder.Op("constant_pad3d_grad") + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) + .Output("dx") + .Attr("padding", op.attr<std::vector<int64_t>>("padding")) + .Attr("floating_value", op.attr<double>("floating_value")) + .Attr("integral_value", op.attr<int64_t>("integral_value")) + .Build(); + op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); + AddOp(grad_op); + } + return Maybe<void>::Ok(); + }); + } // namespace oneflow