From 32ba8001660737cca09313b363c505fbe8cfa4f7 Mon Sep 17 00:00:00 2001 From: ZZK <42901638+MARD1NO@users.noreply.github.com> Date: Thu, 29 Jul 2021 20:23:14 +0800 Subject: [PATCH] Rewrite activation function (#5465) * add activation * rename swish to silu * add selu * add four activation op * add softsign test * add silu mish selu softsign * Add softsign docs * Add functional impl * small fix for softsign backward * remove flow.mish test * add silu module test * add selu test * fix docs * fix softsign docs * fix format * fix static cast * merge master functional api yaml * add torch style unittest * Remove assert and add torch unittest * add tensor def * remove softsign test temporary * add return maybe ok * migrate nn ops to single_client dir * migrate unittest * remove lazy unittest * add unittest * fix to new directory * Remove useless docs and single client test * add doc * fix docs * add docs in oneflow and tensor namespace * add torch autotest * fix to new autotest * remove outdated python code * remove useless docs * enlarge unittest tolerance * Add static cast for const value * skip softsign unittest * skip tensor softsign Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- docs/source/nn.rst | 7 +- docs/source/oneflow.rst | 4 + docs/source/tensor.rst | 3 + oneflow/core/autograd/gradient_funcs/mish.cpp | 54 +++++ oneflow/core/autograd/gradient_funcs/selu.cpp | 54 +++++ oneflow/core/autograd/gradient_funcs/silu.cpp | 54 +++++ .../core/autograd/gradient_funcs/softsign.cpp | 54 +++++ oneflow/core/functional/functional_api.yaml | 32 +++ .../functional/impl/activation_functor.cpp | 58 +++++ oneflow/user/kernels/activation_kernels.cpp | 6 +- oneflow/user/kernels/activation_kernels.cu | 78 ++++++- oneflow/user/kernels/activation_kernels.h | 110 +++++++++ oneflow/user/ops/mish_op.cpp | 92 ++++++++ oneflow/user/ops/selu_op.cpp | 91 ++++++++ oneflow/user/ops/silu_op.cpp | 91 ++++++++ oneflow/user/ops/softsign_op.cpp | 91 ++++++++ python/oneflow/__init__.py | 6 + python/oneflow/nn/__init__.py | 3 + python/oneflow/nn/modules/activation.py | 214 +++++++++++++++++- .../oneflow/test/modules/test_activation.py | 152 +++++++++++++ python/oneflow/test/tensor/test_tensor.py | 79 +++++++ 21 files changed, 1327 insertions(+), 6 deletions(-) create mode 100644 oneflow/core/autograd/gradient_funcs/mish.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/selu.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/silu.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/softsign.cpp create mode 100644 oneflow/user/ops/mish_op.cpp create mode 100644 oneflow/user/ops/selu_op.cpp create mode 100644 oneflow/user/ops/silu_op.cpp create mode 100644 oneflow/user/ops/softsign_op.cpp diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 1eaa272a5..c3cb12c0c 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -65,11 +65,14 @@ Operators for neural networks ReLU6, ReflectionPad2d, ReplicationPad2d, - Sequential, + Sequential, + SELU, + SiLU, Sigmoid, SmoothL1Loss, Softmax, - Softplus, + Softplus, + Softsign, Tanh, Upsample, UpsamplingBilinear2d, diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index 4e0b646e6..63702bb8e 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -21,6 +21,7 @@ oneflow load, masked_fill, matmul, + mish, ones, ones_like, repeat, @@ -28,8 +29,11 @@ oneflow save, saved_model, scatter_nd, + selu, + silu, slice, slice_update, + softsign, sort, squeeze, stack, diff --git a/docs/source/tensor.rst b/docs/source/tensor.rst index 8f30ff623..3df7edc06 100644 --- a/docs/source/tensor.rst +++ b/docs/source/tensor.rst @@ -102,15 +102,18 @@ OneFlow Tensor Class retain_grad, round, rsqrt, + selu, shape, sigmoid, sign, + silu, sin, sin_, sinh, size, softmax, softplus, + softsign, sort, sqrt, square, diff --git a/oneflow/core/autograd/gradient_funcs/mish.cpp b/oneflow/core/autograd/gradient_funcs/mish.cpp new file mode 100644 index 000000000..8cedacb61 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/mish.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct MishInterpState : public OpExprInterpState { + bool requires_grad; +}; + +class Mish : public OpExprGradFunction<MishInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } + + Maybe<void> Capture(MishInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } + return Maybe<void>::Ok(); + } + + Maybe<void> Apply(const MishInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::MishGrad(out_grads.at(0), x)); + } + return Maybe<void>::Ok(); + } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("mish", Mish); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/selu.cpp b/oneflow/core/autograd/gradient_funcs/selu.cpp new file mode 100644 index 000000000..663606eff --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/selu.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct SeluInterpState : public OpExprInterpState { + bool requires_grad; +}; + +class Selu : public OpExprGradFunction<SeluInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } + + Maybe<void> Capture(SeluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } + return Maybe<void>::Ok(); + } + + Maybe<void> Apply(const SeluInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::SeluGrad(out_grads.at(0), x)); + } + return Maybe<void>::Ok(); + } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("selu", Selu); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/silu.cpp b/oneflow/core/autograd/gradient_funcs/silu.cpp new file mode 100644 index 000000000..7d4beb638 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/silu.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct SiluInterpState : public OpExprInterpState { + bool requires_grad; +}; + +class Silu : public OpExprGradFunction<SiluInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } + + Maybe<void> Capture(SiluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } + return Maybe<void>::Ok(); + } + + Maybe<void> Apply(const SiluInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::SiluGrad(out_grads.at(0), x)); + } + return Maybe<void>::Ok(); + } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("silu", Silu); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/softsign.cpp b/oneflow/core/autograd/gradient_funcs/softsign.cpp new file mode 100644 index 000000000..087ac118c --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/softsign.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct SoftSignInterpState : public OpExprInterpState { + bool requires_grad; +}; + +class SoftSign : public OpExprGradFunction<SoftSignInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } + + Maybe<void> Capture(SoftSignInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); } + return Maybe<void>::Ok(); + } + + Maybe<void> Apply(const SoftSignInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::SoftSignGrad(out_grads.at(0), x)); + } + return Maybe<void>::Ok(); + } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("softsign", SoftSign); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index bb7c8bda3..d30e68d95 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -718,6 +718,38 @@ signature: "Tensor PadGrad(Tensor dy, *, Int64List pad, String mode=\"constant\", Scalar value=0)" bind_python: False +- name: "silu" + signature: "Tensor Silu(Tensor x, *)" + bind_python: True + +- name: "silu_grad" + signature: "Tensor SiluGrad(Tensor x, Tensor dy, *)" + bind_python: False + +- name: "mish" + signature: "Tensor Mish(Tensor x, *)" + bind_python: True + +- name: "mish_grad" + signature: "Tensor MishGrad(Tensor x, Tensor dy, *)" + bind_python: False + +- name: "selu" + signature: "Tensor Selu(Tensor x, *)" + bind_python: True + +- name: "selu_grad" + signature: "Tensor SeluGrad(Tensor x, Tensor dy, *)" + bind_python: False + +- name: "softsign" + signature: "Tensor SoftSign(Tensor x, *)" + bind_python: True + +- name: "softsign_grad" + signature: "Tensor SoftSignGrad(Tensor x, Tensor dy, *)" + bind_python: False + - name: "diag" signature: "Tensor Diag(Tensor x, *, Int32 diagonal=0)" bind_python: True diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp index 7f8bcd10d..c12981038 100644 --- a/oneflow/core/functional/impl/activation_functor.cpp +++ b/oneflow/core/functional/impl/activation_functor.cpp @@ -210,6 +210,56 @@ class LeakyReluGradFunctor { std::shared_ptr<OpExpr> op_; }; +class SiluFunctor : public UnaryFunctor { + public: + SiluFunctor() { op_ = CHECK_JUST(one::OpBuilder("silu").Input("in").Output("out").Build()); } +}; + +class SiluGradFunctor : public BinaryFunctor { + public: + SiluGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("silu_grad").Input("dy").Input("x").Output("dx").Build()); + } +}; + +class MishFunctor : public UnaryFunctor { + public: + MishFunctor() { op_ = CHECK_JUST(one::OpBuilder("mish").Input("in").Output("out").Build()); } +}; + +class MishGradFunctor : public BinaryFunctor { + public: + MishGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("mish_grad").Input("dy").Input("x").Output("dx").Build()); + } +}; + +class SeluFunctor : public UnaryFunctor { + public: + SeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("selu").Input("in").Output("out").Build()); } +}; + +class SeluGradFunctor : public BinaryFunctor { + public: + SeluGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("selu_grad").Input("dy").Input("x").Output("dx").Build()); + } +}; + +class SoftSignFunctor : public UnaryFunctor { + public: + SoftSignFunctor() { + op_ = CHECK_JUST(one::OpBuilder("softsign").Input("in").Output("out").Build()); + } +}; + +class SoftSignGradFunctor : public BinaryFunctor { + public: + SoftSignGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("softsign_grad").Input("dy").Input("x").Output("dx").Build()); + } +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -229,6 +279,14 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::HardSwishGradFunctor>("HardSwishGrad"); m.add_functor<impl::LeakyReluFunctor>("LeakyRelu"); m.add_functor<impl::LeakyReluGradFunctor>("LeakyReluGrad"); + m.add_functor<impl::SiluFunctor>("Silu"); + m.add_functor<impl::SiluGradFunctor>("SiluGrad"); + m.add_functor<impl::MishFunctor>("Mish"); + m.add_functor<impl::MishGradFunctor>("MishGrad"); + m.add_functor<impl::SeluFunctor>("Selu"); + m.add_functor<impl::SeluGradFunctor>("SeluGrad"); + m.add_functor<impl::SoftSignFunctor>("SoftSign"); + m.add_functor<impl::SoftSignGradFunctor>("SoftSignGrad"); }; } // namespace functional diff --git a/oneflow/user/kernels/activation_kernels.cpp b/oneflow/user/kernels/activation_kernels.cpp index 997489ab6..f835136a9 100644 --- a/oneflow/user/kernels/activation_kernels.cpp +++ b/oneflow/user/kernels/activation_kernels.cpp @@ -21,7 +21,11 @@ namespace oneflow { REGISTER_ELU_KERNEL(DeviceType::kCPU, dtype); \ REGISTER_HARDSWISH_KERNEL(DeviceType::kCPU, dtype); \ REGISTER_HARDSIGMOID_KERNEL(DeviceType::kCPU, dtype); \ - REGISTER_HARDTANH_KERNEL(DeviceType::kCPU, dtype); + REGISTER_HARDTANH_KERNEL(DeviceType::kCPU, dtype); \ + REGISTER_MISH_KERNEL(DeviceType::kCPU, dtype); \ + REGISTER_SILU_KERNEL(DeviceType::kCPU, dtype); \ + REGISTER_SELU_KERNEL(DeviceType::kCPU, dtype); \ + REGISTER_SOFTSIGN_KERNEL(DeviceType::kCPU, dtype); REGISTER_ACTIVATION_CPU_KERNEL(float); REGISTER_ACTIVATION_CPU_KERNEL(double); diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu index 0001d315b..db292b930 100644 --- a/oneflow/user/kernels/activation_kernels.cu +++ b/oneflow/user/kernels/activation_kernels.cu @@ -56,11 +56,87 @@ struct HardswishGradFunctor<half> { } }; +template<> +struct MishFunctor<half> { + OF_DEVICE_FUNC explicit MishFunctor() : float_functor(MishFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x) const { + return __float2half(float_functor(__half2float(x))); + } + MishFunctor<float> float_functor; +}; + +template<> +struct MishGradFunctor<half> { + OF_DEVICE_FUNC explicit MishGradFunctor() : float_functor(MishGradFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x, half dy) const { + return __float2half(float_functor(__half2float(x), __half2float(dy))); + } + MishGradFunctor<float> float_functor; +}; + +template<> +struct SiluFunctor<half> { + OF_DEVICE_FUNC explicit SiluFunctor() : float_functor(SiluFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x) const { + return __float2half(float_functor(__half2float(x))); + } + SiluFunctor<float> float_functor; +}; + +template<> +struct SiluGradFunctor<half> { + OF_DEVICE_FUNC explicit SiluGradFunctor() : float_functor(SiluGradFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x, half dy) const { + return __float2half(float_functor(__half2float(x), __half2float(dy))); + } + SiluGradFunctor<float> float_functor; +}; + +template<> +struct SeluFunctor<half> { + OF_DEVICE_FUNC explicit SeluFunctor() : float_functor(SeluFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x) const { + return __float2half(float_functor(__half2float(x))); + } + SeluFunctor<float> float_functor; +}; + +template<> +struct SeluGradFunctor<half> { + OF_DEVICE_FUNC explicit SeluGradFunctor() : float_functor(SeluGradFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x, half dy) const { + return __float2half(float_functor(__half2float(x), __half2float(dy))); + } + SeluGradFunctor<float> float_functor; +}; + +template<> +struct SoftSignFunctor<half> { + OF_DEVICE_FUNC explicit SoftSignFunctor() : float_functor(SoftSignFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x) const { + return __float2half(float_functor(__half2float(x))); + } + SoftSignFunctor<float> float_functor; +}; + +template<> +struct SoftSignGradFunctor<half> { + OF_DEVICE_FUNC explicit SoftSignGradFunctor() : float_functor(SoftSignGradFunctor<float>()) {} + OF_DEVICE_FUNC half operator()(half x, half dy) const { + return __float2half(float_functor(__half2float(x), __half2float(dy))); + } + SoftSignGradFunctor<float> float_functor; +}; + #define REGISTER_ACTIVATION_GPU_KERNEL(dtype) \ REGISTER_ELU_KERNEL(DeviceType::kGPU, dtype); \ REGISTER_HARDSWISH_KERNEL(DeviceType::kGPU, dtype); \ REGISTER_HARDSIGMOID_KERNEL(DeviceType::kGPU, dtype); \ - REGISTER_HARDTANH_KERNEL(DeviceType::kGPU, dtype); + REGISTER_HARDTANH_KERNEL(DeviceType::kGPU, dtype); \ + REGISTER_MISH_KERNEL(DeviceType::kGPU, dtype); \ + REGISTER_SILU_KERNEL(DeviceType::kGPU, dtype); \ + REGISTER_SELU_KERNEL(DeviceType::kGPU, dtype); \ + REGISTER_SOFTSIGN_KERNEL(DeviceType::kGPU, dtype); REGISTER_ACTIVATION_GPU_KERNEL(half); REGISTER_ACTIVATION_GPU_KERNEL(float); diff --git a/oneflow/user/kernels/activation_kernels.h b/oneflow/user/kernels/activation_kernels.h index 5d4ebb99f..b13d5f9e4 100644 --- a/oneflow/user/kernels/activation_kernels.h +++ b/oneflow/user/kernels/activation_kernels.h @@ -111,6 +111,80 @@ struct HardtanhGradFunctor { const T max_val; }; +template<typename T> +struct MishFunctor { + OF_DEVICE_FUNC explicit MishFunctor() {} + OF_DEVICE_FUNC T operator()(T x) const { + T soft_plus_val = log(static_cast<T>(1) + exp(x)); + T exp_val = exp(soft_plus_val); + T neg_exp_val = exp(-soft_plus_val); + T tanh_val = (exp_val - neg_exp_val) / (exp_val + neg_exp_val); + return x * tanh_val; + } +}; + +template<typename T> +struct MishGradFunctor { + OF_DEVICE_FUNC explicit MishGradFunctor() {} + OF_DEVICE_FUNC T operator()(T x, T dy) const { + T sp = log(static_cast<T>(1) + exp(x)); + T grad_sp = static_cast<T>(1) - exp(-sp); + T tsp = (exp(sp) - exp(-sp)) / (exp(sp) + exp(-sp)); + T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp; + return dy * (x * grad_tsp + tsp); + } +}; + +template<typename T> +struct SiluFunctor { + OF_DEVICE_FUNC explicit SiluFunctor() {} + OF_DEVICE_FUNC T operator()(T x) const { return (x / (static_cast<T>(1) + exp(-x))); } +}; + +template<typename T> +struct SiluGradFunctor { + OF_DEVICE_FUNC explicit SiluGradFunctor() {} + OF_DEVICE_FUNC T operator()(T x, T dy) const { + T sig = static_cast<T>(1) / (static_cast<T>(1) + exp(-x)); + return dy * (sig * (static_cast<T>(1) + x * (static_cast<T>(1) - sig))); + } +}; + +template<typename T> +struct SeluFunctor { + OF_DEVICE_FUNC explicit SeluFunctor() {} + OF_DEVICE_FUNC T operator()(T x) const { + return (x > static_cast<T>(0)) ? scale * x : scale * alpha * (exp(x) - static_cast<T>(1)); + } + const T scale = 1.0507009873554804934193349852946; + const T alpha = 1.6732632423543772848170429916717; +}; + +template<typename T> +struct SeluGradFunctor { + OF_DEVICE_FUNC explicit SeluGradFunctor() {} + OF_DEVICE_FUNC T operator()(T x, T dy) const { + return (x > static_cast<T>(0)) ? scale * dy : dy * scale * alpha * (exp(x)); + } + const T scale = 1.0507009873554804934193349852946; + const T alpha = 1.6732632423543772848170429916717; +}; + +template<typename T> +struct SoftSignFunctor { + OF_DEVICE_FUNC explicit SoftSignFunctor() {} + OF_DEVICE_FUNC T operator()(T x) const { return x / (static_cast<T>(1) + abs(x)); } +}; + +template<typename T> +struct SoftSignGradFunctor { + OF_DEVICE_FUNC explicit SoftSignGradFunctor() {} + OF_DEVICE_FUNC T operator()(T x, T dy) const { + T val = (static_cast<T>(1) + abs(x)); + return dy / (val * val); + } +}; + #define REGISTER_ELU_KERNEL(device, dtype) \ REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ device, "elu", EluFunctor, dtype, dtype, \ @@ -179,6 +253,42 @@ struct HardtanhGradFunctor { return Maybe<void>::Ok(); \ }); +#define REGISTER_MISH_KERNEL(device, dtype) \ + REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ + device, "mish", MishFunctor, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return MishFunctor<dtype>(); }, "out", "in"); \ + REGISTER_BINARY_ELEMWISE_USER_KERNEL( \ + device, "mish_grad", MishGradFunctor, dtype, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return MishGradFunctor<dtype>(); }, "dx", "x", \ + "dy"); + +#define REGISTER_SILU_KERNEL(device, dtype) \ + REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ + device, "silu", SiluFunctor, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return SiluFunctor<dtype>(); }, "out", "in"); \ + REGISTER_BINARY_ELEMWISE_USER_KERNEL( \ + device, "silu_grad", SiluGradFunctor, dtype, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return SiluGradFunctor<dtype>(); }, "dx", "x", \ + "dy"); + +#define REGISTER_SELU_KERNEL(device, dtype) \ + REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ + device, "selu", SeluFunctor, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return SeluFunctor<dtype>(); }, "out", "in"); \ + REGISTER_BINARY_ELEMWISE_USER_KERNEL( \ + device, "selu_grad", SeluGradFunctor, dtype, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return SeluGradFunctor<dtype>(); }, "dx", "x", \ + "dy"); + +#define REGISTER_SOFTSIGN_KERNEL(device, dtype) \ + REGISTER_UNARY_ELEMWISE_USER_KERNEL( \ + device, "softsign", SoftSignFunctor, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return SoftSignFunctor<dtype>(); }, "out", "in"); \ + REGISTER_BINARY_ELEMWISE_USER_KERNEL( \ + device, "softsign_grad", SoftSignGradFunctor, dtype, dtype, dtype, \ + [](user_op::KernelComputeContext* ctx) { return SoftSignGradFunctor<dtype>(); }, "dx", "x", \ + "dy"); + } // namespace oneflow #endif // _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_ diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp new file mode 100644 index 000000000..4f51ca760 --- /dev/null +++ b/oneflow/user/ops/mish_op.cpp @@ -0,0 +1,92 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +REGISTER_USER_OP("mish") + .Input("in") + .Output("out") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("mish_grad") + .Input("x") + .Input("dy") + .Output("dx") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP_GRAD("mish").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto mish_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(mish_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("mish_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &mish_grad_op_name]() -> const std::string& { + return ctx->GetOp(mish_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp new file mode 100644 index 000000000..8697cbf39 --- /dev/null +++ b/oneflow/user/ops/selu_op.cpp @@ -0,0 +1,91 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +REGISTER_USER_OP("selu") + .Input("in") + .Output("out") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("selu_grad") + .Input("x") + .Input("dy") + .Output("dx") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP_GRAD("selu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + const auto selu_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(selu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("selu_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &selu_grad_op_name]() -> const std::string& { + return ctx->GetOp(selu_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); +}); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp new file mode 100644 index 000000000..eb46ab7b4 --- /dev/null +++ b/oneflow/user/ops/silu_op.cpp @@ -0,0 +1,91 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +REGISTER_USER_OP("silu") + .Input("in") + .Output("out") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("silu_grad") + .Input("x") + .Input("dy") + .Output("dx") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP_GRAD("silu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + const auto silu_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(silu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("silu_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &silu_grad_op_name]() -> const std::string& { + return ctx->GetOp(silu_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); +}); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp new file mode 100644 index 000000000..324980302 --- /dev/null +++ b/oneflow/user/ops/softsign_op.cpp @@ -0,0 +1,91 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +REGISTER_USER_OP("softsign") + .Input("in") + .Output("out") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("softsign_grad") + .Input("x") + .Input("dy") + .Output("dx") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP_GRAD("softsign").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + const auto softsign_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(softsign_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("softsign_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &softsign_grad_op_name]() -> const std::string& { + return ctx->GetOp(softsign_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); +}); + +} // namespace + +} // namespace oneflow diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 2a8e696c0..8e6f9d5cf 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -216,6 +216,12 @@ from oneflow.nn.modules.activation import mish_op as mish from oneflow.nn.modules.activation import sigmoid_op as sigmoid from oneflow.nn.modules.activation import softmax_op as softmax from oneflow.nn.modules.activation import tanh_op as tanh +from oneflow.nn.modules.activation import silu_op as silu +from oneflow.nn.modules.activation import selu_op as selu +from oneflow.nn.modules.activation import softsign_op as softsign +from oneflow.nn.modules.activation import mish_op as mish + + from oneflow.nn.modules.adaptive_pool import ( adaptive_avg_pool1d, adaptive_avg_pool2d, diff --git a/python/oneflow/nn/__init__.py b/python/oneflow/nn/__init__.py index 0993b8e4e..8401b05fb 100644 --- a/python/oneflow/nn/__init__.py +++ b/python/oneflow/nn/__init__.py @@ -32,6 +32,9 @@ from oneflow.nn.modules.activation import ( Softmax, Softplus, Tanh, + SELU, + SiLU, + Softsign, ) from oneflow.nn.modules.adaptive_pool import ( AdaptiveAvgPool1d, diff --git a/python/oneflow/nn/modules/activation.py b/python/oneflow/nn/modules/activation.py index 406d0187e..8990c0922 100644 --- a/python/oneflow/nn/modules/activation.py +++ b/python/oneflow/nn/modules/activation.py @@ -923,11 +923,11 @@ class Mish(Module): """ def __init__(self, inplace: bool = False): - assert not inplace, "In-place operation is not currently supported" + self.inplace = inplace super().__init__() def forward(self, x): - return x * flow.tanh(flow.softplus(x)) + return flow.F.mish(x) def mish_op(x): @@ -953,6 +953,216 @@ def mish_op_tensor(x): return Mish()(x) +class SiLU(Module): + r"""SiLU(Swish) activation: + + .. math:: + + \text{SiLU}(x) = x * sigmoid(x) + + .. note:: + See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish: + a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_ + where the SiLU was experimented with later. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow as flow + + + >>> x = np.array([1, 2, 3]).astype(np.float32) + >>> input = flow.Tensor(x) + >>> silu = flow.nn.SiLU() + >>> out = silu(input) + >>> out + tensor([0.7311, 1.7616, 2.8577], dtype=oneflow.float32) + """ + + def __init__(self, inplace: bool = False): + self.inplace = inplace + super().__init__() + + def forward(self, x): + return flow.F.silu(x) + + +def silu_op(x): + r"""SiLU(Swish) activation: + + .. math:: + \text{SiLU}(x) = x * sigmoid(x) + + .. note:: + + See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish: + a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_ + where the SiLU was experimented with later. + + See :mod:`oneflow.nn.SiLU` + """ + + return SiLU()(x) + + +@register_tensor_op("silu") +def silu_op_tensor(x): + r""" + silu() -> Tensor + See :func:`oneflow.silu` + """ + return SiLU()(x) + + +class SELU(Module): + r"""Applies the element-wise function: + + The formula is: + + .. math:: + + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + + with :math:`\alpha = 1.6732632423543772848170429916717` and + + :math:`\text{scale} = 1.0507009873554804934193349852946`. + + .. warning:: + + When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation, + ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'`` + in order to get `Self-Normalizing Neural Networks`_. + See :func:`torch.nn.init.calculate_gain` for more information. + + More details can be found in the paper `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow as flow + >>> x = np.array([1, 2, 3]).astype(np.float32) + >>> input = flow.Tensor(x) + >>> selu = flow.nn.SELU() + >>> out = selu(input) + >>> out + tensor([1.0507, 2.1014, 3.1521], dtype=oneflow.float32) + """ + + def __init__(self, inplace: bool = False): + self.inplace = inplace + super().__init__() + + def forward(self, x): + return flow.F.selu(x) + + +def selu_op(x): + r"""The SELU activation. + + The formula is: + + .. math:: + + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + + with :math:`\alpha = 1.6732632423543772848170429916717` and + + :math:`\text{scale} = 1.0507009873554804934193349852946`. + + See :mod:`oneflow.nn.SELU` + """ + return SELU()(x) + + +@register_tensor_op("selu") +def selu_op_tensor(x): + r""" + selu() -> Tensor + + See :func:`oneflow.selu` + """ + return SELU()(x) + + +class Softsign(Module): + r"""The SoftSign activation. + + The formula is: + + .. math:: + + SoftSign(x) = \frac{x}{1 + |x|} + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow as flow + >>> x = np.array([1, 2, 3]).astype(np.float32) + >>> input = flow.Tensor(x) + >>> softsign = flow.nn.Softsign() + >>> out = softsign(input) + >>> out + tensor([0.5 , 0.6667, 0.75 ], dtype=oneflow.float32) + """ + + def __init__(self, inplace: bool = False): + self.inplace = inplace + super().__init__() + + def forward(self, x): + return flow.F.softsign(x) + + +def softsign_op(x): + r"""The SoftSign activation. + + The formula is: + + .. math:: + + SoftSign(x) = \frac{x}{1 + |x|} + + See :mod:`oneflow.nn.Softsign` + """ + return Softsign()(x) + + +@register_tensor_op("softsign") +def softsign_op_tensor(x): + r""" + softsign() -> Tensor + See :func:`oneflow.softsign` + """ + return Softsign()(x) + + if __name__ == "__main__": import doctest diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index a54d81c5f..27bf0df0d 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -824,6 +824,158 @@ class TestMishModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(n=5) + def test_mish_module_with_random_data(test_case): + m = torch.nn.Mish() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor().to(device) + y = m(x) + return y + + +def _np_silu_grad(x): + _sig = 1 / (1 + np.exp(-x)) + return _sig * (1 + x * (1 - _sig)) + + +def _test_silu_impl(test_case, shape, device): + m = flow.nn.SiLU() + np_input = np.random.randn(*shape) + np_out = np_input / (1 + np.exp(-np_input)) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + of_out = m(of_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), _np_silu_grad(np_input), 1e-5, 1e-5) + ) + + +@flow.unittest.skip_unless_1n1d() +class TestSiluModule(flow.unittest.TestCase): + def test_silu(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [_test_silu_impl] + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] + + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + @autotest(n=5) + def test_silu_module_with_random_data(test_case): + m = torch.nn.SiLU() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor().to(device) + y = m(x) + return y + + +def _np_selu(x): + scale = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + return np.where(x < 0, scale * alpha * (np.exp(x) - 1), scale * x) + + +def _np_selu_grad(x): + scale = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + return np.where(x < 0, scale * alpha * np.exp(x), scale) + + +def _test_selu_impl(test_case, shape, device): + m = flow.nn.SELU() + np_input = np.random.randn(*shape) + np_out = _np_selu(np_input) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + of_out = m(of_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), _np_selu_grad(np_input), 1e-5, 1e-5) + ) + + +@flow.unittest.skip_unless_1n1d() +class TestSeluModule(flow.unittest.TestCase): + def test_selu(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [_test_selu_impl] + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] + + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + @autotest(n=5) + def test_selu_module_with_random_data(test_case): + m = torch.nn.SELU() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor().to(device) + y = m(x) + return y + + +def _np_softsign(x): + return x / (1.0 + np.abs(x)) + + +def _np_softsign_grad(x): + return 1.0 / (np.square(1.0 + np.abs(x))) + + +def _test_softsign_impl(test_case, shape, device): + m = flow.nn.Softsign() + np_input = np.random.randn(*shape) + np_out = _np_softsign(np_input) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + of_out = m(of_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-3, 1e-3)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), _np_softsign_grad(np_input), 1e-3, 1e-3) + ) + + +@unittest.skip("still have error in ci test") +class TestSoftsignModule(flow.unittest.TestCase): + @autotest(n=5) + def test_softsign_module_with_random_data(test_case): + m = torch.nn.Softsign() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor().to(device) + y = m(x) + return y + + def test_softsign(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [_test_softsign_impl] + arg_dict["shape"] = [(3, 3), (2, 3, 3)] + + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py index d08f5ac05..6328865ef 100644 --- a/python/oneflow/test/tensor/test_tensor.py +++ b/python/oneflow/test/tensor/test_tensor.py @@ -1006,6 +1006,85 @@ class TestTensor(flow.unittest.TestCase): ) ) + @flow.unittest.skip_unless_1n1d() + def test_tensor_mish(test_case): + def np_mish(x): + f = 1 + np.exp(x) + y = x * ((f * f - 1) / (f * f + 1)) + y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / ( + (f * f + 1) * (f * f + 1) + ) + return [y, y_grad] + + np_input = np.random.randn(2, 4, 5, 6,) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + of_out = of_input.mish() + + np_out, np_grad = np_mish(np_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + @flow.unittest.skip_unless_1n1d() + def test_tensor_silu(test_case): + def np_silu(x): + _sig = 1 / (1 + np.exp(-x)) + y = x * _sig + y_grad = _sig * (1 + x * (1 - _sig)) + return [y, y_grad] + + np_input = np.random.randn(2, 4, 5, 6,) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + of_out = of_input.silu() + + np_out, np_grad = np_silu(np_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + @flow.unittest.skip_unless_1n1d() + def test_tensor_selu(test_case): + _scale = 1.0507009873554804934193349852946 + _alpha = 1.6732632423543772848170429916717 + + def np_selu(x): + y = np.where(x < 0, _scale * _alpha * (np.exp(x) - 1), _scale * x) + y_grad = np.where(x < 0, _scale * _alpha * np.exp(x), _scale) + return [y, y_grad] + + np_input = np.random.randn(2, 4, 5, 6,) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + of_out = of_input.selu() + + np_out, np_grad = np_selu(np_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + @unittest.skip("still have error in ci") + def test_tensor_softsign(test_case): + def np_softsign(x): + y = x / (1 + np.abs(x)) + y_grad = 1 / np.square(1 + np.abs(x)) + return [y, y_grad] + + np_input = np.random.randn(2, 4, 5, 6,) + of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True) + of_out = of_input.softsign() + + np_out, np_grad = np_softsign(np_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5)) + if __name__ == "__main__": unittest.main() -- GitLab