diff --git a/oneflow/python/ops/pad.py b/oneflow/python/ops/pad.py index b6d5a5d718dc05e7f49b5bad843b8e1a453750f1..9602ecbc180e1aacac637abc828aaacba2f6cec3 100644 --- a/oneflow/python/ops/pad.py +++ b/oneflow/python/ops/pad.py @@ -317,3 +317,84 @@ def reflection_pad2d( .InferAndTryRun() .RemoteBlobList()[0] ) + + +@oneflow_export("replication_pad2d") +def replication_pad2d( + x: oneflow_api.BlobDesc, + padding: Union[int, tuple, list], + name: Optional[str] = None, +) -> oneflow_api.BlobDesc: + """Pads the input tensor using the replication of the input boundary. + + Args: + x (oneflow_api.BlobDesc): input blob, only support "NCHW" format. + padding (Union[int, oneflow_api.BlobDesc]): The size or bundary of padding, if is int uses the same padding in all dimension; + if 4-dims tuple, uses (\text{padding\_left}padding_left , \text{padding\_right}padding_right , \text{padding\_top}padding_top , \text{padding\_bottom}padding_bottom ) + name (Optional[str], optional): The name for the operation. Defaults to None. + + Returns: + oneflow_api.BlobDesc: [description] + + For example: + + .. code-block:: python + + import oneflow as flow + import oneflow.typing as tp + import numpy as np + + + @flow.global_function() + def pad_Job(x: tp.Numpy.Placeholder((1, 2, 3, 3)) + ) -> tp.Numpy: + return flow.reflection_pad2d(x, padding=[2, 2, 1, 1]) + + + x = np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32) + out = pad_Job(x) + + # out [[[[ 0. 0. 0. 1. 2. 2. 2.] + # [ 0. 0. 0. 1. 2. 2. 2.] + # [ 3. 3. 3. 4. 5. 5. 5.] + # [ 6. 6. 6. 7. 8. 8. 8.] + # [ 6. 6. 6. 7. 8. 8. 8.]] + + # [[9. 9. 9. 10. 11. 11. 11.] + # [9. 9. 9. 10. 11. 11. 11.] + # [12. 12. 12. 13. 14. 14. 14.] + # [15. 15. 15. 16. 17. 17. 17.] + # [15. 15. 15. 16. 17. 17. 17.]]]] + + """ + H, W = x.shape[2], x.shape[3] + if isinstance(padding, (tuple, list)): + assert len(padding) == len(x.shape), ValueError( + "padding boundry must be the same size of input dims" + ) + assert ( + padding[2] < H and padding[3] < H and padding[0] < W and padding[1] < W + ), ValueError( + "Padding size should be less than the corresponding input dimension!" + ) + boundry = [padding[0], padding[1], padding[2], padding[3]] + elif isinstance(padding, int): + assert padding < H and padding < W, ValueError( + "Padding size should be less than the corresponding input dimension!" + ) + boundry = [padding, padding, padding, padding] + else: + raise ValueError("padding must be in or list or tuple!") + + return ( + oneflow.user_op_builder( + name if name is not None else id_util.UniqueStr("Replication_Pad2d") + ) + .Op("replication_pad2d") + .Input("x", [x]) + .Output("y") + .Attr("padding", list(boundry)) + .Build() + .InferAndTryRun() + .RemoteBlobList()[0] + ) diff --git a/oneflow/python/test/ops/test_replication_pad2d.py b/oneflow/python/test/ops/test_replication_pad2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e23faec580d0ee0f2a495c5949c3d8948bf0e0e3 --- /dev/null +++ b/oneflow/python/test/ops/test_replication_pad2d.py @@ -0,0 +1,329 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import os +from collections import OrderedDict + +import numpy as np +import oneflow as flow +import oneflow.typing as tp +from test_util import Args, GenArgDict, GenArgList + + +def flatten_array(input_array): + output_array = list() + for x in np.nditer(input_array): + output_array.append(x.tolist()) + return output_array + + +def array_to_numpy(input_array, target_shape): + return np.array(input_array).reshape(target_shape, order="C") + + +def index2coordinate(idx, tensor_shape): + coordinate = [] + tmp = idx + for i in range(len(tensor_shape) - 1, -1, -1): + axis_size = tensor_shape[i] + coor = tmp % axis_size + coordinate.insert(0, int(coor)) + tmp = (tmp - coor) / axis_size + return coordinate + + +def coordinate2index(coordinate, tensor_shape): + if len(coordinate) != len(tensor_shape): + raise "wrong coordinate or shape" + idx = 0 + for i, coor in enumerate(coordinate): + size_at_axis = coor + for j in range(i + 1, len(tensor_shape)): + size_at_axis *= tensor_shape[j] + + idx += size_at_axis + return idx + + +def _make_op_function( + test_case, + input, + padding, + grad, + device_type, + value_type, + machine_ids, + device_counts, +): + flow.clear_default_session() + if device_type == "cpu": + flow.config.cpu_device_num(device_counts) + else: + flow.config.gpu_device_num(device_counts) + + func_config = flow.FunctionConfig() + + # global function needs float32 as type of argument and return value + if value_type == flow.float16: + func_config.default_data_type(flow.float32) + else: + func_config.default_data_type(value_type) + + func_config.default_placement_scope(flow.scope.placement(device_type, machine_ids)) + func_config.default_logical_view(flow.scope.consistent_view()) + + def _compare_diff(blob: tp.Numpy): + test_case.assertTrue(np.allclose(grad, blob, 1e-3, 1e-3)) + + if value_type == flow.float32 or value_type == flow.float64: + + @flow.global_function(type="train", function_config=func_config) + def op_function(x: tp.Numpy.Placeholder(input.shape, dtype=value_type)): + with flow.scope.placement(device_type, "0:0"): + x += flow.get_variable( + name="input", + shape=input.shape, + dtype=value_type, + initializer=flow.zeros_initializer(), + ) + out = flow.replication_pad2d(x, padding) + flow.optimizer.SGD( + flow.optimizer.PiecewiseConstantScheduler([], [0]), momentum=0 + ).minimize(out) + + flow.watch_diff(x, _compare_diff) + return out + + return op_function + + elif value_type == flow.int32: + + @flow.global_function(type="train", function_config=func_config) + def op_function(x: tp.Numpy.Placeholder(input.shape, dtype=flow.float32)): + with flow.scope.placement(device_type, "0:0"): + x += flow.get_variable( + name="input", + shape=input.shape, + dtype=flow.float32, + initializer=flow.zeros_initializer(), + ) + y_int32 = flow.replication_pad2d(x, padding) + y_fp32 = flow.cast(y_int32, dtype=flow.float32) + flow.optimizer.SGD( + flow.optimizer.PiecewiseConstantScheduler([], [0]), momentum=0 + ).minimize(y_fp32) + + flow.watch_diff(x, _compare_diff) + return y_fp32 + + return op_function + + elif value_type == flow.float16: + + @flow.global_function(type="train", function_config=func_config) + def op_function(x: tp.Numpy.Placeholder(input.shape, dtype=flow.float32)): + with flow.scope.placement(device_type, "0:0"): + x_var = flow.get_variable( + name="input", + shape=input.shape, + dtype=flow.float32, + initializer=flow.constant_initializer(0), + ) + x_var = flow.cast_to_current_logical_view(x_var) + input_x = x_var + x + x_fp32 = flow.cast(input_x, flow.float32) + x_fp16 = flow.cast(input_x, dtype=flow.float16) + y_fp16 = flow.replication_pad2d(x_fp16, padding) + y_fp32 = flow.cast(y_fp16, dtype=flow.float32) + flow.optimizer.SGD( + flow.optimizer.PiecewiseConstantScheduler([], [0]), momentum=0 + ).minimize(y_fp32) + + flow.watch_diff(x_fp32, _compare_diff) + return y_fp32 + + return op_function + + +def gen_numpy_test_sample(input_shape, padding, is_float=True): + c_idx, h_idx, w_idx = 1, 2, 3 + pad_left = padding[0] + pad_right = padding[1] + pad_top = padding[2] + pad_bottom = padding[3] + pad_shape = ((0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)) + + def _np_replication_pad2d(input, pad_shape): + numpy_replicate = np.pad(input, pad_shape, "edge") + return numpy_replicate + + def _np_replication_pad2d_grad(src, dest): + dx_height, dx_width = input.shape[h_idx], input.shape[w_idx] + dy_height, dy_width = output.shape[h_idx], output.shape[w_idx] + + numpy_src = np.ones(src.shape, np.int32) + numpy_dest = np.zeros(dest.shape, np.int32) + array_src = flatten_array(numpy_src) + array_dest = flatten_array(numpy_dest) + + src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx] + dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx] + elements_num = src.shape[0] * src_num + for iter_n in range(elements_num): + coords = index2coordinate(iter_n, src.shape) + n, c, i, j = coords[0], coords[c_idx], coords[h_idx], coords[w_idx] + ip_x = ip_y = 0 + if j < pad_left: + ip_x = pad_left + elif j >= pad_left and j < (dx_width + pad_left): + ip_x = j + else: + ip_x = dx_width + pad_left - 1 + + if i < pad_top: + ip_y = pad_top + elif i >= pad_top and i < (dx_height + pad_top): + ip_y = i + else: + ip_y = dx_height + pad_top - 1 + + ip_x = ip_x - pad_left + ip_y = ip_y - pad_top + src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j + dest_index = ( + n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x + ) + array_dest[dest_index] += array_src[src_index] + + numpy_dest = array_to_numpy(array_dest, dest.shape) + return numpy_dest + + if is_float: + input = np.random.random(input_shape).astype(np.float32) + else: + input = np.random.randint(0, 100, input_shape) + + output = _np_replication_pad2d(input, pad_shape) + grad = _np_replication_pad2d_grad(output, input) + + numpy_results = { + "input": input, + "padding": padding, + "output": output, + "grad": grad, + } + + return numpy_results + + +def _compare_op_function_with_samples( + test_case, device_type, sample, value_type, machine_ids, device_count +): + op_function = _make_op_function( + test_case, + sample["input"].astype(value_type[0]), + sample["padding"], + sample["grad"].astype(value_type[0]), + device_type, + value_type[1], + machine_ids, + device_count, + ) + y = ( + op_function(sample["input"].astype(value_type[0])) + .get() + .numpy() + .astype(value_type[0]) + ) + + if value_type == flow.float16: + test_case.assertTrue( + np.allclose(y, sample["output"].astype(np.float32), 1e-3, 1e-3) + ) + else: + test_case.assertTrue(np.allclose(y, sample["output"].astype(value_type[0]))) + + +def _gen_arg_dict( + device_type="gpu", value_type="float", machine_ids="0:0", device_count=1 +): + arg_dict = OrderedDict() + arg_dict["device_type"] = [device_type] + arg_dict["samples"] = [] + arg_dict["samples"].append(gen_numpy_test_sample((2, 1, 2, 2), [1, 1, 1, 1])) + arg_dict["samples"].append(gen_numpy_test_sample((4, 2, 3, 3), [2, 2, 2, 2])) + arg_dict["samples"].append(gen_numpy_test_sample((2, 3, 4, 5), [3, 2, 1, 2])) + if value_type == "float": + if device_type == "gpu": + arg_dict["value_type"] = [ + (np.float32, flow.float32), + # (np.float16, flow.float16), + ] + else: + arg_dict["value_type"] = [(np.float32, flow.float32)] + + elif value_type == "int": + arg_dict["value_type"] = [(np.float32, flow.int32)] + else: + raise "float or int for value type only" + + arg_dict["machine_ids"] = [machine_ids] + arg_dict["device_count"] = [device_count] + return arg_dict + + +@flow.unittest.skip_unless_1n1d() +class TestReplicationPad2d1n1d(flow.unittest.TestCase): + def test_op_function_int_cpu(test_case): + arg_dict = _gen_arg_dict("cpu", "int", "0:0", 1) + for arg in GenArgList(arg_dict): + _compare_op_function_with_samples(test_case, *arg) + + def test_op_function_float_cpu(test_case): + arg_dict = _gen_arg_dict("cpu", "float", "0:0", 1) + for arg in GenArgList(arg_dict): + _compare_op_function_with_samples(test_case, *arg) + + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_op_function_int_gpu(test_case): + arg_dict = _gen_arg_dict("gpu", "int", "0:0", 1) + for arg in GenArgList(arg_dict): + _compare_op_function_with_samples(test_case, *arg) + + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_op_function_float_gpu(test_case): + arg_dict = _gen_arg_dict("gpu", "float", "0:0", 1) + for arg in GenArgList(arg_dict): + _compare_op_function_with_samples(test_case, *arg) + + +@flow.unittest.skip_unless_1n2d() +class TestReplicationPad2d1n2d(flow.unittest.TestCase): + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_op_function_float(test_case): + arg_dict = _gen_arg_dict("gpu", "float", "0:0-1", 2) + for arg in GenArgList(arg_dict): + _compare_op_function_with_samples(test_case, *arg) + + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_op_function_int(test_case): + arg_dict = _gen_arg_dict("gpu", "int", "0:0-1", 2) + for arg in GenArgList(arg_dict): + _compare_op_function_with_samples(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/user/kernels/pad2d_kernels.cpp b/oneflow/user/kernels/pad2d_kernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..87949e7c5246fe613a2cf45d9126e3b2a713e820 --- /dev/null +++ b/oneflow/user/kernels/pad2d_kernels.cpp @@ -0,0 +1,250 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/device/memory_copier.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/user/kernels/pad2d_kernels_util.h" + +namespace oneflow { +namespace user_op { + +template<DeviceType device_type, typename IN_T> +class ReflectionPad2dKernel final : public OpKernel { + public: + ReflectionPad2dKernel() = default; + ~ReflectionPad2dKernel() = default; + + private: + void Compute(user_op::KernelComputeContext *ctx) const override { + const Tensor *x = ctx->Tensor4ArgNameAndIndex("x", 0); + Tensor *y = ctx->Tensor4ArgNameAndIndex("y", 0); + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); + const int64_t ndims = x->shape().NumAxes(); + CHECK_EQ(padding.size(), ndims); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + const int64_t pad_left = padding[0]; + const int64_t pad_top = padding[2]; + + const int64_t n_batch = y->shape().At(n_idx); + const int64_t n_channel = y->shape().At(c_idx); + const int64_t y_height = y->shape().At(h_idx); + const int64_t y_width = y->shape().At(w_idx); + const int64_t x_height = x->shape().At(h_idx); + const int64_t x_width = x->shape().At(w_idx); + + IN_T *dest = y->mut_dptr<IN_T>(); + const IN_T *src = x->dptr<IN_T>(); + DimVector y_vector; + y->shape().ToDimVector(&y_vector); + NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data()); + + ReflectionPad2dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, n_batch, + n_channel, y_height, y_width, x_height, x_width, + pad_left, pad_top); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template<DeviceType device_type, typename IN_T> +class ReflectionPad2dGradKernel final : public OpKernel { + public: + ReflectionPad2dGradKernel() = default; + ~ReflectionPad2dGradKernel() = default; + + private: + void Compute(KernelComputeContext *ctx) const override { + const Tensor *dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + Tensor *dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); + const int64_t ndims = dy->shape().NumAxes(); + CHECK_EQ(padding.size(), ndims); + + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + int64_t pad_left = padding[0]; + int64_t pad_top = padding[2]; + int64_t n_batch = dy->shape().At(n_idx); + int64_t n_channel = dy->shape().At(c_idx); + int64_t dy_height = dy->shape().At(h_idx); + int64_t dy_width = dy->shape().At(w_idx); + int64_t dx_height = dx->shape().At(h_idx); + int64_t dx_width = dx->shape().At(w_idx); + + const IN_T *src = dy->dptr<IN_T>(); + IN_T *dest = dx->mut_dptr<IN_T>(); + DimVector dy_vector; + dy->shape().ToDimVector(&dy_vector); + NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data()); + + size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type()); + Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); + + ReflectionPad2dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, + n_batch, n_channel, dy_height, dy_width, + dx_height, dx_width, pad_left, pad_top); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_REFLECTION_PAD2D_KERNELS(device, dtype) \ + REGISTER_USER_KERNEL("reflection_pad2d") \ + .SetCreateFn<ReflectionPad2dKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("reflection_pad2d_grad") \ + .SetCreateFn<ReflectionPad2dGradKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); + +#define REGISTER_REFLECTION_PAD2D_WITH_DEVICE(device) \ + REGISTER_REFLECTION_PAD2D_KERNELS(device, float) \ + REGISTER_REFLECTION_PAD2D_KERNELS(device, double) \ + REGISTER_REFLECTION_PAD2D_KERNELS(device, int32_t) + +REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kCPU) +#ifdef WITH_CUDA +REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kGPU) +REGISTER_REFLECTION_PAD2D_KERNELS(DeviceType::kGPU, float16) +#endif + +template<DeviceType device_type, typename IN_T> +class ReplicationPad2dKernel final : public OpKernel { + public: + ReplicationPad2dKernel() = default; + ~ReplicationPad2dKernel() = default; + + private: + void Compute(user_op::KernelComputeContext *ctx) const override { + const Tensor *x = ctx->Tensor4ArgNameAndIndex("x", 0); + Tensor *y = ctx->Tensor4ArgNameAndIndex("y", 0); + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); + const int64_t ndims = x->shape().NumAxes(); + CHECK_EQ(padding.size(), ndims); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + const int64_t pad_left = padding[0]; + const int64_t pad_top = padding[2]; + + const int64_t n_batch = y->shape().At(n_idx); + const int64_t n_channel = y->shape().At(c_idx); + const int64_t y_height = y->shape().At(h_idx); + const int64_t y_width = y->shape().At(w_idx); + const int64_t x_height = x->shape().At(h_idx); + const int64_t x_width = x->shape().At(w_idx); + + IN_T *dest = y->mut_dptr<IN_T>(); + const IN_T *src = x->dptr<IN_T>(); + DimVector y_vector; + y->shape().ToDimVector(&y_vector); + NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data()); + + ReplicationPad2dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, + n_batch, n_channel, y_height, y_width, x_height, + x_width, pad_left, pad_top); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template<DeviceType device_type, typename IN_T> +class ReplicationPad2dGradKernel final : public OpKernel { + public: + ReplicationPad2dGradKernel() = default; + ~ReplicationPad2dGradKernel() = default; + + private: + void Compute(KernelComputeContext *ctx) const override { + const Tensor *dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + Tensor *dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); + const int64_t ndims = dy->shape().NumAxes(); + CHECK_EQ(padding.size(), ndims); + + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + int64_t pad_left = padding[0]; + int64_t pad_top = padding[2]; + int64_t n_batch = dy->shape().At(n_idx); + int64_t n_channel = dy->shape().At(c_idx); + int64_t dy_height = dy->shape().At(h_idx); + int64_t dy_width = dy->shape().At(w_idx); + int64_t dx_height = dx->shape().At(h_idx); + int64_t dx_width = dx->shape().At(w_idx); + + const IN_T *src = dy->dptr<IN_T>(); + IN_T *dest = dx->mut_dptr<IN_T>(); + DimVector dy_vector; + dy->shape().ToDimVector(&dy_vector); + NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data()); + + size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type()); + Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); + + ReplicationPad2dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, + n_batch, n_channel, dy_height, dy_width, + dx_height, dx_width, pad_left, pad_top); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_REPLICATION_PAD2D_KERNELS(device, dtype) \ + REGISTER_USER_KERNEL("replication_pad2d") \ + .SetCreateFn<ReplicationPad2dKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("replication_pad2d_grad") \ + .SetCreateFn<ReplicationPad2dGradKernel<device, dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); + +#define REGISTER_REPLICATION_PAD2D_WITH_DEVICE(device) \ + REGISTER_REPLICATION_PAD2D_KERNELS(device, float) \ + REGISTER_REPLICATION_PAD2D_KERNELS(device, double) \ + REGISTER_REPLICATION_PAD2D_KERNELS(device, int32_t) + +REGISTER_REPLICATION_PAD2D_WITH_DEVICE(DeviceType::kCPU) +#ifdef WITH_CUDA +REGISTER_REPLICATION_PAD2D_WITH_DEVICE(DeviceType::kGPU) +REGISTER_REPLICATION_PAD2D_KERNELS(DeviceType::kGPU, float16) +#endif + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/pad2d_kernels_util.cpp b/oneflow/user/kernels/pad2d_kernels_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58ab1cd3ac6277ed8c3a2287438165c28e9f2205 --- /dev/null +++ b/oneflow/user/kernels/pad2d_kernels_util.cpp @@ -0,0 +1,92 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/pad2d_kernels_util.h" +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace user_op { + +template<typename IN_T> +struct ReflectionPad2dFunctor<DeviceType::kCPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * y_height * y_width; + int64_t src_num = n_channel * x_height * x_width; + int64_t elem_num = n_batch * dest_num; + DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, + x_height, x_width, pad_left, pad_top); + } +}; + +template<typename IN_T> +struct ReflectionPad2dGradFunctor<DeviceType::kCPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * dx_height * dx_width; + int64_t src_num = n_channel * dy_height * dy_width; + int64_t elem_num = n_batch * src_num; + DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, + dy_width, dx_height, dx_width, pad_left, pad_top); + } +}; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR, (DeviceType::kCPU), + PADDING_DATA_TYPE_CPU_SEQ); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR, (DeviceType::kCPU), + PADDING_DATA_TYPE_CPU_SEQ); + +template<typename IN_T> +struct ReplicationPad2dFunctor<DeviceType::kCPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * y_height * y_width; + int64_t src_num = n_channel * x_height * x_width; + int64_t elem_num = n_batch * dest_num; + DoReplicationPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, + y_width, x_height, x_width, pad_left, pad_top); + } +}; + +template<typename IN_T> +struct ReplicationPad2dGradFunctor<DeviceType::kCPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * dx_height * dx_width; + int64_t src_num = n_channel * dy_height * dy_width; + int64_t elem_num = n_batch * src_num; + DoReplicationPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, + dy_width, dx_height, dx_width, pad_left, pad_top); + } +}; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_FUNCTOR, (DeviceType::kCPU), + PADDING_DATA_TYPE_CPU_SEQ); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_GRAD_FUNCTOR, (DeviceType::kCPU), + PADDING_DATA_TYPE_CPU_SEQ); + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/pad2d_kernels_util.cu b/oneflow/user/kernels/pad2d_kernels_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..0dbbfb7ff6448e918724db09479a08863665590c --- /dev/null +++ b/oneflow/user/kernels/pad2d_kernels_util.cu @@ -0,0 +1,208 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include <cstdint> +#ifdef WITH_CUDA +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/pad2d_kernels_util.h" + +namespace oneflow { +namespace user_op { + +template<typename IN_T> +__global__ void DoCUDAReflectionPad2d(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, + x_height, x_width, pad_left, pad_top); +}; + +template<typename IN_T> +__global__ void DoCUDAReflectionPad2dGrad(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, + dy_width, dx_height, dx_width, pad_left, pad_top); +}; + +template<typename IN_T> +struct ReflectionPad2dFunctor<DeviceType::kGPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * y_height * y_width; + int64_t src_num = n_channel * x_height * x_width; + int64_t elem_num = n_batch * dest_num; + DoCUDAReflectionPad2d<IN_T> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, + x_width, pad_left, pad_top); + } +}; + +// float16 implementation +template<> +void ReflectionPad2dFunctor<DeviceType::kGPU, float16>::operator()( + DeviceCtx *ctx, const float16 *src, float16 *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel, + int64_t y_height, int64_t y_width, int64_t x_height, int64_t x_width, int64_t pad_left, + int64_t pad_top) { + int64_t dest_num = n_channel * y_height * y_width; + int64_t src_num = n_channel * x_height * x_width; + int64_t elem_num = n_batch * dest_num; + DoCUDAReflectionPad2d<half> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper, + elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); +} + +template<typename IN_T> +struct ReflectionPad2dGradFunctor<DeviceType::kGPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * dx_height * dx_width; + int64_t src_num = n_channel * dy_height * dy_width; + int64_t elem_num = n_batch * src_num; + DoCUDAReflectionPad2dGrad<IN_T> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, + dx_width, pad_left, pad_top); + } +}; + +// float16 implementation +template<> +void ReflectionPad2dGradFunctor<DeviceType::kGPU, float16>::operator()( + DeviceCtx *ctx, const float16 *src, float16 *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel, + int64_t dy_height, int64_t dy_width, int64_t dx_height, int64_t dx_width, int64_t pad_left, + int64_t pad_top) { + int64_t dest_num = n_channel * dx_height * dx_width; + int64_t src_num = n_channel * dy_height * dy_width; + int64_t elem_num = n_batch * src_num; + DoCUDAReflectionPad2dGrad<half> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper, + elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); +} + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ); + +template<typename IN_T> +__global__ void DoCUDAReplicationPad2d(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + DoReplicationPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, + x_height, x_width, pad_left, pad_top); +}; + +template<typename IN_T> +__global__ void DoCUDAReplicationPad2dGrad(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + DoReplicationPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, + dy_width, dx_height, dx_width, pad_left, pad_top); +}; + +template<typename IN_T> +struct ReplicationPad2dFunctor<DeviceType::kGPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * y_height * y_width; + int64_t src_num = n_channel * x_height * x_width; + int64_t elem_num = n_batch * dest_num; + DoCUDAReplicationPad2d<IN_T> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, + x_width, pad_left, pad_top); + } +}; + +// float16 implementation +template<> +void ReplicationPad2dFunctor<DeviceType::kGPU, float16>::operator()( + DeviceCtx *ctx, const float16 *src, float16 *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel, + int64_t y_height, int64_t y_width, int64_t x_height, int64_t x_width, int64_t pad_left, + int64_t pad_top) { + int64_t dest_num = n_channel * y_height * y_width; + int64_t src_num = n_channel * x_height * x_width; + int64_t elem_num = n_batch * dest_num; + DoCUDAReplicationPad2d<half> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper, + elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); +} + +template<typename IN_T> +struct ReplicationPad2dGradFunctor<DeviceType::kGPU, IN_T> final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + int64_t dest_num = n_channel * dx_height * dx_width; + int64_t src_num = n_channel * dy_height * dy_width; + int64_t elem_num = n_batch * src_num; + DoCUDAReplicationPad2dGrad<IN_T> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, + dx_width, pad_left, pad_top); + } +}; + +// float16 implementation +template<> +void ReplicationPad2dGradFunctor<DeviceType::kGPU, float16>::operator()( + DeviceCtx *ctx, const float16 *src, float16 *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel, + int64_t dy_height, int64_t dy_width, int64_t dx_height, int64_t dx_width, int64_t pad_left, + int64_t pad_top) { + int64_t dest_num = n_channel * dx_height * dx_width; + int64_t src_num = n_channel * dy_height * dy_width; + int64_t elem_num = n_batch * src_num; + DoCUDAReplicationPad2dGrad<half> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper, + elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); +} + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_FUNCTOR, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_GRAD_FUNCTOR, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ); + +} // namespace user_op +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/pad2d_kernels_util.h b/oneflow/user/kernels/pad2d_kernels_util.h new file mode 100644 index 0000000000000000000000000000000000000000..d527cd0626e32f2278dde9ac9e1c33b3a651f7f3 --- /dev/null +++ b/oneflow/user/kernels/pad2d_kernels_util.h @@ -0,0 +1,247 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_PAD2D_KERNELS_UTIL_H_ +#define ONEFLOW_USER_KERNELS_PAD2D_KERNELS_UTIL_H_ +#ifdef WITH_CUDA +#include "oneflow/core/cuda/atomic.cuh" +#endif // WITH_CUDA +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/ndarray/xpu_util.h" + +namespace oneflow { + +#define PADDING_DATA_TYPE_CPU_SEQ \ + FLOATING_DATA_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) + +#define PADDING_DATA_TYPE_GPU_SEQ \ + FLOAT16_DATA_TYPE_SEQ \ + PADDING_DATA_TYPE_CPU_SEQ + +namespace user_op { + +template<typename T> +struct DeviceAdd { + OF_DEVICE_FUNC static void Invoke(const T *x, T *y) { +#if defined(__CUDA_ARCH__) + cuda::atomic::Add(y, *x); +#else + *y += *x; +#endif + }; +}; + +template<DeviceType device_type, typename IN_T> +struct ReflectionPad2dFunctor final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top); +}; + +template<DeviceType device_type, typename IN_T> +struct ReflectionPad2dGradFunctor final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top); +}; + +template<typename IN_T> +OF_DEVICE_FUNC void DoReflectionPad2d(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + XPU_1D_KERNEL_LOOP(k, elem_num) { + int64_t n, c, i, j, ip_x, ip_y; + int64_t coord_y[4]; + index_helper.OffsetToNdIndex(k, coord_y); + n = coord_y[0]; + c = coord_y[1]; + i = coord_y[2]; + j = coord_y[3]; + if (j < pad_left) { + ip_x = pad_left * 2 - j; + } else if (j >= pad_left && j < x_width + pad_left) { + ip_x = j; + } else { + ip_x = (x_width + pad_left - 1) * 2 - j; + } + + if (i < pad_top) { + ip_y = pad_top * 2 - i; + } else if (i >= pad_top && i < x_height + pad_top) { + ip_y = i; + } else { + ip_y = (x_height + pad_top - 1) * 2 - i; + } + ip_x = ip_x - pad_left; + ip_y = ip_y - pad_top; + int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j; + int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x; + dest[dest_index] = src[src_index]; + } +} + +template<typename IN_T> +OF_DEVICE_FUNC void DoReflectionPad2dGrad(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + XPU_1D_KERNEL_LOOP(k, elem_num) { + int64_t n, c, i, j, ip_x, ip_y; + int64_t coord[4]; + index_helper.OffsetToNdIndex(k, coord); + n = coord[0]; + c = coord[1]; + i = coord[2]; + j = coord[3]; + if (j < pad_left) { + ip_x = pad_left * 2 - j; + } else if (j >= pad_left && j < dx_width + pad_left) { + ip_x = j; + } else { + ip_x = (dx_width + pad_left - 1) * 2 - j; + } + + if (i < pad_top) { + ip_y = pad_top * 2 - i; + } else if (i >= pad_top && i < dx_height + pad_top) { + ip_y = i; + } else { + ip_y = (dx_height + pad_top - 1) * 2 - i; + } + ip_x = ip_x - pad_left; + ip_y = ip_y - pad_top; + + int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j; + int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x; + DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index); + } +} + +// macros for functors instantiate(used by pad2d_kernels_util.cu) +#define INSTANTIATE_REFLECTION_PAD2D_FUNCTOR(device_type_v, dtype_pair) \ + template struct ReflectionPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +#define INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR(device_type_v, dtype_pair) \ + template struct ReflectionPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +template<DeviceType device_type, typename IN_T> +struct ReplicationPad2dFunctor final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top); +}; + +template<DeviceType device_type, typename IN_T> +struct ReplicationPad2dGradFunctor final { + void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, + int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top); +}; + +template<typename IN_T> +OF_DEVICE_FUNC void DoReplicationPad2d(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t y_height, int64_t y_width, int64_t x_height, + int64_t x_width, int64_t pad_left, int64_t pad_top) { + XPU_1D_KERNEL_LOOP(k, elem_num) { + int64_t n, c, i, j, ip_x, ip_y; + int64_t coord_y[4]; + index_helper.OffsetToNdIndex(k, coord_y); + n = coord_y[0]; + c = coord_y[1]; + i = coord_y[2]; + j = coord_y[3]; + if (j < pad_left) { + ip_x = pad_left; + } else if (j >= pad_left && j < x_width + pad_left) { + ip_x = j; + } else { + ip_x = x_width + pad_left - 1; + } + + if (i < pad_top) { + ip_y = pad_top; + } else if (i >= pad_top && i < x_height + pad_top) { + ip_y = i; + } else { + ip_y = x_height + pad_top - 1; + } + ip_x = ip_x - pad_left; + ip_y = ip_y - pad_top; + + int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j; + int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x; + dest[dest_index] = src[src_index]; + } +} + +template<typename IN_T> +OF_DEVICE_FUNC void DoReplicationPad2dGrad(const IN_T *src, IN_T *dest, + const NdIndexOffsetHelper<int64_t, 4> &index_helper, + int64_t elem_num, int64_t src_num, int64_t dest_num, + int64_t dy_height, int64_t dy_width, int64_t dx_height, + int64_t dx_width, int64_t pad_left, int64_t pad_top) { + XPU_1D_KERNEL_LOOP(k, elem_num) { + int64_t n, c, i, j, ip_x, ip_y; + int64_t coord[4]; + index_helper.OffsetToNdIndex(k, coord); + n = coord[0]; + c = coord[1]; + i = coord[2]; + j = coord[3]; + if (j < pad_left) { + ip_x = pad_left; + } else if (j >= pad_left && j < dx_width + pad_left) { + ip_x = j; + } else { + ip_x = dx_width + pad_left - 1; + } + + if (i < pad_top) { + ip_y = pad_top; + } else if (i >= pad_top && i < dx_height + pad_top) { + ip_y = i; + } else { + ip_y = dx_height + pad_top - 1; + } + ip_x = ip_x - pad_left; + ip_y = ip_y - pad_top; + + int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j; + int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x; + DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index); + } +} + +// macros for functors instantiate(used by pad2d_kernels_util.cu) +#define INSTANTIATE_REPLICATION_PAD2D_FUNCTOR(device_type_v, dtype_pair) \ + template struct ReplicationPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +#define INSTANTIATE_REPLICATION_PAD2D_GRAD_FUNCTOR(device_type_v, dtype_pair) \ + template struct ReplicationPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; + +} // namespace user_op +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_PAD2D_KERNELS_UTIL_H_ diff --git a/oneflow/user/kernels/reflection_pad2d_kernel.cpp b/oneflow/user/kernels/reflection_pad2d_kernel.cpp deleted file mode 100644 index e9a876781d1fc36ed39f9f2ad8910a56cee8a090..0000000000000000000000000000000000000000 --- a/oneflow/user/kernels/reflection_pad2d_kernel.cpp +++ /dev/null @@ -1,151 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "oneflow/core/device/memory_copier.h" -#include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/core/common/nd_index_offset_helper.h" -#include "oneflow/user/kernels/reflection_pad2d_kernel_util.h" - -namespace oneflow { -namespace user_op { - -// Fill ShapeView into dim vector -DimVector ShapeViewToDimVector(const ShapeView& tensor_shape) { - int64_t ndims = tensor_shape.NumAxes(); - DimVector shape_vec(ndims); - for (int64_t i = 0; i < ndims; ++i) { shape_vec[i] = tensor_shape.At(i); } - shape_vec[ndims - 1] = shape_vec[ndims - 1]; - return shape_vec; -} - -template<DeviceType device_type, typename IN_T> -class ReflectionPad2dKernel final : public OpKernel { - public: - ReflectionPad2dKernel() = default; - ~ReflectionPad2dKernel() = default; - - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); - Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); - const int64_t ndims = x->shape().NumAxes(); - CHECK_EQ(padding.size(), ndims); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - const int64_t pad_left = padding[0]; - const int64_t pad_top = padding[2]; - - const int64_t n_batch = y->shape().At(n_idx); - const int64_t n_channel = y->shape().At(c_idx); - const int64_t y_height = y->shape().At(h_idx); - const int64_t y_width = y->shape().At(w_idx); - const int64_t x_height = x->shape().At(h_idx); - const int64_t x_width = x->shape().At(w_idx); - - IN_T* dest = y->mut_dptr<IN_T>(); - const IN_T* src = x->dptr<IN_T>(); - DimVector y_vector = ShapeViewToDimVector(y->shape()); - NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data()); - - ReflectionPad2dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, n_batch, - n_channel, y_height, y_width, x_height, x_width, - pad_left, pad_top); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -template<DeviceType device_type, typename IN_T> -class ReflectionPad2dGradKernel final : public OpKernel { - public: - ReflectionPad2dGradKernel() = default; - ~ReflectionPad2dGradKernel() = default; - - private: - void Compute(KernelComputeContext* ctx) const override { - const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); - Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); - const int64_t ndims = dy->shape().NumAxes(); - CHECK_EQ(padding.size(), ndims); - - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - int64_t pad_left = padding[0]; - int64_t pad_top = padding[2]; - int64_t n_batch = dy->shape().At(n_idx); - int64_t n_channel = dy->shape().At(c_idx); - int64_t dy_height = dy->shape().At(h_idx); - int64_t dy_width = dy->shape().At(w_idx); - int64_t dx_height = dx->shape().At(h_idx); - int64_t dx_width = dx->shape().At(w_idx); - - const IN_T* src = dy->dptr<IN_T>(); - IN_T* dest = dx->mut_dptr<IN_T>(); - DimVector dy_vector = ShapeViewToDimVector(dy->shape()); - NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data()); - - size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type()); - Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size); - - ReflectionPad2dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, - n_batch, n_channel, dy_height, dy_width, - dx_height, dx_width, pad_left, pad_top); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_REFLECTION_PAD2D_KERNELS(device, dtype) \ - REGISTER_USER_KERNEL("reflection_pad2d") \ - .SetCreateFn<ReflectionPad2dKernel<device, dtype>>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \ - REGISTER_USER_KERNEL("reflection_pad2d_grad") \ - .SetCreateFn<ReflectionPad2dGradKernel<device, dtype>>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); - -#define REGISTER_REFLECTION_PAD2D_WITH_DEVICE(device) \ - REGISTER_REFLECTION_PAD2D_KERNELS(device, float) \ - REGISTER_REFLECTION_PAD2D_KERNELS(device, double) \ - REGISTER_REFLECTION_PAD2D_KERNELS(device, int32_t) - -REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kCPU) -#ifdef WITH_CUDA -REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kGPU) -REGISTER_REFLECTION_PAD2D_KERNELS(DeviceType::kGPU, float16) -#endif - -} // namespace user_op -} // namespace oneflow \ No newline at end of file diff --git a/oneflow/user/kernels/reflection_pad2d_kernel_util.cpp b/oneflow/user/kernels/reflection_pad2d_kernel_util.cpp deleted file mode 100644 index 73ec583430c10eb9ebc1ea2aca66c3393714385b..0000000000000000000000000000000000000000 --- a/oneflow/user/kernels/reflection_pad2d_kernel_util.cpp +++ /dev/null @@ -1,58 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/reflection_pad2d_kernel_util.h" - -namespace oneflow { - -namespace user_op { - -template<typename IN_T> -struct ReflectionPad2dFunctor<DeviceType::kCPU, IN_T> final { - void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, - int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, - int64_t x_width, int64_t pad_left, int64_t pad_top) { - int64_t dest_num = n_channel * y_height * y_width; - int64_t src_num = n_channel * x_height * x_width; - int64_t elem_num = n_batch * dest_num; - DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, - x_height, x_width, pad_left, pad_top); - } -}; - -template<typename IN_T> -struct ReflectionPad2dGradFunctor<DeviceType::kCPU, IN_T> final { - void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, - int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, - int64_t dx_width, int64_t pad_left, int64_t pad_top) { - int64_t dest_num = n_channel * dx_height * dx_width; - int64_t src_num = n_channel * dy_height * dy_width; - int64_t elem_num = n_batch * src_num; - DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, - dy_width, dx_height, dx_width, pad_left, pad_top); - } -}; - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR, (DeviceType::kCPU), - REFLECTION_PAD2D_DATA_TYPE_CPU_SEQ); - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR, (DeviceType::kCPU), - REFLECTION_PAD2D_GRAD_DATA_TYPE_CPU_SEQ); - -} // namespace user_op -} // namespace oneflow diff --git a/oneflow/user/kernels/reflection_pad2d_kernel_util.cu b/oneflow/user/kernels/reflection_pad2d_kernel_util.cu deleted file mode 100644 index 9c36b88061712b134f77b1e551aa5c3b80d7a8eb..0000000000000000000000000000000000000000 --- a/oneflow/user/kernels/reflection_pad2d_kernel_util.cu +++ /dev/null @@ -1,120 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include <cstdint> -#ifdef WITH_CUDA -#include "oneflow/core/framework/framework.h" -#include "oneflow/core/common/data_type.h" -#include "oneflow/user/kernels/reflection_pad2d_kernel_util.h" - -namespace oneflow { -namespace user_op { - -template<typename IN_T> -__global__ void DoCUDAReflectionPad2d(const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4> index_helper, - int64_t elem_num, int64_t src_num, int64_t dest_num, - int64_t y_height, int64_t y_width, int64_t x_height, - int64_t x_width, int64_t pad_left, int64_t pad_top) { - DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, - x_height, x_width, pad_left, pad_top); -}; - -template<typename IN_T> -__global__ void DoCUDAReflectionPad2dGrad(const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4> index_helper, - int64_t elem_num, int64_t src_num, int64_t dest_num, - int64_t dy_height, int64_t dy_width, int64_t dx_height, - int64_t dx_width, int64_t pad_left, int64_t pad_top) { - DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height, - dy_width, dx_height, dx_width, pad_left, pad_top); -}; - -template<typename IN_T> -struct ReflectionPad2dFunctor<DeviceType::kGPU, IN_T> final { - void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, - int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, - int64_t x_width, int64_t pad_left, int64_t pad_top) { - int64_t dest_num = n_channel * y_height * y_width; - int64_t src_num = n_channel * x_height * x_width; - int64_t elem_num = n_batch * dest_num; - DoCUDAReflectionPad2d<IN_T> - <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height, - x_width, pad_left, pad_top); - } -}; - -// float16 implementation -template<> -void ReflectionPad2dFunctor<DeviceType::kGPU, float16>::operator()( - DeviceCtx* ctx, const float16* src, float16* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, int64_t n_channel, - int64_t y_height, int64_t y_width, int64_t x_height, int64_t x_width, int64_t pad_left, - int64_t pad_top) { - int64_t dest_num = n_channel * y_height * y_width; - int64_t src_num = n_channel * x_height * x_width; - int64_t elem_num = n_batch * dest_num; - DoCUDAReflectionPad2d<half> - <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num, - src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top); -} - -template<typename IN_T> -struct ReflectionPad2dGradFunctor<DeviceType::kGPU, IN_T> final { - void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, - int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, - int64_t dx_width, int64_t pad_left, int64_t pad_top) { - int64_t dest_num = n_channel * dx_height * dx_width; - int64_t src_num = n_channel * dy_height * dy_width; - int64_t elem_num = n_batch * src_num; - DoCUDAReflectionPad2dGrad<IN_T> - <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height, - dx_width, pad_left, pad_top); - } -}; - -// float16 implementation -template<> -void ReflectionPad2dGradFunctor<DeviceType::kGPU, float16>::operator()( - DeviceCtx* ctx, const float16* src, float16* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, int64_t n_channel, - int64_t dy_height, int64_t dy_width, int64_t dx_height, int64_t dx_width, int64_t pad_left, - int64_t pad_top) { - int64_t dest_num = n_channel * dx_height * dx_width; - int64_t src_num = n_channel * dy_height * dy_width; - int64_t elem_num = n_batch * src_num; - DoCUDAReflectionPad2dGrad<half> - <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num, - src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top); -} - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR, - OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), - REFLECTION_PAD2D_DATA_TYPE_GPU_SEQ); - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR, - OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), - REFLECTION_PAD2D_GRAD_DATA_TYPE_GPU_SEQ); - -} // namespace user_op -} // namespace oneflow - -#endif // WITH_CUDA \ No newline at end of file diff --git a/oneflow/user/kernels/reflection_pad2d_kernel_util.h b/oneflow/user/kernels/reflection_pad2d_kernel_util.h deleted file mode 100644 index 3fe1ccf52cbe8e3be676720ef9ce7492e1b8c2e0..0000000000000000000000000000000000000000 --- a/oneflow/user/kernels/reflection_pad2d_kernel_util.h +++ /dev/null @@ -1,156 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_USER_KERNELS_REFLECTION_PAD2D_KERNEL_UTIL_H_ -#define ONEFLOW_USER_KERNELS_REFLECTION_PAD2D_KERNEL_UTIL_H_ -#ifdef WITH_CUDA -#include "oneflow/core/cuda/atomic.cuh" -#endif // WITH_CUDA -#include "oneflow/core/ndarray/xpu_util.h" -#include "oneflow/core/common/nd_index_offset_helper.h" - -namespace oneflow { - -#define REFLECTION_PAD2D_DATA_TYPE_CPU_SEQ \ - FLOATING_DATA_TYPE_SEQ \ - OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) - -#define REFLECTION_PAD2D_DATA_TYPE_GPU_SEQ \ - FLOAT16_DATA_TYPE_SEQ \ - REFLECTION_PAD2D_DATA_TYPE_CPU_SEQ - -#define REFLECTION_PAD2D_GRAD_DATA_TYPE_CPU_SEQ \ - FLOATING_DATA_TYPE_SEQ \ - OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) - -#define REFLECTION_PAD2D_GRAD_DATA_TYPE_GPU_SEQ \ - FLOAT16_DATA_TYPE_SEQ \ - REFLECTION_PAD2D_GRAD_DATA_TYPE_CPU_SEQ - -namespace user_op { - -template<typename T> -struct DeviceAdd { - OF_DEVICE_FUNC static void Invoke(const T* x, T* y) { -#if defined(__CUDA_ARCH__) - cuda::atomic::Add(y, *x); -#else - *y += *x; -#endif - }; -}; - -template<DeviceType device_type, typename IN_T> -struct ReflectionPad2dFunctor final { - void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, - int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height, - int64_t x_width, int64_t pad_left, int64_t pad_top); -}; - -template<DeviceType device_type, typename IN_T> -struct ReflectionPad2dGradFunctor final { - void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, - int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height, - int64_t dx_width, int64_t pad_left, int64_t pad_top); -}; - -template<typename IN_T> -OF_DEVICE_FUNC void DoReflectionPad2d(const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, - int64_t elem_num, int64_t src_num, int64_t dest_num, - int64_t y_height, int64_t y_width, int64_t x_height, - int64_t x_width, int64_t pad_left, int64_t pad_top) { - XPU_1D_KERNEL_LOOP(k, elem_num) { - int64_t n, c, i, j, ip_x, ip_y; - int64_t coord_y[4]; - index_helper.OffsetToNdIndex(k, coord_y); - n = coord_y[0]; - c = coord_y[1]; - i = coord_y[2]; - j = coord_y[3]; - if (j < pad_left) { - ip_x = pad_left * 2 - j; - } else if (j >= pad_left && j < x_width + pad_left) { - ip_x = j; - } else { - ip_x = (x_width + pad_left - 1) * 2 - j; - } - - if (i < pad_top) { - ip_y = pad_top * 2 - i; - } else if (i >= pad_top && i < x_height + pad_top) { - ip_y = i; - } else { - ip_y = (x_height + pad_top - 1) * 2 - i; - } - ip_x = ip_x - pad_left; - ip_y = ip_y - pad_top; - int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j; - int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x; - dest[dest_index] = src[src_index]; - } -} - -template<typename IN_T> -OF_DEVICE_FUNC void DoReflectionPad2dGrad(const IN_T* src, IN_T* dest, - const NdIndexOffsetHelper<int64_t, 4>& index_helper, - int64_t elem_num, int64_t src_num, int64_t dest_num, - int64_t dy_height, int64_t dy_width, int64_t dx_height, - int64_t dx_width, int64_t pad_left, int64_t pad_top) { - XPU_1D_KERNEL_LOOP(k, elem_num) { - int64_t n, c, i, j, ip_x, ip_y; - int64_t coord[4]; - index_helper.OffsetToNdIndex(k, coord); - n = coord[0]; - c = coord[1]; - i = coord[2]; - j = coord[3]; - if (j < pad_left) { - ip_x = pad_left * 2 - j; - } else if (j >= pad_left && j < dx_width + pad_left) { - ip_x = j; - } else { - ip_x = (dx_width + pad_left - 1) * 2 - j; - } - - if (i < pad_top) { - ip_y = pad_top * 2 - i; - } else if (i >= pad_top && i < dx_height + pad_top) { - ip_y = i; - } else { - ip_y = (dx_height + pad_top - 1) * 2 - i; - } - ip_x = ip_x - pad_left; - ip_y = ip_y - pad_top; - - int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j; - int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x; - DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index); - } -} - -// macros for functors instantiate(used by reflection_pad2d_kernel_util.cu) -#define INSTANTIATE_REFLECTION_PAD2D_FUNCTOR(device_type_v, dtype_pair) \ - template struct ReflectionPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; - -#define INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR(device_type_v, dtype_pair) \ - template struct ReflectionPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>; - -} // namespace user_op -} // namespace oneflow - -#endif // ONEFLOW_USER_KERNELS_REFLECTION_PAD2D_KERNEL_UTIL_H_ \ No newline at end of file diff --git a/oneflow/user/ops/pad2d_ops.cpp b/oneflow/user/ops/pad2d_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05d5e12a4bd8de53fbfa8f946d605d0c7e0b88c7 --- /dev/null +++ b/oneflow/user/ops/pad2d_ops.cpp @@ -0,0 +1,134 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/ops/nn_util.h" +#include "pad_2d_seq.h" + +namespace oneflow { + +namespace { + +Maybe<void> GetOpSbpSignature(user_op::SbpContext *ctx) { + const user_op::TensorDesc &x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (padding[i] == 0) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + } + return Maybe<void>::Ok(); +} + +Maybe<void> GetOpGradSbpSignature(user_op::SbpContext *ctx) { + const user_op::TensorDesc &dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); + FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) { + if (padding[i] == 0) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + } + return Maybe<void>::Ok(); +} + +} // namespace + +#define REGISTER_PAD_2D_OP_AND_GRAD(pad_2d_type) \ + REGISTER_USER_OP(pad_2d_type) \ + .Input("x") \ + .Output("y") \ + .Attr<std::vector<int64_t>>("padding") \ + .SetTensorDescInferFn([](user_op::InferContext *ctx) -> Maybe<void> { \ + Shape *x_shape = ctx->Shape4ArgNameAndIndex("x", 0); \ + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); \ + CHECK_EQ_OR_RETURN(padding.size(), x_shape->NumAxes()); \ + const int64_t n_idx = 0; \ + const int64_t c_idx = 1; \ + const int64_t h_idx = 2; \ + const int64_t w_idx = 3; \ + CHECK_LT_OR_RETURN(padding[0], x_shape->At(w_idx)); \ + CHECK_LT_OR_RETURN(padding[1], x_shape->At(w_idx)); \ + CHECK_LT_OR_RETURN(padding[2], x_shape->At(h_idx)); \ + CHECK_LT_OR_RETURN(padding[3], x_shape->At(h_idx)); \ + \ + DimVector y_dim_vec(x_shape->NumAxes()); \ + const int64_t h_x = x_shape->At(h_idx); \ + const int64_t w_x = x_shape->At(w_idx); \ + \ + y_dim_vec[n_idx] = x_shape->At(n_idx); \ + y_dim_vec[c_idx] = x_shape->At(c_idx); \ + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; \ + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; \ + \ + *ctx->Shape4ArgNameAndIndex("y", 0) = Shape(y_dim_vec); \ + *ctx->Dtype4ArgNameAndIndex("y", 0) = *ctx->Dtype4ArgNameAndIndex("x", 0); \ + return Maybe<void>::Ok(); \ + }) \ + .SetGetSbpFn(GetOpSbpSignature) \ + .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ + const user_op::UserOpConfWrapper &) { \ + user_op::InputArgModifier *x_modifier = GetInputArgModifierFn("x", 0); \ + CHECK_NOTNULL(x_modifier); \ + x_modifier->set_requires_grad(true); \ + }); \ + \ + REGISTER_USER_OP((std::string("") + pad_2d_type + "_grad")) \ + .Input("dy") \ + .Output("dx") \ + .Attr<std::vector<int64_t>>("padding") \ + .SetTensorDescInferFn([](user_op::InferContext *ctx) -> Maybe<void> { \ + Shape *dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0); \ + const auto &padding = ctx->Attr<std::vector<int64_t>>("padding"); \ + CHECK_EQ_OR_RETURN(padding.size(), dy_shape->NumAxes()); \ + const int64_t n_idx = 0; \ + const int64_t c_idx = 1; \ + const int64_t h_idx = 2; \ + const int64_t w_idx = 3; \ + \ + DimVector dx_dim_vec(dy_shape->NumAxes()); \ + int64_t h_dy, w_dy; \ + h_dy = dy_shape->At(h_idx); \ + w_dy = dy_shape->At(w_idx); \ + \ + dx_dim_vec[n_idx] = dy_shape->At(0); \ + dx_dim_vec[c_idx] = dy_shape->At(1); \ + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; \ + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; \ + \ + *ctx->Shape4ArgNameAndIndex("dx", 0) = Shape(dx_dim_vec); \ + *ctx->Dtype4ArgNameAndIndex("dx", 0) = *ctx->Dtype4ArgNameAndIndex("dy", 0); \ + return Maybe<void>::Ok(); \ + }) \ + .SetGetSbpFn(GetOpGradSbpSignature); \ + \ + REGISTER_USER_OP_GRAD(pad_2d_type) \ + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper &op, user_op::AddOpFn AddOp) { \ + if (op.NeedGenGradTensor4OpInput("x", 0)) { \ + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); \ + user_op::UserOpConfWrapper grad_op = \ + builder.Op((std::string("") + pad_2d_type + "_grad")) \ + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) \ + .Output("dx") \ + .Attr("padding", op.attr<std::vector<int64_t>>("padding")) \ + .Build(); \ + op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); \ + AddOp(grad_op); \ + } \ + }); + +OF_PP_FOR_EACH_TUPLE(REGISTER_PAD_2D_OP_AND_GRAD, PAD_2D_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/user/ops/pad_2d_seq.h b/oneflow/user/ops/pad_2d_seq.h new file mode 100644 index 0000000000000000000000000000000000000000..fb750a1e4b9eaac5c80b8697ecc8784eba0fbee0 --- /dev/null +++ b/oneflow/user/ops/pad_2d_seq.h @@ -0,0 +1,28 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_OPS_PAD_2D_SEQ_H_ +#define ONEFLOW_USER_OPS_PAD_2D_SEQ_H_ + +#include "oneflow/core/common/util.h" + +namespace oneflow { + +#define PAD_2D_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ("reflection_pad2d") \ + OF_PP_MAKE_TUPLE_SEQ("replication_pad2d") +} // namespace oneflow + +#endif // ONEFLOW_USER_OPS_PAD_2D_SEQ_H_ diff --git a/oneflow/user/ops/reflection_pad2d_op.cpp b/oneflow/user/ops/reflection_pad2d_op.cpp deleted file mode 100644 index 8380e4830957377cb171f1da8d2a0d1edc30c8ee..0000000000000000000000000000000000000000 --- a/oneflow/user/ops/reflection_pad2d_op.cpp +++ /dev/null @@ -1,131 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "oneflow/core/common/balanced_splitter.h" -#include "oneflow/user/ops/nn_util.h" - -namespace oneflow { - -namespace { - -Maybe<void> GetOpSbpSignature(user_op::SbpContext* ctx) { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (padding[i] == 0) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - } - return Maybe<void>::Ok(); -} - -Maybe<void> GetOpGradSbpSignature(user_op::SbpContext* ctx) { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); - FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) { - if (padding[i] == 0) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - } - return Maybe<void>::Ok(); -} - -} // namespace - -REGISTER_USER_OP("reflection_pad2d") - .Input("x") - .Output("y") - .Attr<std::vector<int64_t>>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { - Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0); - const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape->NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - // (padding_left, padding_right, padding_top, padding_bottom) - CHECK_LT_OR_RETURN(padding[0], x_shape->At(w_idx)); - CHECK_LT_OR_RETURN(padding[1], x_shape->At(w_idx)); - CHECK_LT_OR_RETURN(padding[2], x_shape->At(h_idx)); - CHECK_LT_OR_RETURN(padding[3], x_shape->At(h_idx)); - - DimVector y_dim_vec(x_shape->NumAxes()); - const int64_t h_x = x_shape->At(h_idx); - const int64_t w_x = x_shape->At(w_idx); - - y_dim_vec[n_idx] = x_shape->At(n_idx); - y_dim_vec[c_idx] = x_shape->At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->Shape4ArgNameAndIndex("y", 0) = Shape(y_dim_vec); - *ctx->Dtype4ArgNameAndIndex("y", 0) = *ctx->Dtype4ArgNameAndIndex("x", 0); - return Maybe<void>::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); - x_modifier->set_requires_grad(true); - }); - -REGISTER_USER_OP("reflection_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr<std::vector<int64_t>>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { - Shape* dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0); - const auto& padding = ctx->Attr<std::vector<int64_t>>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape->NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape->NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape->At(h_idx); - w_dy = dy_shape->At(w_idx); - - dx_dim_vec[n_idx] = dy_shape->At(0); - dx_dim_vec[c_idx] = dy_shape->At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->Shape4ArgNameAndIndex("dx", 0) = Shape(dx_dim_vec); - *ctx->Dtype4ArgNameAndIndex("dx", 0) = *ctx->Dtype4ArgNameAndIndex("dy", 0); - return Maybe<void>::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature); - -REGISTER_USER_OP_GRAD("reflection_pad2d") - .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { - if (op.NeedGenGradTensor4OpInput("x", 0)) { - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - user_op::UserOpConfWrapper grad_op = - builder.Op("reflection_pad2d_grad") - .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) - .Output("dx") - .Attr("padding", op.attr<std::vector<int64_t>>("padding")) - .Build(); - op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); - AddOp(grad_op); - } - }); - -} // namespace oneflow \ No newline at end of file