From 2b208ec02d5ee77cdf7e960b7287a24386da41a2 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Tue, 20 Jul 2021 15:29:00 +0800 Subject: [PATCH] Add flip module (#5541) * fix upsample nearest bug * fix upsample nearest bug (#5347) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix upsample bilinear bug * init flip op * add flip op register * add flip cpu kernel forward * add flip kernel impl * add flip op functor and gradient_funcs * add test, still need fix bug * fix segmentfault bug * add docs * fix comments * fix comments * fix comments Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- docs/source/experimental.rst | 2 + oneflow/core/autograd/gradient_funcs/flip.cpp | 69 +++++++++ oneflow/core/functional/functional_api.yaml | 8 ++ .../core/functional/impl/array_functor.cpp | 32 +++++ oneflow/python/nn/modules/flip.py | 100 +++++++++++++ oneflow/python/test/modules/test_flip.py | 53 +++++++ oneflow/user/kernels/flip_kernel.cpp | 132 ++++++++++++++++++ oneflow/user/kernels/flip_kernel.cu | 132 ++++++++++++++++++ oneflow/user/ops/flip_op.cpp | 63 +++++++++ 9 files changed, 591 insertions(+) create mode 100644 oneflow/core/autograd/gradient_funcs/flip.cpp create mode 100644 oneflow/python/nn/modules/flip.py create mode 100644 oneflow/python/test/modules/test_flip.py create mode 100644 oneflow/user/kernels/flip_kernel.cpp create mode 100644 oneflow/user/kernels/flip_kernel.cu create mode 100644 oneflow/user/ops/flip_op.cpp diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index e58b0d9c1..6aacefedd 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -194,6 +194,8 @@ Experimental features .. autofunction:: oneflow.experimental.Tensor.reshape .. autofunction:: oneflow.experimental.squeeze .. autofunction:: oneflow.experimental.Tensor.squeeze +.. autofunction:: oneflow.experimental.flip +.. autofunction:: oneflow.experimental.Tensor.flip .. autofunction:: oneflow.experimental.transpose .. autofunction:: oneflow.experimental.Tensor.transpose .. autofunction:: oneflow.experimental.unsqueeze diff --git a/oneflow/core/autograd/gradient_funcs/flip.cpp b/oneflow/core/autograd/gradient_funcs/flip.cpp new file mode 100644 index 000000000..1022e4900 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/flip.cpp @@ -0,0 +1,69 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct FlipInterpState : public OpExprInterpState { + bool requires_grad; + std::vector<int32_t> dims; +}; + +class Flip : public OpExprGradFunction<FlipInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override; + Maybe<void> Capture(FlipInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override; + Maybe<void> Apply(const FlipInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe<void> Flip::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe<void>::Ok(); +} + +Maybe<void> Flip::Capture(FlipInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe<void>::Ok(); } + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dims")); + return Maybe<void>::Ok(); +} + +Maybe<void> Flip::Apply(const FlipInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + in_grads->at(0) = JUST(functional::FlipGrad(out_grads.at(0), ctx->dims)); + } + return Maybe<void>::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("flip", Flip); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index a936c88fd..5c1a6f5ab 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -470,6 +470,14 @@ signature: "Tensor Copy(Tensor x, *, String device_type, Int64 device_id)" bind_python: True +- name: "flip" + signature: "Tensor Flip(Tensor x, *, Int32List dims)" + bind_python: True + +- name: "flip_grad" + signature: "Tensor FlipGrad(Tensor dy, *, Int32List dims)" + bind_python: False + - name: "upsample" signature: "Tensor Upsample(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners, diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index e7d913b2c..b9cd3f395 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -436,6 +436,36 @@ class CopyFunctor { std::shared_ptr<OpExpr> op_; }; +class FlipFunctor { + public: + FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, + const std::vector<int32_t>& dims) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<std::vector<int32_t>>("dims", dims)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class FlipGradFunctor { + public: + FlipGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("flip_grad").Input("dy").Output("dx").Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, + const std::vector<int32_t>& dims) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<std::vector<int32_t>>("dims", dims)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + class UpsampleFunctor { public: UpsampleFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample").Input("x").Output("y").Build()); } @@ -923,6 +953,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::SliceUpdateFunctor>("SliceUpdate"); m.add_functor<impl::SqueezeFunctor>("Squeeze"); m.add_functor<impl::CopyFunctor>("Copy"); + m.add_functor<impl::FlipFunctor>("Flip"); + m.add_functor<impl::FlipGradFunctor>("FlipGrad"); m.add_functor<impl::UpsampleFunctor>("Upsample"); m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D"); m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad"); diff --git a/oneflow/python/nn/modules/flip.py b/oneflow/python/nn/modules/flip.py new file mode 100644 index 000000000..6c61a6b83 --- /dev/null +++ b/oneflow/python/nn/modules/flip.py @@ -0,0 +1,100 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import collections +from typing import Optional, Sequence, Union + +import oneflow as flow +from oneflow.python.oneflow_export import oneflow_export, experimental_api +from oneflow.python.nn.module import Module +from oneflow.python.framework.tensor import register_tensor_op +from oneflow.python.nn.modules.utils import _check_axis + + +class Flip(Module): + def __init__(self, dims) -> None: + super().__init__() + assert isinstance(dims, (list, tuple)), f"dims must be list or tuple" + self.dims = dims + + def forward(self, x): + input_len = len(x.shape) + assert ( + len(self.dims) <= input_len + ), f"len of dims must less than len of input tensor" + new_dims = [] + for i in self.dims: + if i < 0: + i += input_len + assert ( + i < input_len + ), f"IndexError: Dimension out of range (expected to be in range of {input_len}, but got {i})" + new_dims.append(i) + return flow.F.flip(x, new_dims) + + +@oneflow_export("flip") +@experimental_api +def flip_op(input, dims): + + r""" + + Reverse the order of a n-D tensor along given axis in dims. + + .. note:: + `flow.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `flow.flip` is expected to be slower than `np.flip`. + + Args: + input (Tensor): the input tensor + dims (a list or tuple): axis to flip on + + For example: + + .. code-block:: python + + >>> import oneflow.experimental as flow + >>> import numpy as np + + >>> np_arr = np.arange(0, 8).reshape((2, 2, 2)).astype(np.float32) + >>> input = flow.Tensor(np_arr) + >>> out = flow.flip(input, [0, 1]) + >>> out + tensor([[[6., 7.], + [4., 5.]], + <BLANKLINE> + [[2., 3.], + [0., 1.]]], dtype=oneflow.float32) + + """ + + return Flip(dims)(input) + + +@register_tensor_op("flip") +@experimental_api +def flip_op_tensor(input, dims): + r""" + See :func:`oneflow.experimental.flip` + """ + return Flip(dims)(input) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/test/modules/test_flip.py b/oneflow/python/test/modules/test_flip.py new file mode 100644 index 000000000..a5ad28f21 --- /dev/null +++ b/oneflow/python/test/modules/test_flip.py @@ -0,0 +1,53 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow.experimental as flow +from test_util import GenArgList +from automated_test_util import * + + +def _test_flip(test_case, device): + np_arr = np.arange(0, 16).reshape((2, 2, 2, 2)).astype(np.float32) + input = flow.Tensor(np_arr, device=flow.device(device), requires_grad=True) + out = flow.flip(input, [0, 1, 2]) + np_out = [ + [[[14.0, 15.0], [12.0, 13.0]], [[10.0, 11.0], [8.0, 9.0]]], + [[[6.0, 7.0], [4.0, 5.0]], [[2.0, 3.0], [0.0, 1.0]]], + ] + test_case.assertTrue(np.allclose(out.numpy(), np_out, 1e-5, 1e-5)) + out = out.sum() + out = out.backward() + np_grad = np.ones_like(np_arr) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + +class TestFlip(flow.unittest.TestCase): + def test_flip(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_flip, + ] + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/user/kernels/flip_kernel.cpp b/oneflow/user/kernels/flip_kernel.cpp new file mode 100644 index 000000000..89ad832cb --- /dev/null +++ b/oneflow/user/kernels/flip_kernel.cpp @@ -0,0 +1,132 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/common/nd_index_offset_helper.h" + +namespace oneflow { + +namespace { + +const int32_t NDIMS = 16; + +struct SIZE_V { + int32_t val[NDIMS]; +}; + +struct VIS { + bool val[NDIMS] = {false}; +}; + +template<typename T> +void FlipCpuForward(const int32_t element, const int64_t total_dims, + const SIZE_V stride_contiguous_v, const SIZE_V sizes_v, const VIS vis, + SIZE_V strides_v, const T* in_dptr, T* out_dptr) { + for (int i = 0; i < element; i++) { + int32_t cur_indices = i; + int32_t rem = 0; + int32_t dst_offset = 0; + for (int32_t d = 0; d < total_dims; d++) { + int32_t temp = cur_indices; + cur_indices = cur_indices / stride_contiguous_v.val[d]; + rem = temp - cur_indices * stride_contiguous_v.val[d]; + dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d] + : cur_indices * strides_v.val[d]; + cur_indices = rem; + } + out_dptr[i] = in_dptr[dst_offset]; + } +} + +} // namespace + +template<typename T> +class FlipCpuKernel final : public user_op::OpKernel { + public: + FlipCpuKernel() = default; + ~FlipCpuKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); + const int32_t elem_cnt = y_tensor->shape().elem_cnt(); + + const int32_t total_dims = y_tensor->shape().NumAxes(); + + std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims"); + VIS vis; + for (auto x : dims) { vis.val[x] = true; } + + SIZE_V sizes_v; + for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); } + + // TODO(bbuf) delete strides caluculate, after tensor strides supported + SIZE_V strides_v; + strides_v.val[total_dims - 1] = 1; + for (int32_t i = total_dims - 2; i >= 0; i--) { + strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape().At(i); + } + + SIZE_V stride_contiguous_v; + + for (int32_t i = total_dims - 1; i >= 0; i--) { + if (i == total_dims - 1) { + stride_contiguous_v.val[i] = 1; + } else { + stride_contiguous_v.val[i] = + std::max<int32_t>(x_tensor->shape().At(i + 1), 1) * stride_contiguous_v.val[i + 1]; + } + } + + FlipCpuForward(elem_cnt, total_dims, stride_contiguous_v, sizes_v, vis, strides_v, + x_tensor->dptr<T>(), y_tensor->mut_dptr<T>()); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template<typename T> +class FlipGrad1DCpuKernel final : public user_op::OpKernel { + public: + FlipGrad1DCpuKernel() = default; + ~FlipGrad1DCpuKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); + Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0, + dx_tensor->shape().elem_cnt() * sizeof(T)); + const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); + memcpy((void*)dx_tensor->mut_dptr<T>(), (void*)dy_tensor->dptr<T>(), + dy_tensor->shape().elem_cnt() * sizeof(T)); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_FLIP_CPU_KERNEL(dtype) \ + REGISTER_USER_KERNEL("flip").SetCreateFn<FlipCpuKernel<dtype>>().SetIsMatchedHob( \ + (user_op::HobDeviceTag() == "cpu") \ + & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("flip_grad") \ + .SetCreateFn<FlipGrad1DCpuKernel<dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu") \ + & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); + +REGISTER_FLIP_CPU_KERNEL(float) +REGISTER_FLIP_CPU_KERNEL(double) +REGISTER_FLIP_CPU_KERNEL(int) + +} // namespace oneflow diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu new file mode 100644 index 000000000..00e5ed24d --- /dev/null +++ b/oneflow/user/kernels/flip_kernel.cu @@ -0,0 +1,132 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/common/nd_index_offset_helper.h" + +namespace oneflow { + +namespace { + +const int32_t NDIMS = 16; +struct SIZE_V { + int32_t val[NDIMS]; +}; + +struct VIS { + bool val[NDIMS] = {false}; +}; + +template<typename T> +__global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, + const SIZE_V stride_contiguous_v, const SIZE_V sizes_v, + const VIS vis, SIZE_V strides_v, const T* in_dptr, T* out_dptr) { + CUDA_1D_KERNEL_LOOP(i, element) { + int32_t cur_indices = i; + int32_t rem = 0; + int32_t dst_offset = 0; + for (int32_t d = 0; d < total_dims; d++) { + int32_t temp = cur_indices; + cur_indices = cur_indices / stride_contiguous_v.val[d]; + rem = temp - cur_indices * stride_contiguous_v.val[d]; + dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d] + : cur_indices * strides_v.val[d]; + cur_indices = rem; + } + out_dptr[i] = in_dptr[dst_offset]; + } +} + +} // namespace + +template<typename T> +class FlipGpuKernel final : public user_op::OpKernel { + public: + FlipGpuKernel() = default; + ~FlipGpuKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); + const int32_t elem_cnt = y_tensor->shape().elem_cnt(); + + const int32_t total_dims = y_tensor->shape().NumAxes(); + + std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims"); + VIS vis; + for (auto x : dims) { vis.val[x] = true; } + + SIZE_V sizes_v; + for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); } + + // TODO(bbuf) delete strides caluculate, after tensor strides supported + SIZE_V strides_v; + strides_v.val[total_dims - 1] = 1; + for (int32_t i = total_dims - 2; i >= 0; i--) { + strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape().At(i); + } + + SIZE_V stride_contiguous_v; + + for (int32_t i = total_dims - 1; i >= 0; i--) { + if (i == total_dims - 1) { + stride_contiguous_v.val[i] = 1; + } else { + stride_contiguous_v.val[i] = + std::max<int32_t>(x_tensor->shape().At(i + 1), 1) * stride_contiguous_v.val[i + 1]; + } + } + RUN_CUDA_KERNEL((FlipGpuForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt, total_dims, + stride_contiguous_v, sizes_v, vis, strides_v, x_tensor->dptr<T>(), + y_tensor->mut_dptr<T>()); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template<typename T> +class FlipGrad1DGpuKernel final : public user_op::OpKernel { + public: + FlipGrad1DGpuKernel() = default; + ~FlipGrad1DGpuKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); + Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0, + dx_tensor->shape().elem_cnt() * sizeof(T)); + const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0); + Memcpy<DeviceType::kGPU>( + ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(), + dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type())); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_FLIP_GPU_KERNEL(dtype) \ + REGISTER_USER_KERNEL("flip").SetCreateFn<FlipGpuKernel<dtype>>().SetIsMatchedHob( \ + (user_op::HobDeviceTag() == "gpu") \ + & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \ + REGISTER_USER_KERNEL("flip_grad") \ + .SetCreateFn<FlipGrad1DGpuKernel<dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \ + & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); + +REGISTER_FLIP_GPU_KERNEL(float) +REGISTER_FLIP_GPU_KERNEL(double) +REGISTER_FLIP_GPU_KERNEL(int) + +} // namespace oneflow diff --git a/oneflow/user/ops/flip_op.cpp b/oneflow/user/ops/flip_op.cpp new file mode 100644 index 000000000..417608205 --- /dev/null +++ b/oneflow/user/ops/flip_op.cpp @@ -0,0 +1,63 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +REGISTER_USER_OP("flip") + .Input("x") + .Output("y") + .Attr<std::vector<int32_t>>("dims") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); + const int input_dims = x_desc->shape().NumAxes(); + const std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims"); + CHECK_OR_RETURN(dims.size() <= input_dims) + << "len of dims must less than len of input tensor"; + for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = x_desc->shape(); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("flip_grad") + .Input("dy") + .Output("dx") + .Attr<std::vector<int32_t>>("dims") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + *dx_shape = dy_shape; + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe<void>::Ok(); + }); + +} // namespace oneflow -- GitLab