Skip to content
Snippets Groups Projects
Unverified Commit 4bab7aa5 authored by Luyang's avatar Luyang Committed by GitHub
Browse files

implementation of constantpad-3d op (#5529)


* implementation of constantpad-3d op

* format

* refine

* rename files

* refine

* refine test case

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 73dd59d6
No related branches found
No related tags found
No related merge requests found
Showing with 793 additions and 140 deletions
......@@ -49,7 +49,6 @@ Experimental features
.. autofunction:: oneflow.experimental.Tensor.argmax
.. autofunction:: oneflow.experimental.nn.BatchNorm1d
.. autofunction:: oneflow.experimental.nn.BatchNorm2d
.. autofunction:: oneflow.experimental.nn.ReplicationPad2d
.. autofunction:: oneflow.experimental.nn.InstanceNorm1d
.. autofunction:: oneflow.experimental.nn.InstanceNorm2d
.. autofunction:: oneflow.experimental.nn.InstanceNorm3d
......@@ -71,7 +70,11 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.ModuleDict
.. autofunction:: oneflow.experimental.nn.Conv1d
.. autofunction:: oneflow.experimental.nn.Conv2d
.. autofunction:: oneflow.experimental.nn.ZeroPad2d
.. autofunction:: oneflow.experimental.nn.ReflectionPad2d
.. autofunction:: oneflow.experimental.nn.ReplicationPad2d
.. autofunction:: oneflow.experimental.nn.ConstantPad2d
.. autofunction:: oneflow.experimental.nn.ConstantPad3d
.. autofunction:: oneflow.experimental.nn.ConvTranspose2d
.. autofunction:: oneflow.experimental.nn.Dropout
.. autofunction:: oneflow.experimental.slice
......@@ -232,7 +235,6 @@ Experimental features
.. autofunction:: oneflow.experimental.Tensor.ceil
.. autofunction:: oneflow.experimental.expm1
.. autofunction:: oneflow.experimental.Tensor.expm1
.. autofunction:: oneflow.experimental.nn.ReflectionPad2d
.. autofunction:: oneflow.experimental.meshgrid
.. autofunction:: oneflow.experimental.topk
.. autofunction:: oneflow.experimental.Tensor.topk
......@@ -241,7 +243,6 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.GroupNorm
.. autofunction:: oneflow.experimental.gather_nd
.. autofunction:: oneflow.experimental.scatter_nd
.. autofunction:: oneflow.experimental.nn.ZeroPad2d
.. autofunction:: oneflow.experimental.nn.image.flip
.. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
.. autofunction:: oneflow.experimental.tensor_to_tensor_buffer
......
......@@ -78,13 +78,13 @@ class ReplicationPad2d : public Pad2d {
REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPad2d);
REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPad2d);
struct ConstantPad2dInterpState : public OpExprInterpState {
struct ConstantPadNdInterpState : public OpExprInterpState {
bool requires_grad;
std::vector<int64_t> paddings;
functional::Scalar padding_value;
};
class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
class ConstantPadNd : public OpExprGradFunction<ConstantPadNdInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
......@@ -93,7 +93,7 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
return Maybe<void>::Ok();
}
Maybe<void> Capture(ConstantPad2dInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(ConstantPadNdInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
......@@ -112,7 +112,7 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ConstantPad2dInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ConstantPadNdInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
......@@ -127,7 +127,8 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPad2d);
REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad3d", ConstantPadNd);
} // namespace one
} // namespace oneflow
......@@ -375,6 +375,7 @@ class NormalizationFunctor {
class PadFunctor {
public:
PadFunctor() {
constant_pad_3d_ = CHECK_JUST(one::OpBuilder("constant_pad3d").Input("x").Output("y").Build());
constant_pad_ = CHECK_JUST(one::OpBuilder("constant_pad2d").Input("x").Output("y").Build());
reflect_pad_ = CHECK_JUST(one::OpBuilder("reflection_pad2d").Input("x").Output("y").Build());
replicate_pad_ = CHECK_JUST(one::OpBuilder("replication_pad2d").Input("x").Output("y").Build());
......@@ -396,7 +397,14 @@ class PadFunctor {
} else {
UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type.";
}
return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs);
switch (x->shape()->NumAxes()) {
case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs);
case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_, {x}, attrs);
default:
UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but " << x->shape()->NumAxes()
<< "d-tensor is not support yet! ";
}
} else if (mode == "reflect") {
return OpInterpUtil::Dispatch<Tensor>(*reflect_pad_, {x}, attrs);
} else if (mode == "replicate") {
......@@ -409,6 +417,7 @@ class PadFunctor {
private:
std::shared_ptr<OpExpr> constant_pad_;
std::shared_ptr<OpExpr> constant_pad_3d_;
std::shared_ptr<OpExpr> reflect_pad_;
std::shared_ptr<OpExpr> replicate_pad_;
};
......
......@@ -229,6 +229,8 @@ class PadGradFunctor {
PadGradFunctor() {
constant_pad_grad_ =
CHECK_JUST(one::OpBuilder("constant_pad2d_grad").Input("dy").Output("dx").Build());
constant_pad_3d_grad_ =
CHECK_JUST(one::OpBuilder("constant_pad3d_grad").Input("dy").Output("dx").Build());
reflect_pad_grad_ =
CHECK_JUST(one::OpBuilder("reflection_pad2d_grad").Input("dy").Output("dx").Build());
replicate_pad_grad_ =
......@@ -251,7 +253,13 @@ class PadGradFunctor {
} else {
UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type.";
}
return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs);
switch (dy->shape()->NumAxes()) {
case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs);
case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_grad_, {dy}, attrs);
default:
UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but "
<< dy->shape()->NumAxes() << "d-tensor is not support yet! ";
}
} else if (mode == "reflect") {
return OpInterpUtil::Dispatch<Tensor>(*reflect_pad_grad_, {dy}, attrs);
} else if (mode == "replicate") {
......@@ -266,6 +274,7 @@ class PadGradFunctor {
std::shared_ptr<OpExpr> constant_pad_grad_;
std::shared_ptr<OpExpr> reflect_pad_grad_;
std::shared_ptr<OpExpr> replicate_pad_grad_;
std::shared_ptr<OpExpr> constant_pad_3d_grad_;
};
} // namespace impl
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from typing import Union
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
@oneflow_export("nn.ConstantPad2d")
@experimental_api
class ConstantPad2d(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d
This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`.
Args:
padding (Union[int, tuple, list]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
value (Union[int, float]): The constant value used for padding. Defaults to 0.
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
:math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1)
>>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))
>>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32))
>>> output = constantpad_layer_0(input)
>>> output.shape
flow.Size([1, 2, 5, 7])
>>> output
tensor([[[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 0., 1., 2., 1., 1.],
[ 1., 1., 3., 4., 5., 1., 1.],
[ 1., 1., 6., 7., 8., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]],
<BLANKLINE>
[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 9., 10., 11., 1., 1.],
[ 1., 1., 12., 13., 14., 1., 1.],
[ 1., 1., 15., 16., 17., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32)
>>> output_int = constantpad_layer_0(input_int)
>>> output_int
tensor([[[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 0., 1., 2., 1., 1.],
[ 1., 1., 3., 4., 5., 1., 1.],
[ 1., 1., 6., 7., 8., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]],
<BLANKLINE>
[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 9., 10., 11., 1., 1.],
[ 1., 1., 12., 13., 14., 1., 1.],
[ 1., 1., 15., 16., 17., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32)
"""
def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
super().__init__()
if isinstance(padding, (tuple, list)):
assert len(padding) == 4, ValueError("Length of padding must be 4")
boundary = [padding[0], padding[1], padding[2], padding[3]]
elif isinstance(padding, int):
boundary = [padding, padding, padding, padding]
else:
raise ValueError("padding must be int or list or tuple!")
self.padding = boundary
self.value = value
def forward(self, x):
_, _, h, w = x.shape
if x.dtype in [flow.float32, flow.float16, flow.float64]:
floating_value = float(self.value)
integral_value = int(0)
else:
floating_value = float(0)
integral_value = int(self.value)
self._op = (
flow.builtin_op("constant_pad2d")
.Input("x")
.Output("y")
.Attr("padding", self.padding)
.Attr("floating_value", floating_value)
.Attr("integral_value", integral_value)
.Build()
)
res = self._op(x)[0]
return res
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
......@@ -190,6 +190,176 @@ class ReflectionPad2d(Module):
return "{}".format(self.padding)
@oneflow_export("nn.ConstantPad2d")
@experimental_api
class ConstantPad2d(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d
This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`.
Args:
padding (Union[int, tuple, list]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
value (Union[int, float]): The constant value used for padding. Defaults to 0.
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
:math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1)
>>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))
>>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32))
>>> output = constantpad_layer_0(input)
>>> output.shape
flow.Size([1, 2, 5, 7])
>>> output
tensor([[[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 0., 1., 2., 1., 1.],
[ 1., 1., 3., 4., 5., 1., 1.],
[ 1., 1., 6., 7., 8., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]],
<BLANKLINE>
[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 9., 10., 11., 1., 1.],
[ 1., 1., 12., 13., 14., 1., 1.],
[ 1., 1., 15., 16., 17., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32)
>>> output_int = constantpad_layer_0(input_int)
>>> output_int
tensor([[[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 0., 1., 2., 1., 1.],
[ 1., 1., 3., 4., 5., 1., 1.],
[ 1., 1., 6., 7., 8., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]],
<BLANKLINE>
[[ 1., 1., 1., 1., 1., 1., 1.],
[ 1., 1., 9., 10., 11., 1., 1.],
[ 1., 1., 12., 13., 14., 1., 1.],
[ 1., 1., 15., 16., 17., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32)
"""
def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
super().__init__()
if isinstance(padding, (tuple, list)):
assert len(padding) == 4, ValueError("Length of padding must be 4")
boundary = [padding[0], padding[1], padding[2], padding[3]]
elif isinstance(padding, int):
boundary = [padding, padding, padding, padding]
else:
raise ValueError("padding must be int or list or tuple!")
self.padding = boundary
self.value = value
def forward(self, x):
if x.dtype in (flow.float32, flow.float16, flow.float64):
self.value = float(self.value)
else:
self.value = int(self.value)
return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value)
@oneflow_export("nn.ConstantPad3d")
@experimental_api
class ConstantPad3d(Module):
r"""Pads the input tensor boundaries with a constant value.
The interface is consistent with PyTorch, and referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad3d.html?highlight=constantpad3d#torch.nn.ConstantPad3d
For `N`-dimensional padding, use :func:`flow.nn.functional.pad()`.
Args:
padding (int, list, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 6-`tuple`, uses
(:math:`\text{padding_left}`, :math:`\text{padding_right}`,
:math:`\text{padding_top}`, :math:`\text{padding_bottom}`,
:math:`\text{padding_front}`, :math:`\text{padding_back}`)
value (Union[int, float]): The constant value used for padding. Defaults to 0.
Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = D_{in} + \text{padding_front} + \text{padding_back}`
:math:`H_{out} = H_{in} + \text{padding_top} + \text{padding_bottom}`
:math:`W_{out} = W_{in} + \text{padding_left} + \text{padding_right}`
Examples::
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> input = flow.tensor(np.arange(8).reshape(1,1,2,2,2).astype(np.int32))
>>> m = flow.nn.ConstantPad3d(padding=1, value=9)
>>> output = m(input)
>>> output
tensor([[[[[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9]],
<BLANKLINE>
[[9, 9, 9, 9],
[9, 0, 1, 9],
[9, 2, 3, 9],
[9, 9, 9, 9]],
<BLANKLINE>
[[9, 9, 9, 9],
[9, 4, 5, 9],
[9, 6, 7, 9],
[9, 9, 9, 9]],
<BLANKLINE>
[[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9]]]]], dtype=oneflow.int32)
"""
def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
super().__init__()
if isinstance(padding, (tuple, list)):
assert len(padding) == 6, ValueError("Length of padding must be 6")
boundary = [
padding[0],
padding[1],
padding[2],
padding[3],
padding[4],
padding[5],
]
elif isinstance(padding, int):
boundary = [padding, padding, padding, padding, padding, padding]
else:
raise ValueError("padding must be int or list or tuple!")
self.padding = boundary
self.value = value
def forward(self, x):
if x.dtype in (flow.float32, flow.float16, flow.float64):
self.value = float(self.value)
else:
self.value = int(self.value)
return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value)
if __name__ == "__main__":
import doctest
......
......@@ -111,7 +111,7 @@ class TestConstantPad2dModule(flow.unittest.TestCase):
def test_with_random_data(test_case):
for device in ["cpu", "cuda"]:
spatial_size = np.random.randint(10, 20)
spatial_size = np.random.randint(1, 6)
test_module_against_pytorch(
test_case,
"nn.ConstantPad2d",
......@@ -120,8 +120,31 @@ class TestConstantPad2dModule(flow.unittest.TestCase):
"input": random_tensor(
ndim=4, dim2=spatial_size, dim3=spatial_size
),
"padding": random(0, 6),
"value": random(0, 6),
"padding": random(0, 3),
"value": random(0, 10),
},
device=device,
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestConstantPad3dModule(flow.unittest.TestCase):
def test_with_random_data(test_case):
for device in ["cpu", "cuda"]:
spatial_size = np.random.randint(1, 6)
test_module_against_pytorch(
test_case,
"nn.ConstantPad3d",
extra_annotations={"padding": int, "value": float},
extra_generators={
"input": random_tensor(
ndim=5, dim2=spatial_size, dim3=spatial_size, dim4=spatial_size
),
"padding": random(0, 3),
"value": random(0, 10),
},
device=device,
)
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/device/memory_copier.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
namespace oneflow {
namespace {
template<typename T>
T GetDtypeMatchedValue(double floating, int64_t integral);
template<>
float16 GetDtypeMatchedValue(double floating, int64_t integral) {
return static_cast<float16>(floating);
}
template<>
float GetDtypeMatchedValue(double floating, int64_t integral) {
return static_cast<float>(floating);
}
template<>
double GetDtypeMatchedValue(double floating, int64_t integral) {
return floating;
}
template<>
int8_t GetDtypeMatchedValue(double floating, int64_t integral) {
return static_cast<int8_t>(integral);
}
template<>
int32_t GetDtypeMatchedValue(double floating, int64_t integral) {
return static_cast<int32_t>(integral);
}
template<>
int64_t GetDtypeMatchedValue(double floating, int64_t integral) {
return integral;
}
} // namespace
namespace user_op {
template<DeviceType device_type, typename IN_T>
class ConstantPad3dKernel final : public OpKernel {
public:
ConstantPad3dKernel() = default;
~ConstantPad3dKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const ShapeView& x_shape = x->shape();
const ShapeView& y_shape = y->shape();
CHECK_EQ(x->shape().NumAxes(), 5);
const std::vector<int64_t>& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ(padding.size(), 6);
const IN_T constant_value = GetDtypeMatchedValue<IN_T>(ctx->Attr<double>("floating_value"),
ctx->Attr<int64_t>("integral_value"));
IN_T* dest = y->mut_dptr<IN_T>();
const IN_T* src = x->dptr<IN_T>();
DimVector y_vector;
y->shape().ToDimVector(&y_vector);
NdIndexOffsetHelper<int64_t, 5> index_helper(y_vector.data());
ConstantPad3dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, x_shape,
y_shape, padding, constant_value);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
template<DeviceType device_type, typename IN_T>
class ConstantPad3dGradKernel final : public OpKernel {
public:
ConstantPad3dGradKernel() = default;
~ConstantPad3dGradKernel() = default;
private:
void Compute(KernelComputeContext* ctx) const override {
const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
CHECK_EQ(dy->shape().NumAxes(), 5);
Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
const ShapeView& dx_shape = dx->shape();
const ShapeView& dy_shape = dy->shape();
const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ(padding.size(), 6);
const IN_T* src = dy->dptr<IN_T>();
IN_T* dest = dx->mut_dptr<IN_T>();
DimVector dy_vector;
dy->shape().ToDimVector(&dy_vector);
NdIndexOffsetHelper<int64_t, 5> index_helper(dy_vector.data());
size_t out_bytes_size = dx->shape().Count(0) * GetSizeOfDataType(dx->data_type());
Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
ConstantPad3dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
dy_shape, dx_shape, padding);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_CONSTANT_PAD3D_KERNELS(device, dtype) \
REGISTER_USER_KERNEL("constant_pad3d") \
.SetCreateFn<ConstantPad3dKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("constant_pad3d_grad") \
.SetCreateFn<ConstantPad3dGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
#define REGISTER_CONSTANT_PAD3D_WITH_DEVICE(device) \
REGISTER_CONSTANT_PAD3D_KERNELS(device, float) \
REGISTER_CONSTANT_PAD3D_KERNELS(device, double) \
REGISTER_CONSTANT_PAD3D_KERNELS(device, int32_t)
REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kGPU)
REGISTER_CONSTANT_PAD3D_KERNELS(DeviceType::kGPU, float16)
#endif
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
namespace user_op {
template<typename IN_T>
struct ConstantPad3dFunctor<DeviceType::kCPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value) {
// for NCDHW format input tensor, index of n, c, d, h, w is 0, 1, 2, 3, 4
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
// padding vector: [left, right, top, bottom, front, back]
DoConstantPad3d<IN_T>(src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx),
y_shape.At(d_idx), y_shape.At(h_idx), y_shape.At(w_idx),
x_shape.At(d_idx), x_shape.At(h_idx), x_shape.At(w_idx), padding[4],
padding[0], padding[2], constant_value);
}
};
template<typename IN_T>
struct ConstantPad3dGradFunctor<DeviceType::kCPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DoConstantPad3dGrad<IN_T>(src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx),
dy_shape.At(d_idx), dy_shape.At(h_idx), dy_shape.At(w_idx),
dx_shape.At(d_idx), dx_shape.At(h_idx), dx_shape.At(w_idx),
padding[4], padding[0], padding[2]);
}
};
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR, (DeviceType::kCPU),
PADDING_DATA_TYPE_CPU_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR, (DeviceType::kCPU),
PADDING_DATA_TYPE_CPU_SEQ);
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdint>
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
namespace oneflow {
namespace user_op {
template<typename IN_T>
__global__ void DoCUDAConstantPad3d(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5> index_helper,
int64_t elem_num, int64_t n_channel, int64_t y_depth,
int64_t y_height, int64_t y_width, int64_t x_depth,
int64_t x_height, int64_t x_width, int64_t pad_front,
int64_t pad_left, int64_t pad_top, const IN_T const_value) {
DoConstantPad3d<IN_T>(src, dest, index_helper, elem_num, n_channel, y_depth, y_height, y_width,
x_depth, x_height, x_width, pad_front, pad_left, pad_top, const_value);
};
template<typename IN_T>
__global__ void DoCUDAConstantPad3dGrad(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5> index_helper,
int64_t elem_num, int64_t n_channel, int64_t dy_depth,
int64_t dy_height, int64_t dy_width, int64_t dx_depth,
int64_t dx_height, int64_t dx_width, int64_t pad_front,
int64_t pad_left, int64_t pad_top) {
DoConstantPad3dGrad<IN_T>(src, dest, index_helper, elem_num, n_channel, dy_depth, dy_height,
dy_width, dx_height, dx_depth, dx_width, pad_front, pad_left, pad_top);
};
template<typename IN_T>
struct ConstantPad3dFunctor<DeviceType::kGPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value) {
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DoCUDAConstantPad3d<IN_T><<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(
src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx), y_shape.At(d_idx),
y_shape.At(h_idx), y_shape.At(w_idx), x_shape.At(d_idx), x_shape.At(h_idx),
x_shape.At(w_idx), padding[4], padding[0], padding[2], constant_value);
}
};
// float16 implementation
template<>
void ConstantPad3dFunctor<DeviceType::kGPU, float16>::operator()(
DeviceCtx* ctx, const float16* src, float16* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding, float16 constant_value) {
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DoCUDAConstantPad3d<half>
<<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper,
y_shape.Count(0), y_shape.At(c_idx), y_shape.At(d_idx), y_shape.At(h_idx),
y_shape.At(w_idx), x_shape.At(d_idx), x_shape.At(h_idx), x_shape.At(w_idx), padding[4],
padding[0], padding[2], static_cast<const half>(constant_value));
}
template<typename IN_T>
struct ConstantPad3dGradFunctor<DeviceType::kGPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DoCUDAConstantPad3dGrad<IN_T><<<BlocksNum4ThreadsNum(dy_shape.Count(0)),
kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(d_idx),
dy_shape.At(h_idx), dy_shape.At(w_idx), dx_shape.At(d_idx), dx_shape.At(h_idx),
dx_shape.At(w_idx), padding[4], padding[0], padding[2]);
}
};
// float16 implementation
template<>
void ConstantPad3dGradFunctor<DeviceType::kGPU, float16>::operator()(
DeviceCtx* ctx, const float16* src, float16* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DoCUDAConstantPad3dGrad<half>
<<<BlocksNum4ThreadsNum(dy_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper,
dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(d_idx), dy_shape.At(h_idx),
dy_shape.At(w_idx), dx_shape.At(d_idx), dx_shape.At(h_idx), dx_shape.At(w_idx),
padding[4], padding[0], padding[2]);
}
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
#define ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
#ifdef WITH_CUDA
#include "oneflow/core/cuda/atomic.cuh"
#endif // WITH_CUDA
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/ndarray/xpu_util.h"
namespace oneflow {
#define PADDING_DATA_TYPE_CPU_SEQ \
FLOATING_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
#define PADDING_DATA_TYPE_GPU_SEQ \
FLOAT16_DATA_TYPE_SEQ \
PADDING_DATA_TYPE_CPU_SEQ
namespace user_op {
template<typename T>
struct DeviceAdd {
OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {
#if defined(__CUDA_ARCH__)
cuda::atomic::Add(y, *x);
#else
*y += *x;
#endif
};
};
template<DeviceType device_type, typename IN_T>
struct ConstantPad3dFunctor final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value);
};
template<DeviceType device_type, typename IN_T>
struct ConstantPad3dGradFunctor final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding);
};
template<typename IN_T>
OF_DEVICE_FUNC void DoConstantPad3d(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper,
int64_t elem_num, int64_t n_channel, int64_t y_depth,
int64_t y_height, int64_t y_width, int64_t x_depth,
int64_t x_height, int64_t x_width, int64_t pad_front,
int64_t pad_left, int64_t pad_top, IN_T constant_value) {
XPU_1D_KERNEL_LOOP(num, elem_num) {
int64_t n, c, d, h, w;
index_helper.OffsetToNdIndex(num, n, c, d, h, w);
const int64_t src_num = n_channel * x_depth * x_height * x_width;
if (pad_front <= d && d < pad_front + x_depth && w >= pad_left && w < x_width + pad_left
&& h >= pad_top && h < x_height + pad_top) {
const int64_t len_w = w - pad_left;
const int64_t len_h = h - pad_top;
const int64_t len_d = d - pad_front;
const int64_t src_index = n * src_num + c * x_depth * x_width * x_height
+ len_d * x_height * x_width + len_h * x_width + len_w;
dest[num] = src[src_index];
} else {
dest[num] = constant_value;
}
}
}
template<typename IN_T>
OF_DEVICE_FUNC void DoConstantPad3dGrad(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper,
int64_t elem_num, int64_t n_channel, int64_t dy_depth,
int64_t dy_height, int64_t dy_width, int64_t dx_depth,
int64_t dx_height, int64_t dx_width, int64_t pad_front,
int64_t pad_left, int64_t pad_top) {
XPU_1D_KERNEL_LOOP(num, elem_num) {
int64_t n, c, d, h, w;
index_helper.OffsetToNdIndex(num, n, c, d, h, w);
const int64_t dest_num = n_channel * dx_depth * dx_height * dx_width;
if (pad_front <= d && d < pad_front + dx_depth && w >= pad_left && w < dx_width + pad_left
&& h >= pad_top && h < dx_height + pad_top) {
const int64_t len_d = d - pad_front;
const int64_t len_w = w - pad_left;
const int64_t len_h = h - pad_top;
const int64_t dest_index = n * dest_num + c * dx_depth * dx_width * dx_height
+ len_d * dx_width * dx_height + len_h * dx_width + len_w;
DeviceAdd<IN_T>::Invoke(src + num, dest + dest_index);
}
}
}
#define INSTANTIATE_CONSTANT_PAD3D_FUNCTOR(device_type_v, dtype_pair) \
template struct ConstantPad3dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
#define INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR(device_type_v, dtype_pair) \
template struct ConstantPad3dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
} // namespace user_op
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
......@@ -343,4 +343,113 @@ REGISTER_USER_OP_GRAD("constant_pad2d")
return Maybe<void>::Ok();
});
REGISTER_USER_OP("constant_pad3d")
.Input("x")
.Output("y")
.Attr<std::vector<int64_t>>("padding")
.Attr<double>("floating_value")
.Attr<int64_t>("integral_value")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const Shape& x_shape = ctx->InputShape("x", 0);
const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5);
// only support NCDHW format input tensor for now !
// for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4
const int64_t n_idx = 0;
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DimVector y_dim_vec(x_shape.NumAxes());
const int64_t d_x = x_shape.At(d_idx);
const int64_t h_x = x_shape.At(h_idx);
const int64_t w_x = x_shape.At(w_idx);
y_dim_vec[n_idx] = x_shape.At(n_idx);
y_dim_vec[c_idx] = x_shape.At(c_idx);
y_dim_vec[d_idx] = d_x + padding[4] + padding[5];
y_dim_vec[h_idx] = h_x + padding[2] + padding[3];
y_dim_vec[w_idx] = w_x + padding[0] + padding[1];
*ctx->OutputShape("y", 0) = Shape(y_dim_vec);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
return Maybe<void>::Ok();
})
.SetGetSbpFn(GetOpSbpSignature)
.SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
const user_op::UserOpConfWrapper&) -> Maybe<void> {
user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0);
CHECK_NOTNULL_OR_RETURN(x_modifier);
x_modifier->set_requires_grad(true);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
return Maybe<void>::Ok();
});
REGISTER_USER_OP("constant_pad3d_grad")
.Input("dy")
.Output("dx")
.Attr<std::vector<int64_t>>("padding")
.Attr<double>("floating_value")
.Attr<int64_t>("integral_value")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const Shape& dy_shape = ctx->InputShape("dy", 0);
const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5);
const int64_t n_idx = 0;
const int64_t c_idx = 1;
const int64_t d_idx = 2;
const int64_t h_idx = 3;
const int64_t w_idx = 4;
DimVector dx_dim_vec(dy_shape.NumAxes());
int64_t d_dy, h_dy, w_dy;
d_dy = dy_shape.At(d_idx);
h_dy = dy_shape.At(h_idx);
w_dy = dy_shape.At(w_idx);
dx_dim_vec[n_idx] = dy_shape.At(0);
dx_dim_vec[c_idx] = dy_shape.At(1);
dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5];
dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3];
dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];
*ctx->OutputShape("dx", 0) = Shape(dx_dim_vec);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
return Maybe<void>::Ok();
})
.SetGetSbpFn(GetOpGradSbpSignature)
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
return Maybe<void>::Ok();
});
REGISTER_USER_OP_GRAD("constant_pad3d")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
builder.Op("constant_pad3d_grad")
.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Output("dx")
.Attr("padding", op.attr<std::vector<int64_t>>("padding"))
.Attr("floating_value", op.attr<double>("floating_value"))
.Attr("integral_value", op.attr<int64_t>("integral_value"))
.Build();
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});
} // namespace oneflow
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment