diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index 521bdd0183ef6e422b5b30b5b7ba994838db2f98..743f904e92c4a997ee8b032f4b310a44adbc86ac 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -29,6 +29,8 @@ oneflow reshape, save, saved_model, + scatter, + scatter_add, scatter_nd, selu, silu, diff --git a/oneflow/core/autograd/gradient_funcs/dim_gather.cpp b/oneflow/core/autograd/gradient_funcs/dim_gather.cpp index 09a2d475de7f13cc6dca07e8bb28373869ca2472..4ae5b63e960e61d62058ff6ef95944fcb9cf0d53 100644 --- a/oneflow/core/autograd/gradient_funcs/dim_gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/dim_gather.cpp @@ -72,7 +72,7 @@ Maybe<void> DimGather::Apply(const DimGatherInterpState* ctx, const TensorTuple& MutableAttrMap attrs; JUST(attrs.SetAttr<int32_t>("dim", ctx->dim)); in_grads->at(0) = JUST( - OpInterpUtil::Dispatch<Tensor>(*bw_dim_gather_op_, {like, out_grads.at(0), index}, attrs)); + OpInterpUtil::Dispatch<Tensor>(*bw_dim_gather_op_, {like, index, out_grads.at(0)}, attrs)); return Maybe<void>::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp b/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6bda00e3abc6ddba5289799dc4f9bbabd2f98755 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp @@ -0,0 +1,176 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct DimScatterInterpState : public OpExprInterpState { + int32_t dim; + bool input_requires_grad; + bool src_requires_grad; +}; + +enum SCATTER_TYPE { SCATTER_UPDATE, SCATTER_ADD }; + +template<SCATTER_TYPE T> +class DimScatter : public OpExprGradFunction<DimScatterInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override; + Maybe<void> Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + Maybe<void> ApplyCommon(const DimScatterInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const; + + private: + AttrMap base_attrs_; +}; + +template<SCATTER_TYPE T> +Maybe<void> DimScatter<T>::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe<void>::Ok(); +} + +template<SCATTER_TYPE T> +Maybe<void> DimScatter<T>::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + CHECK_EQ_OR_RETURN(inputs.size(), 3); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + + ctx->input_requires_grad = inputs.at(0)->requires_grad(); + ctx->src_requires_grad = inputs.at(2)->requires_grad(); + if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); } + + ctx->SaveTensorForBackward(inputs.at(1)); // index saved + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim")); + return Maybe<void>::Ok(); +} + +template<SCATTER_TYPE T> +Maybe<void> DimScatter<T>::ApplyCommon(const DimScatterInterpState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { + const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0); + + in_grads->resize(3); + + if (ctx->src_requires_grad) { + in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), index, ctx->dim)); + } + return Maybe<void>::Ok(); +} + +template<> +Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_UPDATE>::Apply(const DimScatterInterpState* ctx, + const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); } + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + JUST(ApplyCommon(ctx, out_grads, in_grads)); + + if (ctx->input_requires_grad) { + const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0); + in_grads->at(0) = + JUST(functional::DimScatterUpdateScalar(out_grads.at(0), index, 0.0f, ctx->dim)); + } + return Maybe<void>::Ok(); +} + +template<> +Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_ADD>::Apply(const DimScatterInterpState* ctx, + const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); } + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + + JUST(ApplyCommon(ctx, out_grads, in_grads)); + + if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); } + + return Maybe<void>::Ok(); +} + +class DimScatterUpdateScalar : public OpExprGradFunction<DimScatterInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override; + Maybe<void> Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe<void> DimScatterUpdateScalar::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + + return Maybe<void>::Ok(); +} + +Maybe<void> DimScatterUpdateScalar::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, + const AttrMap& attrs) const { + CHECK_EQ_OR_RETURN(inputs.size(), 2); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + + ctx->input_requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); } + + ctx->SaveTensorForBackward(inputs.at(1)); // index saved + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim")); + return Maybe<void>::Ok(); +} + +Maybe<void> DimScatterUpdateScalar::Apply(const DimScatterInterpState* ctx, + const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); } + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0); + + in_grads->resize(2); + + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", ctx->dim)); + JUST(attrs.SetAttr<float>("src_scalar", 0.0f)); + in_grads->at(0) = + JUST(functional::DimScatterUpdateScalar(out_grads.at(0), index, 0.0f, ctx->dim);); + + return Maybe<void>::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter<SCATTER_TYPE::SCATTER_UPDATE>); +REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter<SCATTER_TYPE::SCATTER_ADD>); +REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update_scalar", DimScatterUpdateScalar); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index ccb56d84c5c279283822d5c0821479d9c8e4d3df..9a951cba7008b4ca9ee5c5c97eb8d22d21c368f6 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -637,8 +637,8 @@ Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim) { Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim, const std::string& name) { return one::OpBuilder("dim_scatter_add_like", name) .Input("like") - .Input("input") .Input("index") + .Input("src") .Output("output") .Attr<int32_t>("dim", dim) .Build(); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index f69e203e264bc6293ff6f38936c80b02dd5ad5ee..d2105c285b2b408568e39ebc85bfcb5559a181b8 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -775,6 +775,22 @@ signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)" bind_python: True +- name: "dim_scatter" + signature: "Tensor DimScatter(Tensor input, Tensor index, Tensor src, *, Int32 dim)" + bind_python: True + +- name: "dim_scatter_add" + signature: "Tensor DimScatterAdd(Tensor input, Tensor index, Tensor src, *, Int32 dim)" + bind_python: True + +- name: "dim_scatter_scalar" + signature: "Tensor DimScatterUpdateScalar(Tensor input, Tensor index, *, Float src, Int32 dim)" + bind_python: True + +- name: "dim_scatter_add_scalar" + signature: "Tensor DimScatterAddScalar(Tensor input, Tensor index, *, Float src, Int32 dim)" + bind_python: True + - name: "tensor_setitem" signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)" bind_python: True diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 730b7b868c414382572021e22b612a7ddbf7c4fe..6da924fc614ac4779b441e77549571de1e8ab665 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -344,6 +344,138 @@ class DimGatherFunctor { std::shared_ptr<OpExpr> op_; }; +class DimScatterFunctor { + public: + DimScatterFunctor() { + op_ = CHECK_JUST(one::OpBuilder("dim_scatter_update") + .Input("input") + .Input("index") + .Input("src") + .Output("output") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, + const std::shared_ptr<one::Tensor>& index, + const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", dim)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class DimScatterAddFunctor { + public: + DimScatterAddFunctor() { + op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add") + .Input("input") + .Input("index") + .Input("src") + .Output("output") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, + const std::shared_ptr<one::Tensor>& index, + const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", dim)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class DimScatterMulFunctor { + public: + DimScatterMulFunctor() { + op_ = CHECK_JUST(one::OpBuilder("dim_scatter_mul") + .Input("input") + .Input("index") + .Input("src") + .Output("output") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, + const std::shared_ptr<one::Tensor>& index, + const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", dim)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class DimScatterUpdateScalarFunctor { + public: + DimScatterUpdateScalarFunctor() { + op_ = CHECK_JUST(one::OpBuilder("dim_scatter_update_scalar") + .Input("input") + .Input("index") + .Output("output") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, + const std::shared_ptr<one::Tensor>& index, const float& src, + const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", dim)); + JUST(attrs.SetAttr<float>("src_scalar", src)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class DimScatterAddScalarFunctor { + public: + DimScatterAddScalarFunctor() { + op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add_scalar") + .Input("input") + .Input("index") + .Output("output") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, + const std::shared_ptr<one::Tensor>& index, const float& src, + const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", dim)); + JUST(attrs.SetAttr<float>("src_scalar", src)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class DimScatterMulScalarFunctor { + public: + DimScatterMulScalarFunctor() { + op_ = CHECK_JUST(one::OpBuilder("dim_scatter_mul_scalar") + .Input("input") + .Input("index") + .Output("output") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input, + const std::shared_ptr<one::Tensor>& index, const float& src, + const int32_t& dim) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("dim", dim)); + JUST(attrs.SetAttr<float>("src_scalar", src)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + class GatherNdFunctor { public: GatherNdFunctor() { @@ -1153,6 +1285,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::DiagFunctor>("Diag"); m.add_functor<impl::DiagGradFunctor>("DiagGrad"); m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem"); + m.add_functor<impl::DimScatterFunctor>("DimScatter"); + m.add_functor<impl::DimScatterAddFunctor>("DimScatterAdd"); + m.add_functor<impl::DimScatterMulFunctor>("DimScatterMul"); + m.add_functor<impl::DimScatterUpdateScalarFunctor>("DimScatterUpdateScalar"); + m.add_functor<impl::DimScatterAddScalarFunctor>("DimScatterAddScalar"); + m.add_functor<impl::DimScatterMulScalarFunctor>("DimScatterMulScalar"); m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem"); m.add_functor<impl::ElementwiseMinimumGradFunctor>("ElementwiseMinGrad"); m.add_functor<impl::ElementwiseMaximumGradFunctor>("ElementwiseMaxGrad"); diff --git a/oneflow/user/kernels/dim_gather_kernel_util.cpp b/oneflow/user/kernels/dim_gather_kernel_util.cpp index d12d10a9e6796914386cd1d7e0727aa9222fb003..8262c70ba9100a46a93bd0e2e76d07142fb04290 100644 --- a/oneflow/user/kernels/dim_gather_kernel_util.cpp +++ b/oneflow/user/kernels/dim_gather_kernel_util.cpp @@ -30,20 +30,7 @@ struct DimGatherFunctor<DeviceType::kCPU, IN_T, IDX_T> final { } }; -template<typename IN_T, typename IDX_T> -struct DimScatterAddFunctor<DeviceType::kCPU, IN_T, IDX_T> final { - void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper, - const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt, - int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) { - DoDimScatterAdd<IN_T, IDX_T>(input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index, - input, output); - } -}; - OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCPU), DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ); -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR, (DeviceType::kCPU), - DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ); - } // namespace user_op } // namespace oneflow diff --git a/oneflow/user/kernels/dim_gather_kernel_util.cu b/oneflow/user/kernels/dim_gather_kernel_util.cu index 767023838c5b06c9dbdf3bc27ba37d1944085250..c7b228aa893be2ea22d77cc7c9d451c4eae314d6 100644 --- a/oneflow/user/kernels/dim_gather_kernel_util.cu +++ b/oneflow/user/kernels/dim_gather_kernel_util.cu @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include <cstdint> #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" @@ -53,41 +52,8 @@ struct DimGatherFunctor<DeviceType::kGPU, float16, IDX_T> final { } }; -template<typename IN_T, typename IDX_T> -__global__ void DoCUDAScatterDimAdd(const DimOpIndexNdHelper<IDX_T> input_nd_helper, - const DimOpIndexNdHelper<IDX_T> output_nd_helper, int ndim, - int64_t elem_cnt, int32_t dim, const IDX_T* index, - const IN_T* input, IN_T* output) { - DoDimScatterAdd<IN_T, IDX_T>(input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index, input, - output); -} - -template<typename IN_T, typename IDX_T> -struct DimScatterAddFunctor<DeviceType::kGPU, IN_T, IDX_T> final { - void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper, - const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt, - int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) { - RUN_CUDA_KERNEL((DoCUDAScatterDimAdd<IN_T, IDX_T>), ctx, BlocksNum4ThreadsNum(elem_cnt), - input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index, input, output); - } -}; - -// float16 special case of DimScatterAddFunctor template -template<typename IDX_T> -struct DimScatterAddFunctor<DeviceType::kGPU, float16, IDX_T> final { - void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper, - const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt, - int32_t dim, const IDX_T* index, const float16* input, float16* output) { - RUN_CUDA_KERNEL((DoCUDAScatterDimAdd<half, IDX_T>), ctx, BlocksNum4ThreadsNum(elem_cnt), - input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index, - reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output)); - } -}; - OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kGPU), DIM_GATHER_SCATTER_DATA_TYPE_GPU_SEQ, INDEX_DATA_TYPE_SEQ); -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR, (DeviceType::kGPU), - DIM_GATHER_SCATTER_DATA_TYPE_GPU_SEQ, INDEX_DATA_TYPE_SEQ); } // namespace user_op } // namespace oneflow diff --git a/oneflow/user/kernels/dim_gather_kernel_util.h b/oneflow/user/kernels/dim_gather_kernel_util.h index 4c52a9235262f8393089bdbd0b29eeb0348cc54c..6a12dcc0fe757b44d9d736d9ce830434969706e2 100644 --- a/oneflow/user/kernels/dim_gather_kernel_util.h +++ b/oneflow/user/kernels/dim_gather_kernel_util.h @@ -99,10 +99,6 @@ OF_DEVICE_FUNC void DoDimScatterAdd(const DimOpIndexNdHelper<IDX_T>& input_nd_he template struct DimGatherFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ OF_PP_PAIR_FIRST(itype_pair)>; -#define INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR(device_type_v, dtype_pair, itype_pair) \ - template struct DimScatterAddFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ - OF_PP_PAIR_FIRST(itype_pair)>; - } // namespace user_op } // namespace oneflow diff --git a/oneflow/user/kernels/dim_gather_kernels.cpp b/oneflow/user/kernels/dim_gather_kernels.cpp index d5616b9673b75c4adb258b4ae5d6a154c051250c..fc37bcff7235c81b655c5d168a30caec4366f14e 100644 --- a/oneflow/user/kernels/dim_gather_kernels.cpp +++ b/oneflow/user/kernels/dim_gather_kernels.cpp @@ -66,44 +66,6 @@ class DimGatherKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -template<DeviceType device_type, typename IN_T, typename IDX_T> -class ScatterDimKernel final : public user_op::OpKernel { - public: - ScatterDimKernel() = default; - ~ScatterDimKernel() override = default; - - private: - void Compute(KernelComputeContext* ctx) const override { - const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); - const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); - Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); - const int32_t dim = ctx->Attr<int32_t>("dim"); - - const IN_T* src = input_tensor->dptr<IN_T>(); - const IDX_T* index = index_tensor->dptr<IDX_T>(); - IN_T* output = out_tensor->mut_dptr<IN_T>(); - size_t out_bytes_size = - out_tensor->shape().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()); - Memset<device_type>(ctx->device_ctx(), output, 0, out_bytes_size); - - int ndim = input_tensor->shape().NumAxes(); - fixed_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim); - auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void { - std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(), - [](int64_t dim) -> IDX_T { return static_cast<IDX_T>(dim); }); - }; - shape2dims(input_tensor->shape()); - DimOpIndexNdHelper<IDX_T> input_nd_helper(shape_vec.data(), ndim); - shape2dims(out_tensor->shape()); - DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim); - - DimScatterAddFunctor<device_type, IN_T, IDX_T>()( - ctx->device_ctx(), input_nd_helper, output_nd_helper, ndim, - input_tensor->shape().elem_cnt(), dim, index, src, output); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - #define REGISTER_DIM_GATHER_KERNEL(device, dtype, itype) \ REGISTER_USER_KERNEL("dim_gather") \ .SetCreateFn<DimGatherKernel<device, dtype, itype>>() \ @@ -111,13 +73,6 @@ class ScatterDimKernel final : public user_op::OpKernel { & (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \ & (user_op::HobDataType("index", 0) == GetDataType<itype>::value)); -#define REGISTER_DIM_SCATTER_KERNEL(device, dtype, itype) \ - REGISTER_USER_KERNEL("dim_scatter_add_like") \ - .SetCreateFn<ScatterDimKernel<device, dtype, itype>>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ - & (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \ - & (user_op::HobDataType("index", 0) == GetDataType<itype>::value)); - #define REGISTER_DIM_GATHER_KERNELS_WITH_DEVICE(device) \ REGISTER_DIM_GATHER_KERNEL(device, float, int32_t) \ REGISTER_DIM_GATHER_KERNEL(device, double, int32_t) \ @@ -126,23 +81,11 @@ class ScatterDimKernel final : public user_op::OpKernel { REGISTER_DIM_GATHER_KERNEL(device, double, int64_t) \ REGISTER_DIM_GATHER_KERNEL(device, int32_t, int64_t) -#define REGISTER_DIM_SCATTER_ADD_LIKE_KERNELS_WITH_DEVICE(device) \ - REGISTER_DIM_SCATTER_KERNEL(device, float, int32_t) \ - REGISTER_DIM_SCATTER_KERNEL(device, double, int32_t) \ - REGISTER_DIM_SCATTER_KERNEL(device, int32_t, int32_t) \ - REGISTER_DIM_SCATTER_KERNEL(device, float, int64_t) \ - REGISTER_DIM_SCATTER_KERNEL(device, double, int64_t) \ - REGISTER_DIM_SCATTER_KERNEL(device, int32_t, int64_t) - REGISTER_DIM_GATHER_KERNELS_WITH_DEVICE(DeviceType::kCPU); -REGISTER_DIM_SCATTER_ADD_LIKE_KERNELS_WITH_DEVICE(DeviceType::kCPU); #ifdef WITH_CUDA REGISTER_DIM_GATHER_KERNELS_WITH_DEVICE(DeviceType::kGPU); -REGISTER_DIM_SCATTER_ADD_LIKE_KERNELS_WITH_DEVICE(DeviceType::kGPU); REGISTER_DIM_GATHER_KERNEL(DeviceType::kGPU, float16, int32_t); -REGISTER_DIM_SCATTER_KERNEL(DeviceType::kGPU, float16, int32_t); -REGISTER_DIM_SCATTER_KERNEL(DeviceType::kGPU, float16, int64_t); REGISTER_DIM_GATHER_KERNEL(DeviceType::kGPU, float16, int64_t); #endif // WITH_CUDA diff --git a/oneflow/user/kernels/dim_scatter_kernel_util.cpp b/oneflow/user/kernels/dim_scatter_kernel_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..796ba608f15b8934cd2d0e19df4f5f8599aac972 --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_kernel_util.cpp @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/dim_scatter_kernel_util.h" + +namespace oneflow { +namespace user_op { + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +struct DimScatterFunctor<DeviceType::kCPU, IN_T, IDX_T, Opt> final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper, + const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, + const IDX_T* index, const IN_T* src, IN_T* output) { + DoDimScatter<IN_T, IDX_T, Opt>(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, + dim, upper_bound, index, src, output); + } +}; + +INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kCPU, BinOpAddFunctor); +INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kCPU, BinOpUpdateFunctor); + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/dim_scatter_kernel_util.cu b/oneflow/user/kernels/dim_scatter_kernel_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..151a329b3007e152f97a27ecd1ad01f155106752 --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_kernel_util.cu @@ -0,0 +1,66 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifdef WITH_CUDA +#include "oneflow/user/kernels/dim_scatter_kernel_util.h" + +namespace oneflow { +namespace user_op { + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +__global__ void DoCUDADimScatter(const DimOpIndexNdHelper<IDX_T> src_nd_helper, + const DimOpIndexNdHelper<IDX_T> idx_nd_helper, + const DimOpIndexNdHelper<IDX_T> output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, + const int64_t upper_bound, const IDX_T* index, const IN_T* src, + IN_T* output) { + DoDimScatter<IN_T, IDX_T, Opt>(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, + dim, upper_bound, index, src, output); +} + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +struct DimScatterFunctor<DeviceType::kGPU, IN_T, IDX_T, Opt> final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper, + const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, + const IDX_T* index, const IN_T* src, IN_T* output) { + RUN_CUDA_KERNEL((DoCUDADimScatter<IN_T, IDX_T, Opt>), ctx, BlocksNum4ThreadsNum(elem_cnt), + src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, + upper_bound, index, src, output); + } +}; + +template<typename IDX_T, template<typename T> class Opt> +struct DimScatterFunctor<DeviceType::kGPU, float16, IDX_T, Opt> final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper, + const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, + const IDX_T* index, const float16* src, float16* output) { + RUN_CUDA_KERNEL((DoCUDADimScatter<half, IDX_T, Opt>), ctx, BlocksNum4ThreadsNum(elem_cnt), + src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, + upper_bound, index, reinterpret_cast<const half*>(src), + reinterpret_cast<half*>(output)); + } +}; + +INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kGPU, BinOpAddFunctor); +INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kGPU, BinOpUpdateFunctor); + +} // namespace user_op +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/dim_scatter_kernel_util.h b/oneflow/user/kernels/dim_scatter_kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..34ff5cd31d3626883a9841d3c3fe48f3f25cf02c --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_kernel_util.h @@ -0,0 +1,100 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_ +#ifdef WITH_CUDA +#include "oneflow/core/cuda/atomic.cuh" +#endif // WITH_CUDA + +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/shape_view.h" +#include "oneflow/core/common/error.pb.h" + +namespace oneflow { + +namespace user_op { + +constexpr int kDimGatherMaxDimCount = 8; + +template<typename T> +using DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>; + +#define INSTANTIATE_DIM_SCATTER_FUNCTORS(device_type, opt) \ + template struct DimScatterFunctor<device_type, int32_t, int32_t, opt>; \ + template struct DimScatterFunctor<device_type, float, int32_t, opt>; \ + template struct DimScatterFunctor<device_type, double, int32_t, opt>; \ + template struct DimScatterFunctor<device_type, int32_t, int64_t, opt>; \ + template struct DimScatterFunctor<device_type, float, int64_t, opt>; \ + template struct DimScatterFunctor<device_type, double, int64_t, opt>; + +template<typename T> +struct BinOpAddFunctor { + OF_DEVICE_FUNC static void apply(const T* x, T* y) { +#ifdef __CUDA_ARCH__ + cuda::atomic::Add(y, *x); +#else + *y += *x; +#endif + } +}; + +template<typename T> +struct BinOpUpdateFunctor { + OF_DEVICE_FUNC static void apply(const T* x, T* y) { *y = *x; } +}; + +template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt> +struct DimScatterFunctor final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper, + const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound, + const IDX_T* index, const IN_T* src, IN_T* output); +}; + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +OF_DEVICE_FUNC void DoDimScatter(const DimOpIndexNdHelper<IDX_T>& src_nd_helper, + const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, + const IDX_T* index, const IN_T* src, IN_T* output) { + XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) { + IDX_T coordinate[kDimGatherMaxDimCount] = {0}; + idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim); // idx_offset -> ijk + IDX_T idx_elem = index[idx_offset]; + if (idx_elem >= upper_bound) { +#if __CUDA_ARCH__ + __trap(); +#else + std::cout << "The index element " << idx_elem << " is out of bounds for dimension " << dim + << " with size " << upper_bound << std::endl; + throw Error::CheckFailedError(); +#endif + } + IDX_T src_offset = src_nd_helper.NdIndexToOffset(coordinate, ndim); + coordinate[dim] = idx_elem; + IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim); + Opt<IN_T>::apply(src + src_offset, output + output_offset); + } +} + +} // namespace user_op +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_ diff --git a/oneflow/user/kernels/dim_scatter_kernels.cpp b/oneflow/user/kernels/dim_scatter_kernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..087d0812ba3292c810cda9524d65d6dedb4949c0 --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_kernels.cpp @@ -0,0 +1,138 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/error.pb.h" +#include "oneflow/core/common/util.h" +#include "oneflow/user/kernels/dim_scatter_kernel_util.h" + +namespace oneflow { +namespace user_op { + +template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt> +class DimScatterKernel final : public user_op::OpKernel { + public: + DimScatterKernel() = default; + ~DimScatterKernel() override = default; + + private: + void Compute(KernelComputeContext* ctx) const override { + const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); + const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); + Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); + const Tensor* src_tensor = ctx->Tensor4ArgNameAndIndex("src", 0); + const int32_t dim = ctx->Attr<int32_t>("dim"); + + const IDX_T* index = index_tensor->dptr<IDX_T>(); + IN_T* output = out_tensor->mut_dptr<IN_T>(); + size_t out_bytes_size = + out_tensor->shape().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()); + + Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0); + const IN_T* src = src_tensor->dptr<IN_T>(); + + if (input_tensor) { + Memcpy<device_type>(ctx->device_ctx(), output, input_tensor->dptr<IN_T>(), out_bytes_size); + } else if (like_tensor) { + Memset<device_type>(ctx->device_ctx(), output, 0, out_bytes_size); + } else { + std::cout << "Unimplemented Error" << std::endl; + throw Error::Unimplemented(); + } + + const int ndim = src_tensor->shape().NumAxes(); + fixed_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim); + auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void { + std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(), + [](int32_t dim) -> IDX_T { return static_cast<IDX_T>(dim); }); + }; + shape2dims(src_tensor->shape()); + DimOpIndexNdHelper<IDX_T> src_nd_helper(shape_vec.data(), ndim); + shape2dims(index_tensor->shape()); + DimOpIndexNdHelper<IDX_T> idx_nd_helper(shape_vec.data(), ndim); + shape2dims(out_tensor->shape()); + DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim); + + int64_t upper_bound = 0; + if (input_tensor) { + upper_bound = input_tensor->shape().At(dim); // ensure the idx is smaller than upperbound + } else { + upper_bound = like_tensor->shape().At(dim); // ensure the idx is smaller than upperbound + } + + DimScatterFunctor<device_type, IN_T, IDX_T, Opt>()( + ctx->device_ctx(), src_nd_helper, idx_nd_helper, output_nd_helper, ndim, + index_tensor->shape().elem_cnt(), dim, upper_bound, index, src, output); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, device, dtype, itype, opt) \ + REGISTER_USER_KERNEL(op_type) \ + .SetCreateFn<DimScatterKernel<device, dtype, itype, opt>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("like", 0) == GetDataType<dtype>::value) \ + & (user_op::HobDataType("index", 0) == GetDataType<itype>::value)); + +#define REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS(op_type, opt) \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int32_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int32_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int32_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int64_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int64_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int64_t, opt); + +#define REGISTER_DIM_SCATTER_LIKE_GPU_KERNELS(op_type, opt) \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, float, int32_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, double, int32_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, int32_t, int32_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, float, int64_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, double, int64_t, opt); \ + REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, int32_t, int64_t, opt); + +#define REGISTER_DIM_SCATTER_KERNEL(op_type, device, dtype, itype, opt) \ + REGISTER_USER_KERNEL(op_type) \ + .SetCreateFn<DimScatterKernel<device, dtype, itype, opt>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \ + & (user_op::HobDataType("index", 0) == GetDataType<itype>::value)); + +#define REGISTER_DIM_SCATTER_CPU_KERNELS(op_type, opt) \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, float, int32_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, double, int32_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, int32_t, int32_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, float, int64_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, double, int64_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, int32_t, int64_t, opt); + +#define REGISTER_DIM_SCATTER_GPU_KERNELS(op_type, opt) \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, float, int32_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, double, int32_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, int32_t, int32_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, float, int64_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, double, int64_t, opt); \ + REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, int32_t, int64_t, opt); + +REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS("dim_scatter_add_like", BinOpAddFunctor); +REGISTER_DIM_SCATTER_CPU_KERNELS("dim_scatter_add", BinOpAddFunctor); +REGISTER_DIM_SCATTER_CPU_KERNELS("dim_scatter_update", BinOpUpdateFunctor); + +#ifdef WITH_CUDA +REGISTER_DIM_SCATTER_LIKE_GPU_KERNELS("dim_scatter_add_like", BinOpAddFunctor); +REGISTER_DIM_SCATTER_GPU_KERNELS("dim_scatter_add", BinOpAddFunctor); +REGISTER_DIM_SCATTER_GPU_KERNELS("dim_scatter_update", BinOpUpdateFunctor); +#endif // WITH_CUDA + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp b/oneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df6c91a03239f133cae38dbfdb5294acd50d1e5e --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp @@ -0,0 +1,37 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h" + +namespace oneflow { + +namespace user_op { + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +struct DimScatterScalarFunctor<DeviceType::kCPU, IN_T, IDX_T, Opt> final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, + const IDX_T* index, const IN_T src, IN_T* output) { + DoScatterScalarFunctor<IN_T, IDX_T, Opt>(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, + upper_bound, index, src, output); + } +}; + +INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kCPU, UpdateScalarFunctor); +INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kCPU, AddScalarFunctor); + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/kernels/dim_scatter_scalar_kernel_util.cu b/oneflow/user/kernels/dim_scatter_scalar_kernel_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..abe4b3cd301aee654ff56b1db9734e03d4864902 --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_scalar_kernel_util.cu @@ -0,0 +1,50 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifdef WITH_CUDA +#include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h" + +namespace oneflow { + +namespace user_op { + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +__global__ void DoCUDADimScatterScalar(const DimOpIndexNdHelper<IDX_T> idx_nd_helper, + const DimOpIndexNdHelper<IDX_T> output_nd_helper, + const int ndim, const int64_t elem_cnt, const int32_t dim, + const int64_t upper_bound, const IDX_T* index, + const IN_T src_scalar, IN_T* output) { + DoScatterScalarFunctor<IN_T, IDX_T, Opt>(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, + upper_bound, index, src_scalar, output); +} + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +struct DimScatterScalarFunctor<DeviceType::kGPU, IN_T, IDX_T, Opt> final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, + const IDX_T* index, const IN_T src, IN_T* output) { + RUN_CUDA_KERNEL((DoCUDADimScatterScalar<IN_T, IDX_T, Opt>), ctx, BlocksNum4ThreadsNum(elem_cnt), + idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src, + output); + } +}; + +INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kGPU, UpdateScalarFunctor); +INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kGPU, AddScalarFunctor); + +} // namespace user_op +} // namespace oneflow +#endif diff --git a/oneflow/user/kernels/dim_scatter_scalar_kernel_util.h b/oneflow/user/kernels/dim_scatter_scalar_kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..199cdc7901db71c5a99e95d12f3db9a6776c7bd5 --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_scalar_kernel_util.h @@ -0,0 +1,97 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_ +#ifdef WITH_CUDA +#include "oneflow/core/cuda/atomic.cuh" +#endif // WITH_CUDA +#include "oneflow/core/device/device_context.h" +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/common/data_type.h" + +namespace oneflow { + +namespace user_op { + +constexpr int kDimGatherMaxDimCount = 8; + +template<typename T> +struct AddScalarFunctor { + OF_DEVICE_FUNC static void apply(const T x, T* y) { +#ifdef __CUDA_ARCH__ + cuda::atomic::Add(y, x); +#else + *y += x; +#endif + } +}; + +template<typename T> +struct UpdateScalarFunctor { + OF_DEVICE_FUNC static void apply(const T x, T* y) { *y = x; } +}; + +#define INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(device_type, opt) \ + template struct DimScatterScalarFunctor<device_type, int32_t, int32_t, opt>; \ + template struct DimScatterScalarFunctor<device_type, float, int32_t, opt>; \ + template struct DimScatterScalarFunctor<device_type, double, int32_t, opt>; \ + template struct DimScatterScalarFunctor<device_type, int32_t, int64_t, opt>; \ + template struct DimScatterScalarFunctor<device_type, float, int64_t, opt>; \ + template struct DimScatterScalarFunctor<device_type, double, int64_t, opt>; + +template<typename T> +using DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>; + +template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt> +struct DimScatterScalarFunctor final { + void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim, + const int64_t elem_cnt, const int32_t dim, int64_t upper_bound, + const IDX_T* index, const IN_T src, IN_T* output); +}; + +template<typename IN_T, typename IDX_T, template<typename T> class Opt> +OF_DEVICE_FUNC void DoScatterScalarFunctor(const DimOpIndexNdHelper<IDX_T>& idx_nd_helper, + const DimOpIndexNdHelper<IDX_T>& output_nd_helper, + const int ndim, const int64_t elem_cnt, + const int32_t dim, int64_t upper_bound, + const IDX_T* index, const IN_T src, IN_T* output) { + XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) { + IDX_T coordinate[kDimGatherMaxDimCount] = {0}; + + idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim); // idx_offset -> ijk + IDX_T idx_elem = index[idx_offset]; + if (idx_elem >= upper_bound) { +#if __CUDA_ARCH__ + __trap(); +#else + std::cout << "The index element " << idx_elem << " is out of bounds for dimension " << dim + << " with size " << upper_bound << std::endl; + throw Error::CheckFailedError(); +#endif + } + coordinate[dim] = idx_elem; + IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim); + Opt<IN_T>::apply(src, output + output_offset); + } +} + +} // namespace user_op +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_ diff --git a/oneflow/user/kernels/dim_scatter_scalar_kernels.cpp b/oneflow/user/kernels/dim_scatter_scalar_kernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..154ab129b562aad6140522786a25fdf9ba332a7b --- /dev/null +++ b/oneflow/user/kernels/dim_scatter_scalar_kernels.cpp @@ -0,0 +1,101 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h" + +namespace oneflow { + +namespace user_op { + +template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt> +class DimScatterScalarKernel final : public user_op::OpKernel { + public: + DimScatterScalarKernel() = default; + ~DimScatterScalarKernel() = default; + + private: + void Compute(KernelComputeContext* ctx) const override { + const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0); + const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0); + Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0); + const int32_t dim = ctx->Attr<int32_t>("dim"); + + const IDX_T* index = index_tensor->dptr<IDX_T>(); + IN_T* output = out_tensor->mut_dptr<IN_T>(); + size_t out_bytes_size = + out_tensor->shape().elem_cnt() * GetSizeOfDataType(out_tensor->data_type()); + + Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0); + const IN_T src_scalar = static_cast<IN_T>(ctx->Attr<float>("src_scalar")); + + if (input_tensor) { + Memcpy<device_type>(ctx->device_ctx(), output, input_tensor->dptr<IN_T>(), out_bytes_size); + } else if (like_tensor) { + Memset<device_type>(ctx->device_ctx(), output, 0, out_bytes_size); + } else { + std::cout << "Unimplemented Error" << std::endl; + throw Error::Unimplemented(); + } + + const int ndim = out_tensor->shape().NumAxes(); + fixed_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim); + auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void { + std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(), + [](int32_t dim) -> IDX_T { return static_cast<IDX_T>(dim); }); + }; + shape2dims(index_tensor->shape()); + DimOpIndexNdHelper<IDX_T> idx_nd_helper(shape_vec.data(), ndim); + shape2dims(out_tensor->shape()); + DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim); + + int64_t upper_bound = input_tensor->shape().At(dim); + + DimScatterScalarFunctor<device_type, IN_T, IDX_T, Opt>()( + ctx->device_ctx(), idx_nd_helper, output_nd_helper, ndim, index_tensor->shape().elem_cnt(), + dim, upper_bound, index, src_scalar, output); + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_SCATTERSCALAR_KERNEL(op_type_name, device, dtype, itype, opt) \ + REGISTER_USER_KERNEL(op_type_name) \ + .SetCreateFn<DimScatterScalarKernel<device, dtype, itype, opt>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \ + & (user_op::HobDataType("index", 0) == GetDataType<itype>::value)); + +#define REGISTER_SCATTER_SCALAR_CPU_KERNELS(op_type_name, opt) \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, float, int32_t, opt); \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, float, int64_t, opt); \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, double, int32_t, opt); \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, double, int64_t, opt); + +#define REGISTER_SCATTER_SCALAR_GPU_KERNELS(op_type_name, opt) \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, float, int32_t, opt); \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, float, int64_t, opt); \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, double, int32_t, opt); \ + REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, double, int64_t, opt); + +REGISTER_SCATTER_SCALAR_CPU_KERNELS("dim_scatter_update_scalar", UpdateScalarFunctor); +REGISTER_SCATTER_SCALAR_CPU_KERNELS("dim_scatter_add_scalar", AddScalarFunctor); + +#ifdef WITH_CUDA +REGISTER_SCATTER_SCALAR_GPU_KERNELS("dim_scatter_update_scalar", UpdateScalarFunctor); +REGISTER_SCATTER_SCALAR_GPU_KERNELS("dim_scatter_add_scalar", AddScalarFunctor); +#endif // WITH_CUDA + +} // namespace user_op +} // namespace oneflow diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 17f670f5764f4f06e9fd01ee94bbd4b1f3f389cc..9c490a97985dbfa5dadb9d313df476a35912fe82 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -17,8 +17,8 @@ limitations under the License. #include "oneflow/user/kernels/dim_gather_kernel_util.h" namespace oneflow { - namespace user_op { + REGISTER_USER_OP("dim_gather") .Input("input") .Input("index") @@ -40,11 +40,6 @@ REGISTER_USER_OP("dim_gather") CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); - FOR_RANGE(int64_t, i, 0, input_num_axes) { - if (i == dim) { continue; } - CHECK_EQ_OR_RETURN(in.shape().At(i), index.shape().At(i)); - } - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); *out->mut_shape() = index.shape(); @@ -86,94 +81,10 @@ REGISTER_USER_OP("dim_gather") .Build(); } } - - ctx->NewBuilder() - .PartialSum(user_op::OpArg("input", 0)) - .Broadcast(user_op::OpArg("index", 0)) - .PartialSum(user_op::OpArg("output", 0)) - .Build(); - return Maybe<void>::Ok(); - }); - -REGISTER_USER_OP("dim_scatter_add_like") - .Input("like") - .Input("input") - .Input("index") - .Output("output") - .Attr<int32_t>("dim") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { - const TensorDesc& input = ctx->InputTensorDesc("input", 0); - const TensorDesc& index = ctx->InputTensorDesc("index", 0); - const TensorDesc& like = ctx->InputTensorDesc("like", 0); - - const Shape& like_shape = like.shape(); - - int64_t input_num_axes = input.shape().NumAxes(); - CHECK_GT_OR_RETURN(input_num_axes, 0); - CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount); - - int64_t index_num_axes = index.shape().NumAxes(); - CHECK_EQ_OR_RETURN(input_num_axes, index_num_axes); - CHECK_EQ_OR_RETURN(input_num_axes, like_shape.NumAxes()); - - FOR_RANGE(int64_t, i, 0, input_num_axes) { - CHECK_EQ_OR_RETURN(index.shape().At(i), input.shape().At(i)); - } - - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); - *out->mut_shape() = like_shape; - - return Maybe<void>::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { - const TensorDesc& input = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); - *out->mut_data_type() = input.data_type(); - return Maybe<void>::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe<void> { - user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK_OR_RETURN(like_arg_modifier != nullptr); - like_arg_modifier->set_requires_grad(false); - return Maybe<void>::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { - const user_op::TensorDesc& index_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); - int64_t index_num_axes = index_tensor.shape().NumAxes(); - const int32_t dim = ctx->Attr<int32_t>("dim"); - - FOR_RANGE(int64_t, i, 0, index_num_axes) { - if (i != dim) { - ctx->NewBuilder() - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Split(user_op::OpArg("like", 0), i) - .Build(); - } else { - ctx->NewBuilder() - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("input", 0), i) - .PartialSum(user_op::OpArg("output", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .Build(); - - ctx->NewBuilder() - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("input", 0), i) - .PartialSum(user_op::OpArg("output", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .Build(); - } - } - ctx->NewBuilder() .PartialSum(user_op::OpArg("input", 0)) .Broadcast(user_op::OpArg("index", 0)) .PartialSum(user_op::OpArg("output", 0)) - .PartialSum(user_op::OpArg("like", 0)) .Build(); return Maybe<void>::Ok(); }); @@ -185,10 +96,10 @@ REGISTER_USER_OP_GRAD("dim_gather") ctx->DefineOp(op_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder .OpTypeName( - "dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, input) -> output + "dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, src) -> output .InputBind("index", ctx->FwOp().input("index", 0)) // scatter.index <- gather.index - .InputBind("input", - ctx->FwOp().output_grad("output", 0)) // scatter.input <- grad of gather.out + .InputBind("src", + ctx->FwOp().output_grad("output", 0)) // scatter.src <- grad of gather.out .InputBind("like", ctx->FwOp().input("input", 0)) .Output("output") .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) diff --git a/oneflow/user/ops/dim_scatter_ops.cpp b/oneflow/user/ops/dim_scatter_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a8f71d2439a7d88cb1e8be9d13c504f70df4102 --- /dev/null +++ b/oneflow/user/ops/dim_scatter_ops.cpp @@ -0,0 +1,296 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/error.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/user_op_registry.h" +#include "oneflow/user/kernels/dim_scatter_kernel_util.h" + +namespace oneflow { + +namespace user_op { + +namespace { +Maybe<void> InferTensorDesc(user_op::InferContext* ctx) { + const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0); + const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); + const TensorDesc* like = ctx->TensorDesc4ArgNameAndIndex("like", 0); + const TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("src", 0); + + int32_t dim = ctx->Attr<int32_t>("dim"); + + // check index.numaxes == src.num_axes == input/like.numaxes + int64_t src_num_axes = src->shape().NumAxes(); + CHECK_GT_OR_RETURN(src_num_axes, 0); + CHECK_LE_OR_RETURN(src_num_axes, kDimGatherMaxDimCount); + int64_t index_num_axes = index->shape().NumAxes(); + CHECK_EQ_OR_RETURN(src_num_axes, index_num_axes); + + int64_t output_num_axes = 0; + if (input) { + output_num_axes = input->shape().NumAxes(); + } else if (like) { + output_num_axes = like->shape().NumAxes(); + } else { + throw Error::Unimplemented(); + } + CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes); + + // check index.shape(i) <= input/like.shape(i) + FOR_RANGE(int64_t, i, 0, index_num_axes) { + if (i == dim) continue; + if (input) { + CHECK_LE_OR_RETURN(index->shape().At(i), input->shape().At(i)); + } else { + CHECK_LE_OR_RETURN(index->shape().At(i), like->shape().At(i)); + } + } + + // check index.shape(i) <= src.shape(i) + FOR_RANGE(int64_t, i, 0, index_num_axes) { + if (i == dim) continue; + CHECK_LE_OR_RETURN(index->shape().At(i), src->shape().At(i)); + } + + user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("output", 0); + *out->mut_shape() = input ? input->shape() : like->shape(); + return Maybe<void>::Ok(); +} + +Maybe<void> InferScalarTensorDesc(user_op::InferContext* ctx) { + const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0); + const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); + + int32_t dim = ctx->Attr<int32_t>("dim"); + + // check index.numaxes == src.num_axes == input/like.numaxes + int64_t output_num_axes = input->shape().NumAxes(); + int64_t index_num_axes = index->shape().NumAxes(); + CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes); + + // check index.shape(i) <= input/like.shape(i) + FOR_RANGE(int64_t, i, 0, index_num_axes) { + if (i == dim) continue; + CHECK_LE_OR_RETURN(index->shape().At(i), input->shape().At(i)); + } + + TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("output", 0); + *out->mut_shape() = input->shape(); + return Maybe<void>::Ok(); +} + +Maybe<void> InputArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); + CHECK(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + + return Maybe<void>::Ok(); +} + +Maybe<void> InputScalarArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); + CHECK(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + + return Maybe<void>::Ok(); +} + +void _SetSbp(user_op::SbpContext* ctx, const char* like_or_input) { + const user_op::TensorDesc& index_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); + int64_t index_num_axes = index_tensor.shape().NumAxes(); + const int32_t dim = ctx->Attr<int32_t>("dim"); + + FOR_RANGE(int64_t, i, 0, index_num_axes) { + if (i != dim) { + ctx->NewBuilder() + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("src", 0), i) + .Split(user_op::OpArg("output", 0), i) + .Split(user_op::OpArg(like_or_input, 0), i) + .Build(); + } else { + ctx->NewBuilder() + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("src", 0), i) + .PartialSum(user_op::OpArg("output", 0)) + .Broadcast(user_op::OpArg(like_or_input, 0)) + .Build(); + + ctx->NewBuilder() + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("src", 0), i) + .PartialSum(user_op::OpArg("output", 0)) + .PartialSum(user_op::OpArg(like_or_input, 0)) + .Build(); + } + } + + ctx->NewBuilder() + .PartialSum(user_op::OpArg("src", 0)) + .Broadcast(user_op::OpArg("index", 0)) + .PartialSum(user_op::OpArg("output", 0)) + .PartialSum(user_op::OpArg(like_or_input, 0)) + .Build(); +} + +Maybe<void> SetSbpLike(user_op::SbpContext* ctx) { + _SetSbp(ctx, "like"); + return Maybe<void>::Ok(); +} + +Maybe<void> SetSbpScatter(user_op::SbpContext* ctx) { + _SetSbp(ctx, "input"); + return Maybe<void>::Ok(); +} + +Maybe<void> InferDtype(user_op::InferContext* ctx) { + const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); + CHECK_OR_RETURN(IsIndexDataType(index->data_type())); + const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0); + if (input) { + CHECK_EQ_OR_RETURN(ctx->InputDType("input", 0), ctx->InputDType("src", 0)); + } else { + CHECK_EQ_OR_RETURN(ctx->InputDType("like", 0), ctx->InputDType("src", 0)); + } + *ctx->OutputDType("output", 0) = ctx->InputDType("src", 0); + return Maybe<void>::Ok(); +} + +Maybe<void> InferScalarDtype(user_op::InferContext* ctx) { + const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); + CHECK_OR_RETURN(IsIndexDataType(index->data_type())); + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe<void>::Ok(); +} + +Maybe<void> ScatterBackward(user_op::BackwardOpConfContext* ctx) { + const TensorDesc& src = ctx->FwOp().TensorDesc4ArgNameAndIndex("src", 0); + const TensorDesc& index = ctx->FwOp().TensorDesc4ArgNameAndIndex("index", 0); + const int64_t ndim = src.shape().NumAxes(); + + FOR_RANGE(int64_t, i, 0, ndim) { + if (index.shape().At(i) != src.shape().At(i)) { + UNIMPLEMENTED() << "The backward pass is implemented only for src.shape == index.shape.\n"; + } + } + + const auto op_src_grad_name = ctx->FwOp().op_name() + "_src_grad"; + ctx->DefineOp(op_src_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("dim_gather") + .InputBind("index", ctx->FwOp().input("index", 0)) + .InputBind("input", ctx->FwOp().output_grad("output", 0)) + .Output("output") + .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("src", 0), + [&ctx, &op_src_grad_name]() -> const std::string& { + return ctx->GetOp(op_src_grad_name).output("output", 0); + }); + const auto op_input_grad_name = ctx->FwOp().op_name() + "_input_grad"; + ctx->DefineOp(op_input_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("dim_scatter_update_scalar") + .InputBind("index", ctx->FwOp().input("index", 0)) + .InputBind("input", ctx->FwOp().output_grad("output", 0)) + .Output("output") + .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) + .Attr("src_scalar", static_cast<float>(0.0)) + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), + [&ctx, &op_input_grad_name]() -> const std::string& { + return ctx->GetOp(op_input_grad_name).output("output", 0); + }); + return Maybe<void>::Ok(); +} + +} // namespace + +#define REGISTER_SCATTER_LIKE_OP(optypename) \ + REGISTER_USER_OP(optypename) \ + .Input("like") \ + .Input("index") \ + .Input("src") \ + .Output("output") \ + .Attr<int32_t>("dim") \ + .SetTensorDescInferFn(InferTensorDesc) \ + .SetInputArgModifyFn(InputArgModifierFn) \ + .SetDataTypeInferFn(InferDtype) \ + .SetGetSbpFn(SetSbpLike) + +#define REGISTER_SCATTER_OP(optypename) \ + REGISTER_USER_OP(optypename) \ + .Input("input") \ + .Input("index") \ + .Input("src") \ + .Output("output") \ + .Attr<int32_t>("dim") \ + .SetTensorDescInferFn(InferTensorDesc) \ + .SetInputArgModifyFn(InputArgModifierFn) \ + .SetDataTypeInferFn(InferDtype) \ + .SetGetSbpFn(SetSbpScatter) + +#define REGISTER_SCATTER_SCALAR_OP(optypename) \ + REGISTER_USER_OP(optypename) \ + .Input("input") \ + .Input("index") \ + .Attr<float>("src_scalar") \ + .Output("output") \ + .Attr<int32_t>("dim") \ + .SetTensorDescInferFn(InferScalarTensorDesc) \ + .SetInputArgModifyFn(InputScalarArgModifierFn) \ + .SetDataTypeInferFn(InferScalarDtype) \ + .SetGetSbpFn(SetSbpScatter) + +#define REGISTER_SCATTER_GRAD(optypename) \ + REGISTER_USER_OP_GRAD(optypename).SetBackwardOpConfGenFn(ScatterBackward); + +#define REGISTER_SCATTER_SCALAR_GRAD(optypename) \ + REGISTER_USER_OP_GRAD(optypename) \ + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { \ + const auto op_input_grad_name = ctx->FwOp().op_name() + "_input_grad"; \ + ctx->DefineOp(op_input_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { \ + return builder.OpTypeName("dim_scatter_update_scalar") \ + .InputBind("index", ctx->FwOp().input("index", 0)) \ + .InputBind("input", ctx->FwOp().output_grad("output", 0)) \ + .Output("output") \ + .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) \ + .Attr("src_scalar", static_cast<float>(0.0)) \ + .Build(); \ + }); \ + ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), \ + [&ctx, &op_input_grad_name]() -> const std::string& { \ + return ctx->GetOp(op_input_grad_name).output("output", 0); \ + }); \ + return Maybe<void>::Ok(); \ + }); + +REGISTER_SCATTER_LIKE_OP("dim_scatter_add_like"); +REGISTER_SCATTER_OP("dim_scatter_add"); +REGISTER_SCATTER_OP("dim_scatter_update"); +REGISTER_SCATTER_OP("dim_scatter_mul"); + +REGISTER_SCATTER_SCALAR_OP("dim_scatter_update_scalar"); +REGISTER_SCATTER_SCALAR_OP("dim_scatter_add_scalar"); +REGISTER_SCATTER_SCALAR_OP("dim_scatter_mul_scalar"); + +REGISTER_SCATTER_GRAD("dim_scatter_add"); +REGISTER_SCATTER_GRAD("dim_scatter_update"); + +REGISTER_SCATTER_SCALAR_GRAD("dim_scatter_update_scalar"); +} // namespace user_op +} // namespace oneflow diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 19f5f388325977b018b7647304f3951f138cef5d..b1d835a5a38ae045a046f96926828b9ad709a8f1 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -370,5 +370,6 @@ from oneflow.ops.user_op_builder import api_user_op_builder as user_op_builder from oneflow.ops.user_op_builder import ( api_user_op_module_builder as user_op_module_builder, ) +from oneflow.nn.modules.scatter import * from . import autograd, distributed, linalg, optim, saved_model, sbp diff --git a/python/oneflow/nn/modules/scatter.py b/python/oneflow/nn/modules/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..7572e9d357a17b17597e8116ddea32c99be07053 --- /dev/null +++ b/python/oneflow/nn/modules/scatter.py @@ -0,0 +1,126 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import oneflow as flow +from oneflow.framework.tensor import Tensor +from oneflow.nn.module import Module + + +__all__ = ["scatter", "scatter_add"] + + +def scatter(input, dim, index, src): + r"""This operator writes the elements specified by `index` along with the axis + `dim` from the `src` into the `input`. + + Take a 3-D blob as example, the output is specified by: + + .. code-block:: python + + input[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + input[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + input[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + input, index and src (if it is a Tensor) should all have the same number of dimensions. + It is also required that index.shape(d) <= src.shape(d) for all dimensions d, + and that index.shape(d) <= self.shape(d) for all dimensions d != dim. + Note that index and src do not broadcast. + + Args: + input (Tensor): The input blob. + dim (int): The axis along which to index + index (Tensor): The index blob of elements to scatter. + src (Tensor or float): The source blob whose elements will be scatterd and updated to output. + + Returns: + Tensor: The scatterd Tensor. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + + >>> input = flow.ones((3,5))*2 + >>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32) + >>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]])) + >>> out = flow.scatter(input, 1, index, src) + >>> out + tensor([[ 0., 10., 20., 2., 2.], + [50., 60., 2., 2., 70.], + [ 2., 2., 2., 2., 2.]], dtype=oneflow.float32) + + """ + assert type(src) in [ + flow.Tensor, + float, + ], f"type of src must be oneflow.Tensor or float, but %s givien" % type(src) + + if isinstance(src, flow.Tensor): + return flow.F.dim_scatter(input, index, src, dim) + elif isinstance(src, float): + return flow.F.dim_scatter_scalar(input, index, src, dim) + + +def scatter_add(input, dim, index, src): + r"""This operator scatter the src with addition operation according to index along dim into the input. + + Take a 3-D blob as example, the output is specified by: + + .. code-block:: python + + input[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + input[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + input[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + + Args: + input (Tensor): The input blob. + dim (int): The axis along which to index + index (Tensor): The index blob of elements to scatter. + src (Tensor): The source blob whose elements will be scatterd and added to output. + + Returns: + Tensor: The scatterd Tensor. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + >>> input = flow.ones((3,5))*2 + >>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32) + >>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]])) + >>> out = flow.scatter_add(input, 1, index, src) + >>> out + tensor([[ 2., 12., 22., 2., 2.], + [52., 62., 2., 2., 72.], + [ 2., 2., 2., 2., 2.]], dtype=oneflow.float32) + + """ + + assert type(src) in [ + flow.Tensor + ], f"type of src must be oneflow.Tensor, but %s givien" % type(src) + + return flow.F.dim_scatter_add(input, index, src, dim) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/test/modules/test_scatter_ops.py b/python/oneflow/test/modules/test_scatter_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1c01e41119d4b66c3b070f3ebb5e388bfc64de --- /dev/null +++ b/python/oneflow/test/modules/test_scatter_ops.py @@ -0,0 +1,94 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest + +import oneflow as flow +import oneflow.unittest +import numpy as np +from automated_test_util import * + + +@flow.unittest.skip_unless_1n1d() +class TestScatterOpsModule(flow.unittest.TestCase): + @autotest(n=5) + def test_scatter_random_data_at_dim_0(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + index = constant( + torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device) + ) + y = torch.scatter(input, 0, index, src) + return y + + @autotest(n=5) + def test_scatter_random_data_at_dim_1(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + index = constant( + torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device) + ) + y = torch.scatter(input, 1, index, src) + return y + + @autotest(n=5) + def test_scatter_scalar_random_data_at_dim0(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + index = constant( + torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device) + ) + y = torch.scatter(input, 0, index, 3.14) + return y + + @autotest(n=5) + def test_scatter_scalar_random_data_at_dim1(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + index = constant( + torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device) + ) + y = torch.scatter(input, 1, index, 3.14) + return y + + @autotest(n=5) + def test_scatter_add_random_data_at_dim0(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + index = constant( + torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device) + ) + y = torch.scatter_add(input, 0, index, src) + return y + + @autotest(n=5) + def test_scatter_add_random_data_at_dim1(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device) + index = constant( + torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device) + ) + y = torch.scatter_add(input, 1, index, src) + return y + + +if __name__ == "__main__": + unittest.main()