Skip to content
Snippets Groups Projects
Unverified Commit 87a990a6 authored by Luyang's avatar Luyang Committed by GitHub
Browse files

Dev constantpad1d op (#5579)


* refine and add test case

* support ellipsis type slice

* refine

* refine

* support slice assign ellipsis type

* refine

* register fn to localtensor

* add constantpad1d kernel

* add constantpad1d kernel

* implementation of fuctional api/gradients/test case

* refine test case

* refine

* format docstring

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 3dcb9939
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,8 @@ Operators for neural networks
BatchNorm2d,
COCOReader,
CTCLoss,
CoinFlip,
CoinFlip,
ConstantPad1d,
ConstantPad2d,
ConstantPad3d,
Conv1d,
......
......@@ -127,6 +127,7 @@ class ConstantPadNd : public OpExprGradFunction<ConstantPadNdInterpState> {
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad1d", ConstantPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad3d", ConstantPadNd);
......
......@@ -414,8 +414,9 @@ class NormalizationFunctor {
class PadFunctor {
public:
PadFunctor() {
constant_pad_1d_ = CHECK_JUST(one::OpBuilder("constant_pad1d").Input("x").Output("y").Build());
constant_pad_2d_ = CHECK_JUST(one::OpBuilder("constant_pad2d").Input("x").Output("y").Build());
constant_pad_3d_ = CHECK_JUST(one::OpBuilder("constant_pad3d").Input("x").Output("y").Build());
constant_pad_ = CHECK_JUST(one::OpBuilder("constant_pad2d").Input("x").Output("y").Build());
reflect_pad_ = CHECK_JUST(one::OpBuilder("reflection_pad2d").Input("x").Output("y").Build());
replicate_pad_ = CHECK_JUST(one::OpBuilder("replication_pad2d").Input("x").Output("y").Build());
}
......@@ -437,7 +438,8 @@ class PadFunctor {
UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type.";
}
switch (x->shape()->NumAxes()) {
case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs);
case 3: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_1d_, {x}, attrs);
case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_2d_, {x}, attrs);
case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_, {x}, attrs);
default:
UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but " << x->shape()->NumAxes()
......@@ -455,7 +457,8 @@ class PadFunctor {
}
private:
std::shared_ptr<OpExpr> constant_pad_;
std::shared_ptr<OpExpr> constant_pad_1d_;
std::shared_ptr<OpExpr> constant_pad_2d_;
std::shared_ptr<OpExpr> constant_pad_3d_;
std::shared_ptr<OpExpr> reflect_pad_;
std::shared_ptr<OpExpr> replicate_pad_;
......
......@@ -253,7 +253,9 @@ class SmoothL1LossGradFunctor {
class PadGradFunctor {
public:
PadGradFunctor() {
constant_pad_grad_ =
constant_pad_1d_grad_ =
CHECK_JUST(one::OpBuilder("constant_pad1d_grad").Input("dy").Output("dx").Build());
constant_pad_2d_grad_ =
CHECK_JUST(one::OpBuilder("constant_pad2d_grad").Input("dy").Output("dx").Build());
constant_pad_3d_grad_ =
CHECK_JUST(one::OpBuilder("constant_pad3d_grad").Input("dy").Output("dx").Build());
......@@ -280,7 +282,8 @@ class PadGradFunctor {
UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type.";
}
switch (dy->shape()->NumAxes()) {
case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs);
case 3: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_1d_grad_, {dy}, attrs);
case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_2d_grad_, {dy}, attrs);
case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_grad_, {dy}, attrs);
default:
UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but "
......@@ -297,9 +300,10 @@ class PadGradFunctor {
}
private:
std::shared_ptr<OpExpr> constant_pad_grad_;
std::shared_ptr<OpExpr> reflect_pad_grad_;
std::shared_ptr<OpExpr> replicate_pad_grad_;
std::shared_ptr<OpExpr> constant_pad_1d_grad_;
std::shared_ptr<OpExpr> constant_pad_2d_grad_;
std::shared_ptr<OpExpr> constant_pad_3d_grad_;
};
......
......@@ -186,26 +186,91 @@ class ReflectionPad2d(Module):
return "{}".format(self.padding)
@oneflow_export("nn.ConstantPad1d")
class ConstantPad1d(Module):
r"""Pads the input tensor boundaries with a constant value.
The interface is consistent with PyTorch, and referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad1d.html?highlight=constantpad1d#torch.nn.ConstantPad1d
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
Args:
padding (int, list, tuple): the size of the padding. If is `int`, uses the same
padding in both boundaries. If a 2-`tuple`, uses
(:math:`\text{padding_left}`, :math:`\text{padding_right}`)
value (int, float): The constant value used for padding. Defaults to 0.
Shape:
- Input: :math:`(N, C, W_{in})`
- Output: :math:`(N, C, W_{out})` where
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> input = flow.tensor(np.arange(8).reshape(2,2,2).astype(np.float32))
>>> m = flow.nn.ConstantPad1d(padding=[1, 2], value=9.9999)
>>> output = m(input)
>>> output
tensor([[[9.9999, 0. , 1. , 9.9999, 9.9999],
[9.9999, 2. , 3. , 9.9999, 9.9999]],
<BLANKLINE>
[[9.9999, 4. , 5. , 9.9999, 9.9999],
[9.9999, 6. , 7. , 9.9999, 9.9999]]], dtype=oneflow.float32)
"""
def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
super().__init__()
if isinstance(padding, (tuple, list)):
assert len(padding) == 2, ValueError("Length of padding must be 4")
boundary = [padding[0], padding[1]]
elif isinstance(padding, int):
boundary = [padding, padding]
else:
raise ValueError("padding must be int or list or tuple!")
self.padding = boundary
self.value = value
def forward(self, x):
if x.dtype in (flow.float32, flow.float16, flow.float64):
self.value = float(self.value)
else:
self.value = int(self.value)
return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value)
@oneflow_export("nn.ConstantPad2d")
class ConstantPad2d(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d
This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`.
This operator pads the input with constant value that user specifies.
User can set the amount of padding by setting the parameter `paddings`.
Args:
padding (Union[int, tuple, list]): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
value (Union[int, float]): The constant value used for padding. Defaults to 0.
padding (int, tuple, list): the size of the padding.
If is `int`, uses the same padding in all boundaries.
If a 4-`tuple`, uses
(:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
value (int, float): The constant value used for padding. Defaults to 0.
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
:math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
:math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
:math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
For example:
......@@ -213,6 +278,7 @@ class ConstantPad2d(Module):
>>> import oneflow as flow
>>> import numpy as np
>>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1)
>>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))
>>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32))
......@@ -244,6 +310,7 @@ class ConstantPad2d(Module):
[ 1., 1., 12., 13., 14., 1., 1.],
[ 1., 1., 15., 16., 17., 1., 1.],
[ 1., 1., 1., 1., 1., 1., 1.]]]], dtype=oneflow.float32)
"""
def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
......@@ -283,7 +350,7 @@ class ConstantPad3d(Module):
:math:`\text{padding_top}`, :math:`\text{padding_bottom}`,
:math:`\text{padding_front}`, :math:`\text{padding_back}`)
value (Union[int, float]): The constant value used for padding. Defaults to 0.
value (int, float): The constant value used for padding. Defaults to 0.
Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
......
......@@ -98,7 +98,22 @@ def _test_ConstantPad2d(test_case, shape, padding, value, device):
@flow.unittest.skip_unless_1n1d()
class TestConstantPad2dModule(flow.unittest.TestCase):
class TestConstantPad1d(flow.unittest.TestCase):
@autotest(rtol=1e-4, atol=1e-4)
def test_constantpad1d_with_random_data(test_case):
m = torch.nn.ConstantPad1d(padding=random().to(int), value=random().to(float))
m.train(random())
device = random_device()
m.to(device)
x = random_pytorch_tensor(ndim=3, dim1=random(1, 6), dim2=random(1, 6)).to(
device
)
y = m(x)
return y
@flow.unittest.skip_unless_1n1d()
class TestConstantPad2d(flow.unittest.TestCase):
def test_ConstantPad2d(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(1, 2, 3, 4), (8, 3, 4, 4)]
......@@ -127,11 +142,8 @@ class TestConstantPad2dModule(flow.unittest.TestCase):
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestConstantPad3dModule(flow.unittest.TestCase):
@flow.unittest.skip_unless_1n1d()
class TestConstantPad3d(flow.unittest.TestCase):
def test_with_random_data(test_case):
for device in ["cpu", "cuda"]:
spatial_size = np.random.randint(1, 6)
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/device/memory_copier.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
#include "oneflow/user/kernels/constantpad_kernel_util.h"
namespace oneflow {
......@@ -60,6 +60,68 @@ int64_t GetDtypeMatchedValue(double floating, int64_t integral) {
namespace user_op {
template<DeviceType device_type, typename IN_T>
class ConstantPad1dKernel final : public OpKernel {
public:
ConstantPad1dKernel() = default;
~ConstantPad1dKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const ShapeView& x_shape = x->shape();
const ShapeView& y_shape = y->shape();
CHECK_EQ(x->shape().NumAxes(), 3);
const std::vector<int64_t>& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ(padding.size(), 2);
const IN_T constant_value = GetDtypeMatchedValue<IN_T>(ctx->Attr<double>("floating_value"),
ctx->Attr<int64_t>("integral_value"));
IN_T* dest = y->mut_dptr<IN_T>();
const IN_T* src = x->dptr<IN_T>();
DimVector y_vector;
y->shape().ToDimVector(&y_vector);
NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());
ConstantPad1dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, x_shape,
y_shape, padding, constant_value);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
template<DeviceType device_type, typename IN_T>
class ConstantPad1dGradKernel final : public OpKernel {
public:
ConstantPad1dGradKernel() = default;
~ConstantPad1dGradKernel() = default;
private:
void Compute(KernelComputeContext* ctx) const override {
const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
CHECK_EQ(dy->shape().NumAxes(), 3);
Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
const ShapeView& dx_shape = dx->shape();
const ShapeView& dy_shape = dy->shape();
const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ(padding.size(), 2);
const IN_T* src = dy->dptr<IN_T>();
IN_T* dest = dx->mut_dptr<IN_T>();
DimVector dy_vector;
dy->shape().ToDimVector(&dy_vector);
NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());
size_t out_bytes_size = dx->shape().Count(0) * GetSizeOfDataType(dx->data_type());
Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
ConstantPad1dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
dy_shape, dx_shape, padding);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
template<DeviceType device_type, typename IN_T>
class ConstantPad3dKernel final : public OpKernel {
public:
......@@ -122,25 +184,33 @@ class ConstantPad3dGradKernel final : public OpKernel {
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_CONSTANT_PAD3D_KERNELS(device, dtype) \
REGISTER_USER_KERNEL("constant_pad3d") \
.SetCreateFn<ConstantPad3dKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("constant_pad3d_grad") \
.SetCreateFn<ConstantPad3dGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
#define REGISTER_CONSTANT_PAD_KERNELS(device, dtype) \
REGISTER_USER_KERNEL("constant_pad1d") \
.SetCreateFn<ConstantPad1dKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("constant_pad1d_grad") \
.SetCreateFn<ConstantPad1dGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("constant_pad3d") \
.SetCreateFn<ConstantPad3dKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("constant_pad3d_grad") \
.SetCreateFn<ConstantPad3dGradKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
#define REGISTER_CONSTANT_PAD3D_WITH_DEVICE(device) \
REGISTER_CONSTANT_PAD3D_KERNELS(device, float) \
REGISTER_CONSTANT_PAD3D_KERNELS(device, double) \
REGISTER_CONSTANT_PAD3D_KERNELS(device, int32_t)
#define REGISTER_CONSTANT_PAD_WITH_DEVICE(device) \
REGISTER_CONSTANT_PAD_KERNELS(device, float) \
REGISTER_CONSTANT_PAD_KERNELS(device, double) \
REGISTER_CONSTANT_PAD_KERNELS(device, int32_t)
REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kCPU)
REGISTER_CONSTANT_PAD_WITH_DEVICE(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kGPU)
REGISTER_CONSTANT_PAD3D_KERNELS(DeviceType::kGPU, float16)
REGISTER_CONSTANT_PAD_WITH_DEVICE(DeviceType::kGPU)
REGISTER_CONSTANT_PAD_KERNELS(DeviceType::kGPU, float16)
#endif
} // namespace user_op
......
......@@ -13,12 +13,39 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
#include "oneflow/user/kernels/constantpad_kernel_util.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
namespace user_op {
template<typename IN_T>
struct ConstantPad1dFunctor<DeviceType::kCPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value) {
// for NCW format input tensor, index of n, c, w is 0, 1, 2
const int64_t c_idx = 1;
const int64_t w_idx = 2;
// padding vector: [left, right]
DoConstantPad1d<IN_T>(src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx),
y_shape.At(w_idx), x_shape.At(w_idx), padding[0], constant_value);
}
};
template<typename IN_T>
struct ConstantPad1dGradFunctor<DeviceType::kCPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DoConstantPad1dGrad<IN_T>(src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx),
dy_shape.At(w_idx), dx_shape.At(w_idx), padding[0]);
}
};
template<typename IN_T>
struct ConstantPad3dFunctor<DeviceType::kCPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
......@@ -54,10 +81,10 @@ struct ConstantPad3dGradFunctor<DeviceType::kCPU, IN_T> final {
}
};
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR, (DeviceType::kCPU),
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD_FUNCTOR, (DeviceType::kCPU),
PADDING_DATA_TYPE_CPU_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR, (DeviceType::kCPU),
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD_GRAD_FUNCTOR, (DeviceType::kCPU),
PADDING_DATA_TYPE_CPU_SEQ);
} // namespace user_op
......
......@@ -16,11 +16,20 @@ limitations under the License.
#include <cstdint>
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
#include "oneflow/user/kernels/constantpad_kernel_util.h"
namespace oneflow {
namespace user_op {
template<typename IN_T>
__global__ void DoCUDAConstantPad1d(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3> index_helper,
int64_t elem_num, int64_t n_channel, int64_t y_width,
int64_t x_width, int64_t pad_left, const IN_T const_value) {
DoConstantPad1d<IN_T>(src, dest, index_helper, elem_num, n_channel, y_width, x_width, pad_left,
const_value);
};
template<typename IN_T>
__global__ void DoCUDAConstantPad3d(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5> index_helper,
......@@ -32,6 +41,15 @@ __global__ void DoCUDAConstantPad3d(const IN_T* src, IN_T* dest,
x_depth, x_height, x_width, pad_front, pad_left, pad_top, const_value);
};
template<typename IN_T>
__global__ void DoCUDAConstantPad1dGrad(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3> index_helper,
int64_t elem_num, int64_t n_channel, int64_t dy_width,
int64_t dx_width, int64_t pad_left) {
DoConstantPad1dGrad<IN_T>(src, dest, index_helper, elem_num, n_channel, dy_width, dx_width,
pad_left);
};
template<typename IN_T>
__global__ void DoCUDAConstantPad3dGrad(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5> index_helper,
......@@ -43,6 +61,37 @@ __global__ void DoCUDAConstantPad3dGrad(const IN_T* src, IN_T* dest,
dy_width, dx_height, dx_depth, dx_width, pad_front, pad_left, pad_top);
};
template<typename IN_T>
struct ConstantPad1dFunctor<DeviceType::kGPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value) {
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DoCUDAConstantPad1d<IN_T>
<<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx),
y_shape.At(w_idx), x_shape.At(w_idx), padding[0], constant_value);
}
};
// float16 implementation
template<>
void ConstantPad1dFunctor<DeviceType::kGPU, float16>::operator()(
DeviceCtx* ctx, const float16* src, float16* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding, float16 constant_value) {
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DoCUDAConstantPad1d<half>
<<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper,
y_shape.Count(0), y_shape.At(c_idx), y_shape.At(w_idx), x_shape.At(w_idx), padding[0],
static_cast<const half>(constant_value));
}
template<typename IN_T>
struct ConstantPad3dFunctor<DeviceType::kGPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
......@@ -80,6 +129,35 @@ void ConstantPad3dFunctor<DeviceType::kGPU, float16>::operator()(
padding[0], padding[2], static_cast<const half>(constant_value));
}
template<typename IN_T>
struct ConstantPad1dGradFunctor<DeviceType::kGPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DoCUDAConstantPad1dGrad<IN_T>
<<<BlocksNum4ThreadsNum(dy_shape.Count(0)), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx),
dy_shape.At(w_idx), dx_shape.At(w_idx), padding[0]);
}
};
// float16 implementation
template<>
void ConstantPad1dGradFunctor<DeviceType::kGPU, float16>::operator()(
DeviceCtx* ctx, const float16* src, float16* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DoCUDAConstantPad1dGrad<half>
<<<BlocksNum4ThreadsNum(dy_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper,
dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(w_idx), dx_shape.At(w_idx),
padding[0]);
}
template<typename IN_T>
struct ConstantPad3dGradFunctor<DeviceType::kGPU, IN_T> final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
......@@ -115,10 +193,10 @@ void ConstantPad3dGradFunctor<DeviceType::kGPU, float16>::operator()(
padding[4], padding[0], padding[2]);
}
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD_FUNCTOR,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD_GRAD_FUNCTOR,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
} // namespace user_op
......
......@@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
#define ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
#ifndef ONEFLOW_USER_KERNELS_PAD_KERNELS_UTIL_H_
#define ONEFLOW_USER_KERNELS_PAD_KERNELS_UTIL_H_
#ifdef WITH_CUDA
#include "oneflow/core/cuda/atomic.cuh"
#endif // WITH_CUDA
......@@ -44,6 +44,21 @@ struct DeviceAdd {
};
};
template<DeviceType device_type, typename IN_T>
struct ConstantPad1dFunctor final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int64_t>& padding,
IN_T constant_value);
};
template<DeviceType device_type, typename IN_T>
struct ConstantPad1dGradFunctor final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper, const ShapeView& dy_shape,
const ShapeView& dx_shape, const std::vector<int64_t>& padding);
};
template<DeviceType device_type, typename IN_T>
struct ConstantPad3dFunctor final {
void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
......@@ -59,6 +74,45 @@ struct ConstantPad3dGradFunctor final {
const ShapeView& dx_shape, const std::vector<int64_t>& padding);
};
template<typename IN_T>
OF_DEVICE_FUNC void DoConstantPad1d(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper,
int64_t elem_num, int64_t n_channel, int64_t y_width,
int64_t x_width, int64_t pad_left, IN_T constant_value) {
XPU_1D_KERNEL_LOOP(num, elem_num) {
int64_t n, c, w;
index_helper.OffsetToNdIndex(num, n, c, w);
const int64_t src_num = n_channel * x_width;
if (w >= pad_left && w < x_width + pad_left) {
const int64_t len_w = w - pad_left;
const int64_t src_index = n * src_num + c * x_width + len_w;
dest[num] = src[src_index];
} else {
dest[num] = constant_value;
}
}
}
template<typename IN_T>
OF_DEVICE_FUNC void DoConstantPad1dGrad(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 3>& index_helper,
int64_t elem_num, int64_t n_channel, int64_t dy_width,
int64_t dx_width, int64_t pad_left) {
XPU_1D_KERNEL_LOOP(num, elem_num) {
int64_t n, c, w;
index_helper.OffsetToNdIndex(num, n, c, w);
const int64_t dest_num = n_channel * dx_width;
if (w >= pad_left && w < dx_width + pad_left) {
const int64_t len_w = w - pad_left;
const int64_t dest_index = n * dest_num + c * dx_width + len_w;
DeviceAdd<IN_T>::Invoke(src + num, dest + dest_index);
}
}
}
template<typename IN_T>
OF_DEVICE_FUNC void DoConstantPad3d(const IN_T* src, IN_T* dest,
const NdIndexOffsetHelper<int64_t, 5>& index_helper,
......@@ -110,13 +164,15 @@ OF_DEVICE_FUNC void DoConstantPad3dGrad(const IN_T* src, IN_T* dest,
}
}
#define INSTANTIATE_CONSTANT_PAD3D_FUNCTOR(device_type_v, dtype_pair) \
#define INSTANTIATE_CONSTANT_PAD_FUNCTOR(device_type_v, dtype_pair) \
template struct ConstantPad1dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; \
template struct ConstantPad3dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
#define INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR(device_type_v, dtype_pair) \
#define INSTANTIATE_CONSTANT_PAD_GRAD_FUNCTOR(device_type_v, dtype_pair) \
template struct ConstantPad1dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; \
template struct ConstantPad3dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
} // namespace user_op
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
#endif // ONEFLOW_USER_KERNELS_PAD_KERNELS_UTIL_H_
......@@ -242,6 +242,103 @@ REGISTER_USER_OP_GRAD("replication_pad2d")
return Maybe<void>::Ok();
});
REGISTER_USER_OP("constant_pad1d")
.Input("x")
.Output("y")
.Attr<std::vector<int64_t>>("padding")
.Attr<double>("floating_value")
.Attr<int64_t>("integral_value")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const Shape& x_shape = ctx->InputShape("x", 0);
const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 3);
CHECK_EQ_OR_RETURN(padding.size(), 2);
const int64_t n_idx = 0;
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DimVector y_dim_vec(x_shape.NumAxes());
const int64_t w_x = x_shape.At(w_idx);
y_dim_vec[n_idx] = x_shape.At(n_idx);
y_dim_vec[c_idx] = x_shape.At(c_idx);
y_dim_vec[w_idx] = w_x + padding[0] + padding[1];
*ctx->OutputShape("y", 0) = Shape(y_dim_vec);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
return Maybe<void>::Ok();
})
.SetGetSbpFn(GetOpSbpSignature)
.SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
const user_op::UserOpConfWrapper&) -> Maybe<void> {
user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0);
CHECK_NOTNULL_OR_RETURN(x_modifier);
x_modifier->set_requires_grad(true);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
return Maybe<void>::Ok();
});
REGISTER_USER_OP("constant_pad1d_grad")
.Input("dy")
.Output("dx")
.Attr<std::vector<int64_t>>("padding")
.Attr<double>("floating_value")
.Attr<int64_t>("integral_value")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const Shape& dy_shape = ctx->InputShape("dy", 0);
const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 3);
CHECK_EQ_OR_RETURN(padding.size(), 2);
const int64_t n_idx = 0;
const int64_t c_idx = 1;
const int64_t w_idx = 2;
DimVector dx_dim_vec(dy_shape.NumAxes());
int64_t w_dy;
w_dy = dy_shape.At(w_idx);
dx_dim_vec[n_idx] = dy_shape.At(0);
dx_dim_vec[c_idx] = dy_shape.At(1);
dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];
*ctx->OutputShape("dx", 0) = Shape(dx_dim_vec);
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
return Maybe<void>::Ok();
})
.SetGetSbpFn(GetOpGradSbpSignature)
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
return Maybe<void>::Ok();
});
REGISTER_USER_OP_GRAD("constant_pad1d")
.SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op =
builder.Op("constant_pad1d_grad")
.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Output("dx")
.Attr("padding", op.attr<std::vector<int64_t>>("padding"))
.Attr("floating_value", op.attr<double>("floating_value"))
.Attr("integral_value", op.attr<int64_t>("integral_value"))
.Build();
op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
AddOp(grad_op);
}
return Maybe<void>::Ok();
});
REGISTER_USER_OP("constant_pad2d")
.Input("x")
.Output("y")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment