Skip to content
Snippets Groups Projects
Unverified Commit 2b208ec0 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add flip module (#5541)


* fix upsample nearest bug

* fix upsample nearest bug (#5347)

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* fix upsample bilinear bug

* init flip op

* add flip op register

* add flip cpu kernel forward

* add flip kernel impl

* add flip op functor and gradient_funcs

* add test, still need fix bug

* fix segmentfault bug

* add docs

* fix comments

* fix comments

* fix comments

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 4bab7aa5
No related branches found
No related tags found
No related merge requests found
...@@ -194,6 +194,8 @@ Experimental features ...@@ -194,6 +194,8 @@ Experimental features
.. autofunction:: oneflow.experimental.Tensor.reshape .. autofunction:: oneflow.experimental.Tensor.reshape
.. autofunction:: oneflow.experimental.squeeze .. autofunction:: oneflow.experimental.squeeze
.. autofunction:: oneflow.experimental.Tensor.squeeze .. autofunction:: oneflow.experimental.Tensor.squeeze
.. autofunction:: oneflow.experimental.flip
.. autofunction:: oneflow.experimental.Tensor.flip
.. autofunction:: oneflow.experimental.transpose .. autofunction:: oneflow.experimental.transpose
.. autofunction:: oneflow.experimental.Tensor.transpose .. autofunction:: oneflow.experimental.Tensor.transpose
.. autofunction:: oneflow.experimental.unsqueeze .. autofunction:: oneflow.experimental.unsqueeze
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FlipInterpState : public OpExprInterpState {
bool requires_grad;
std::vector<int32_t> dims;
};
class Flip : public OpExprGradFunction<FlipInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FlipInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const FlipInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Flip::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Flip::Capture(FlipInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dims"));
return Maybe<void>::Ok();
}
Maybe<void> Flip::Apply(const FlipInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::FlipGrad(out_grads.at(0), ctx->dims));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("flip", Flip);
} // namespace one
} // namespace oneflow
...@@ -470,6 +470,14 @@ ...@@ -470,6 +470,14 @@
signature: "Tensor Copy(Tensor x, *, String device_type, Int64 device_id)" signature: "Tensor Copy(Tensor x, *, String device_type, Int64 device_id)"
bind_python: True bind_python: True
- name: "flip"
signature: "Tensor Flip(Tensor x, *, Int32List dims)"
bind_python: True
- name: "flip_grad"
signature: "Tensor FlipGrad(Tensor dy, *, Int32List dims)"
bind_python: False
- name: "upsample" - name: "upsample"
signature: signature:
"Tensor Upsample(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners, "Tensor Upsample(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners,
......
...@@ -436,6 +436,36 @@ class CopyFunctor { ...@@ -436,6 +436,36 @@ class CopyFunctor {
std::shared_ptr<OpExpr> op_; std::shared_ptr<OpExpr> op_;
}; };
class FlipFunctor {
public:
FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); }
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::vector<int32_t>& dims) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("dims", dims));
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class FlipGradFunctor {
public:
FlipGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("flip_grad").Input("dy").Output("dx").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::vector<int32_t>& dims) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("dims", dims));
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class UpsampleFunctor { class UpsampleFunctor {
public: public:
UpsampleFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample").Input("x").Output("y").Build()); } UpsampleFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample").Input("x").Output("y").Build()); }
...@@ -923,6 +953,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ...@@ -923,6 +953,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SliceUpdateFunctor>("SliceUpdate"); m.add_functor<impl::SliceUpdateFunctor>("SliceUpdate");
m.add_functor<impl::SqueezeFunctor>("Squeeze"); m.add_functor<impl::SqueezeFunctor>("Squeeze");
m.add_functor<impl::CopyFunctor>("Copy"); m.add_functor<impl::CopyFunctor>("Copy");
m.add_functor<impl::FlipFunctor>("Flip");
m.add_functor<impl::FlipGradFunctor>("FlipGrad");
m.add_functor<impl::UpsampleFunctor>("Upsample"); m.add_functor<impl::UpsampleFunctor>("Upsample");
m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D"); m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D");
m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad"); m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad");
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import collections
from typing import Optional, Sequence, Union
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import register_tensor_op
from oneflow.python.nn.modules.utils import _check_axis
class Flip(Module):
def __init__(self, dims) -> None:
super().__init__()
assert isinstance(dims, (list, tuple)), f"dims must be list or tuple"
self.dims = dims
def forward(self, x):
input_len = len(x.shape)
assert (
len(self.dims) <= input_len
), f"len of dims must less than len of input tensor"
new_dims = []
for i in self.dims:
if i < 0:
i += input_len
assert (
i < input_len
), f"IndexError: Dimension out of range (expected to be in range of {input_len}, but got {i})"
new_dims.append(i)
return flow.F.flip(x, new_dims)
@oneflow_export("flip")
@experimental_api
def flip_op(input, dims):
r"""
Reverse the order of a n-D tensor along given axis in dims.
.. note::
`flow.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`,
which returns a view in constant time. Since copying a tensor's data is more work than viewing that data,
`flow.flip` is expected to be slower than `np.flip`.
Args:
input (Tensor): the input tensor
dims (a list or tuple): axis to flip on
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> np_arr = np.arange(0, 8).reshape((2, 2, 2)).astype(np.float32)
>>> input = flow.Tensor(np_arr)
>>> out = flow.flip(input, [0, 1])
>>> out
tensor([[[6., 7.],
[4., 5.]],
<BLANKLINE>
[[2., 3.],
[0., 1.]]], dtype=oneflow.float32)
"""
return Flip(dims)(input)
@register_tensor_op("flip")
@experimental_api
def flip_op_tensor(input, dims):
r"""
See :func:`oneflow.experimental.flip`
"""
return Flip(dims)(input)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
from automated_test_util import *
def _test_flip(test_case, device):
np_arr = np.arange(0, 16).reshape((2, 2, 2, 2)).astype(np.float32)
input = flow.Tensor(np_arr, device=flow.device(device), requires_grad=True)
out = flow.flip(input, [0, 1, 2])
np_out = [
[[[14.0, 15.0], [12.0, 13.0]], [[10.0, 11.0], [8.0, 9.0]]],
[[[6.0, 7.0], [4.0, 5.0]], [[2.0, 3.0], [0.0, 1.0]]],
]
test_case.assertTrue(np.allclose(out.numpy(), np_out, 1e-5, 1e-5))
out = out.sum()
out = out.backward()
np_grad = np.ones_like(np_arr)
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
class TestFlip(flow.unittest.TestCase):
def test_flip(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_flip,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
namespace oneflow {
namespace {
const int32_t NDIMS = 16;
struct SIZE_V {
int32_t val[NDIMS];
};
struct VIS {
bool val[NDIMS] = {false};
};
template<typename T>
void FlipCpuForward(const int32_t element, const int64_t total_dims,
const SIZE_V stride_contiguous_v, const SIZE_V sizes_v, const VIS vis,
SIZE_V strides_v, const T* in_dptr, T* out_dptr) {
for (int i = 0; i < element; i++) {
int32_t cur_indices = i;
int32_t rem = 0;
int32_t dst_offset = 0;
for (int32_t d = 0; d < total_dims; d++) {
int32_t temp = cur_indices;
cur_indices = cur_indices / stride_contiguous_v.val[d];
rem = temp - cur_indices * stride_contiguous_v.val[d];
dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]
: cur_indices * strides_v.val[d];
cur_indices = rem;
}
out_dptr[i] = in_dptr[dst_offset];
}
}
} // namespace
template<typename T>
class FlipCpuKernel final : public user_op::OpKernel {
public:
FlipCpuKernel() = default;
~FlipCpuKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
const int32_t elem_cnt = y_tensor->shape().elem_cnt();
const int32_t total_dims = y_tensor->shape().NumAxes();
std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims");
VIS vis;
for (auto x : dims) { vis.val[x] = true; }
SIZE_V sizes_v;
for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); }
// TODO(bbuf) delete strides caluculate, after tensor strides supported
SIZE_V strides_v;
strides_v.val[total_dims - 1] = 1;
for (int32_t i = total_dims - 2; i >= 0; i--) {
strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape().At(i);
}
SIZE_V stride_contiguous_v;
for (int32_t i = total_dims - 1; i >= 0; i--) {
if (i == total_dims - 1) {
stride_contiguous_v.val[i] = 1;
} else {
stride_contiguous_v.val[i] =
std::max<int32_t>(x_tensor->shape().At(i + 1), 1) * stride_contiguous_v.val[i + 1];
}
}
FlipCpuForward(elem_cnt, total_dims, stride_contiguous_v, sizes_v, vis, strides_v,
x_tensor->dptr<T>(), y_tensor->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
template<typename T>
class FlipGrad1DCpuKernel final : public user_op::OpKernel {
public:
FlipGrad1DCpuKernel() = default;
~FlipGrad1DCpuKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape().elem_cnt() * sizeof(T));
const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
memcpy((void*)dx_tensor->mut_dptr<T>(), (void*)dy_tensor->dptr<T>(),
dy_tensor->shape().elem_cnt() * sizeof(T));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_FLIP_CPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("flip").SetCreateFn<FlipCpuKernel<dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == "cpu") \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("flip_grad") \
.SetCreateFn<FlipGrad1DCpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == "cpu") \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
REGISTER_FLIP_CPU_KERNEL(float)
REGISTER_FLIP_CPU_KERNEL(double)
REGISTER_FLIP_CPU_KERNEL(int)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
namespace oneflow {
namespace {
const int32_t NDIMS = 16;
struct SIZE_V {
int32_t val[NDIMS];
};
struct VIS {
bool val[NDIMS] = {false};
};
template<typename T>
__global__ void FlipGpuForward(const int32_t element, const int64_t total_dims,
const SIZE_V stride_contiguous_v, const SIZE_V sizes_v,
const VIS vis, SIZE_V strides_v, const T* in_dptr, T* out_dptr) {
CUDA_1D_KERNEL_LOOP(i, element) {
int32_t cur_indices = i;
int32_t rem = 0;
int32_t dst_offset = 0;
for (int32_t d = 0; d < total_dims; d++) {
int32_t temp = cur_indices;
cur_indices = cur_indices / stride_contiguous_v.val[d];
rem = temp - cur_indices * stride_contiguous_v.val[d];
dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]
: cur_indices * strides_v.val[d];
cur_indices = rem;
}
out_dptr[i] = in_dptr[dst_offset];
}
}
} // namespace
template<typename T>
class FlipGpuKernel final : public user_op::OpKernel {
public:
FlipGpuKernel() = default;
~FlipGpuKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
const int32_t elem_cnt = y_tensor->shape().elem_cnt();
const int32_t total_dims = y_tensor->shape().NumAxes();
std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims");
VIS vis;
for (auto x : dims) { vis.val[x] = true; }
SIZE_V sizes_v;
for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); }
// TODO(bbuf) delete strides caluculate, after tensor strides supported
SIZE_V strides_v;
strides_v.val[total_dims - 1] = 1;
for (int32_t i = total_dims - 2; i >= 0; i--) {
strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape().At(i);
}
SIZE_V stride_contiguous_v;
for (int32_t i = total_dims - 1; i >= 0; i--) {
if (i == total_dims - 1) {
stride_contiguous_v.val[i] = 1;
} else {
stride_contiguous_v.val[i] =
std::max<int32_t>(x_tensor->shape().At(i + 1), 1) * stride_contiguous_v.val[i + 1];
}
}
RUN_CUDA_KERNEL((FlipGpuForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt, total_dims,
stride_contiguous_v, sizes_v, vis, strides_v, x_tensor->dptr<T>(),
y_tensor->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
template<typename T>
class FlipGrad1DGpuKernel final : public user_op::OpKernel {
public:
FlipGrad1DGpuKernel() = default;
~FlipGrad1DGpuKernel() = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape().elem_cnt() * sizeof(T));
const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
Memcpy<DeviceType::kGPU>(
ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_FLIP_GPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("flip").SetCreateFn<FlipGpuKernel<dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == "gpu") \
& (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
REGISTER_USER_KERNEL("flip_grad") \
.SetCreateFn<FlipGrad1DGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \
& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
REGISTER_FLIP_GPU_KERNEL(float)
REGISTER_FLIP_GPU_KERNEL(double)
REGISTER_FLIP_GPU_KERNEL(int)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
namespace oneflow {
REGISTER_USER_OP("flip")
.Input("x")
.Output("y")
.Attr<std::vector<int32_t>>("dims")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const int input_dims = x_desc->shape().NumAxes();
const std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims");
CHECK_OR_RETURN(dims.size() <= input_dims)
<< "len of dims must less than len of input tensor";
for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; }
user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
*y_desc->mut_shape() = x_desc->shape();
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
return Maybe<void>::Ok();
});
REGISTER_USER_OP("flip_grad")
.Input("dy")
.Output("dx")
.Attr<std::vector<int32_t>>("dims")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const Shape& dy_shape = ctx->InputShape("dy", 0);
Shape* dx_shape = ctx->OutputShape("dx", 0);
*dx_shape = dy_shape;
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
return Maybe<void>::Ok();
});
} // namespace oneflow
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment