From 32ba8001660737cca09313b363c505fbe8cfa4f7 Mon Sep 17 00:00:00 2001
From: ZZK <42901638+MARD1NO@users.noreply.github.com>
Date: Thu, 29 Jul 2021 20:23:14 +0800
Subject: [PATCH] Rewrite activation function (#5465)

* add activation

* rename swish to silu

* add selu

* add four activation op

* add softsign test

* add silu mish selu softsign

* Add softsign docs

* Add functional impl

* small fix for softsign backward

* remove flow.mish test

* add silu module test

* add selu test

* fix docs

* fix softsign docs

* fix format

* fix static cast

* merge master functional api yaml

* add torch style unittest

* Remove assert and add torch unittest

* add tensor def

* remove softsign test temporary

* add return maybe ok

* migrate nn ops to single_client dir

* migrate unittest

* remove lazy unittest

* add unittest

* fix to new directory

* Remove useless docs and single client test

* add doc

* fix docs

* add docs in oneflow and tensor namespace

* add torch autotest

* fix to new autotest

* remove outdated python code

* remove useless docs

* enlarge unittest tolerance

* Add static cast for const value

* skip softsign unittest

* skip tensor softsign

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 docs/source/nn.rst                            |   7 +-
 docs/source/oneflow.rst                       |   4 +
 docs/source/tensor.rst                        |   3 +
 oneflow/core/autograd/gradient_funcs/mish.cpp |  54 +++++
 oneflow/core/autograd/gradient_funcs/selu.cpp |  54 +++++
 oneflow/core/autograd/gradient_funcs/silu.cpp |  54 +++++
 .../core/autograd/gradient_funcs/softsign.cpp |  54 +++++
 oneflow/core/functional/functional_api.yaml   |  32 +++
 .../functional/impl/activation_functor.cpp    |  58 +++++
 oneflow/user/kernels/activation_kernels.cpp   |   6 +-
 oneflow/user/kernels/activation_kernels.cu    |  78 ++++++-
 oneflow/user/kernels/activation_kernels.h     | 110 +++++++++
 oneflow/user/ops/mish_op.cpp                  |  92 ++++++++
 oneflow/user/ops/selu_op.cpp                  |  91 ++++++++
 oneflow/user/ops/silu_op.cpp                  |  91 ++++++++
 oneflow/user/ops/softsign_op.cpp              |  91 ++++++++
 python/oneflow/__init__.py                    |   6 +
 python/oneflow/nn/__init__.py                 |   3 +
 python/oneflow/nn/modules/activation.py       | 214 +++++++++++++++++-
 .../oneflow/test/modules/test_activation.py   | 152 +++++++++++++
 python/oneflow/test/tensor/test_tensor.py     |  79 +++++++
 21 files changed, 1327 insertions(+), 6 deletions(-)
 create mode 100644 oneflow/core/autograd/gradient_funcs/mish.cpp
 create mode 100644 oneflow/core/autograd/gradient_funcs/selu.cpp
 create mode 100644 oneflow/core/autograd/gradient_funcs/silu.cpp
 create mode 100644 oneflow/core/autograd/gradient_funcs/softsign.cpp
 create mode 100644 oneflow/user/ops/mish_op.cpp
 create mode 100644 oneflow/user/ops/selu_op.cpp
 create mode 100644 oneflow/user/ops/silu_op.cpp
 create mode 100644 oneflow/user/ops/softsign_op.cpp

diff --git a/docs/source/nn.rst b/docs/source/nn.rst
index 1eaa272a5..c3cb12c0c 100644
--- a/docs/source/nn.rst
+++ b/docs/source/nn.rst
@@ -65,11 +65,14 @@ Operators for neural networks
         ReLU6,
         ReflectionPad2d,
         ReplicationPad2d,
-        Sequential,
+        Sequential, 
+        SELU, 
+        SiLU, 
         Sigmoid,
         SmoothL1Loss,
         Softmax,
-        Softplus,
+        Softplus, 
+        Softsign, 
         Tanh,
         Upsample,
         UpsamplingBilinear2d,
diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst
index 4e0b646e6..63702bb8e 100644
--- a/docs/source/oneflow.rst
+++ b/docs/source/oneflow.rst
@@ -21,6 +21,7 @@ oneflow
             load, 
             masked_fill, 
             matmul, 
+            mish, 
             ones, 
             ones_like, 
             repeat, 
@@ -28,8 +29,11 @@ oneflow
             save, 
             saved_model, 
             scatter_nd, 
+            selu, 
+            silu, 
             slice, 
             slice_update, 
+            softsign, 
             sort, 
             squeeze, 
             stack, 
diff --git a/docs/source/tensor.rst b/docs/source/tensor.rst
index 8f30ff623..3df7edc06 100644
--- a/docs/source/tensor.rst
+++ b/docs/source/tensor.rst
@@ -102,15 +102,18 @@ OneFlow Tensor Class
             retain_grad, 
             round, 
             rsqrt, 
+            selu, 
             shape, 
             sigmoid, 
             sign, 
+            silu, 
             sin, 
             sin_, 
             sinh, 
             size, 
             softmax, 
             softplus, 
+            softsign, 
             sort, 
             sqrt, 
             square, 
diff --git a/oneflow/core/autograd/gradient_funcs/mish.cpp b/oneflow/core/autograd/gradient_funcs/mish.cpp
new file mode 100644
index 000000000..8cedacb61
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/mish.cpp
@@ -0,0 +1,54 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct MishInterpState : public OpExprInterpState {
+  bool requires_grad;
+};
+
+class Mish : public OpExprGradFunction<MishInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(MishInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
+                      const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const MishInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    in_grads->resize(1);
+    if (ctx->requires_grad) {
+      const auto& x = ctx->SavedTensors().at(0);
+      in_grads->at(0) = JUST(functional::MishGrad(out_grads.at(0), x));
+    }
+    return Maybe<void>::Ok();
+  }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("mish", Mish);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/autograd/gradient_funcs/selu.cpp b/oneflow/core/autograd/gradient_funcs/selu.cpp
new file mode 100644
index 000000000..663606eff
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/selu.cpp
@@ -0,0 +1,54 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct SeluInterpState : public OpExprInterpState {
+  bool requires_grad;
+};
+
+class Selu : public OpExprGradFunction<SeluInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(SeluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
+                      const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const SeluInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    in_grads->resize(1);
+    if (ctx->requires_grad) {
+      const auto& x = ctx->SavedTensors().at(0);
+      in_grads->at(0) = JUST(functional::SeluGrad(out_grads.at(0), x));
+    }
+    return Maybe<void>::Ok();
+  }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("selu", Selu);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/autograd/gradient_funcs/silu.cpp b/oneflow/core/autograd/gradient_funcs/silu.cpp
new file mode 100644
index 000000000..7d4beb638
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/silu.cpp
@@ -0,0 +1,54 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct SiluInterpState : public OpExprInterpState {
+  bool requires_grad;
+};
+
+class Silu : public OpExprGradFunction<SiluInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(SiluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
+                      const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const SiluInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    in_grads->resize(1);
+    if (ctx->requires_grad) {
+      const auto& x = ctx->SavedTensors().at(0);
+      in_grads->at(0) = JUST(functional::SiluGrad(out_grads.at(0), x));
+    }
+    return Maybe<void>::Ok();
+  }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("silu", Silu);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/autograd/gradient_funcs/softsign.cpp b/oneflow/core/autograd/gradient_funcs/softsign.cpp
new file mode 100644
index 000000000..087ac118c
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/softsign.cpp
@@ -0,0 +1,54 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct SoftSignInterpState : public OpExprInterpState {
+  bool requires_grad;
+};
+
+class SoftSign : public OpExprGradFunction<SoftSignInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(SoftSignInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const SoftSignInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    in_grads->resize(1);
+    if (ctx->requires_grad) {
+      const auto& x = ctx->SavedTensors().at(0);
+      in_grads->at(0) = JUST(functional::SoftSignGrad(out_grads.at(0), x));
+    }
+    return Maybe<void>::Ok();
+  }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("softsign", SoftSign);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index bb7c8bda3..d30e68d95 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -718,6 +718,38 @@
   signature: "Tensor PadGrad(Tensor dy, *, Int64List pad, String mode=\"constant\", Scalar value=0)"
   bind_python: False
 
+- name: "silu"
+  signature: "Tensor Silu(Tensor x, *)"
+  bind_python: True
+
+- name: "silu_grad"
+  signature: "Tensor SiluGrad(Tensor x, Tensor dy, *)"
+  bind_python: False
+
+- name: "mish"
+  signature: "Tensor Mish(Tensor x, *)"
+  bind_python: True
+
+- name: "mish_grad"
+  signature: "Tensor MishGrad(Tensor x, Tensor dy, *)"
+  bind_python: False
+
+- name: "selu"
+  signature: "Tensor Selu(Tensor x, *)"
+  bind_python: True
+
+- name: "selu_grad"
+  signature: "Tensor SeluGrad(Tensor x, Tensor dy, *)"
+  bind_python: False
+
+- name: "softsign"
+  signature: "Tensor SoftSign(Tensor x, *)"
+  bind_python: True
+
+- name: "softsign_grad"
+  signature: "Tensor SoftSignGrad(Tensor x, Tensor dy, *)"
+  bind_python: False
+  
 - name: "diag"
   signature: "Tensor Diag(Tensor x, *, Int32 diagonal=0)"
   bind_python: True
diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp
index 7f8bcd10d..c12981038 100644
--- a/oneflow/core/functional/impl/activation_functor.cpp
+++ b/oneflow/core/functional/impl/activation_functor.cpp
@@ -210,6 +210,56 @@ class LeakyReluGradFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class SiluFunctor : public UnaryFunctor {
+ public:
+  SiluFunctor() { op_ = CHECK_JUST(one::OpBuilder("silu").Input("in").Output("out").Build()); }
+};
+
+class SiluGradFunctor : public BinaryFunctor {
+ public:
+  SiluGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("silu_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+};
+
+class MishFunctor : public UnaryFunctor {
+ public:
+  MishFunctor() { op_ = CHECK_JUST(one::OpBuilder("mish").Input("in").Output("out").Build()); }
+};
+
+class MishGradFunctor : public BinaryFunctor {
+ public:
+  MishGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("mish_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+};
+
+class SeluFunctor : public UnaryFunctor {
+ public:
+  SeluFunctor() { op_ = CHECK_JUST(one::OpBuilder("selu").Input("in").Output("out").Build()); }
+};
+
+class SeluGradFunctor : public BinaryFunctor {
+ public:
+  SeluGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("selu_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+};
+
+class SoftSignFunctor : public UnaryFunctor {
+ public:
+  SoftSignFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("softsign").Input("in").Output("out").Build());
+  }
+};
+
+class SoftSignGradFunctor : public BinaryFunctor {
+ public:
+  SoftSignGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("softsign_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+};
+
 }  // namespace impl
 
 ONEFLOW_FUNCTION_LIBRARY(m) {
@@ -229,6 +279,14 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::HardSwishGradFunctor>("HardSwishGrad");
   m.add_functor<impl::LeakyReluFunctor>("LeakyRelu");
   m.add_functor<impl::LeakyReluGradFunctor>("LeakyReluGrad");
+  m.add_functor<impl::SiluFunctor>("Silu");
+  m.add_functor<impl::SiluGradFunctor>("SiluGrad");
+  m.add_functor<impl::MishFunctor>("Mish");
+  m.add_functor<impl::MishGradFunctor>("MishGrad");
+  m.add_functor<impl::SeluFunctor>("Selu");
+  m.add_functor<impl::SeluGradFunctor>("SeluGrad");
+  m.add_functor<impl::SoftSignFunctor>("SoftSign");
+  m.add_functor<impl::SoftSignGradFunctor>("SoftSignGrad");
 };
 
 }  // namespace functional
diff --git a/oneflow/user/kernels/activation_kernels.cpp b/oneflow/user/kernels/activation_kernels.cpp
index 997489ab6..f835136a9 100644
--- a/oneflow/user/kernels/activation_kernels.cpp
+++ b/oneflow/user/kernels/activation_kernels.cpp
@@ -21,7 +21,11 @@ namespace oneflow {
   REGISTER_ELU_KERNEL(DeviceType::kCPU, dtype);         \
   REGISTER_HARDSWISH_KERNEL(DeviceType::kCPU, dtype);   \
   REGISTER_HARDSIGMOID_KERNEL(DeviceType::kCPU, dtype); \
-  REGISTER_HARDTANH_KERNEL(DeviceType::kCPU, dtype);
+  REGISTER_HARDTANH_KERNEL(DeviceType::kCPU, dtype);    \
+  REGISTER_MISH_KERNEL(DeviceType::kCPU, dtype);        \
+  REGISTER_SILU_KERNEL(DeviceType::kCPU, dtype);        \
+  REGISTER_SELU_KERNEL(DeviceType::kCPU, dtype);        \
+  REGISTER_SOFTSIGN_KERNEL(DeviceType::kCPU, dtype);
 
 REGISTER_ACTIVATION_CPU_KERNEL(float);
 REGISTER_ACTIVATION_CPU_KERNEL(double);
diff --git a/oneflow/user/kernels/activation_kernels.cu b/oneflow/user/kernels/activation_kernels.cu
index 0001d315b..db292b930 100644
--- a/oneflow/user/kernels/activation_kernels.cu
+++ b/oneflow/user/kernels/activation_kernels.cu
@@ -56,11 +56,87 @@ struct HardswishGradFunctor<half> {
   }
 };
 
+template<>
+struct MishFunctor<half> {
+  OF_DEVICE_FUNC explicit MishFunctor() : float_functor(MishFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x) const {
+    return __float2half(float_functor(__half2float(x)));
+  }
+  MishFunctor<float> float_functor;
+};
+
+template<>
+struct MishGradFunctor<half> {
+  OF_DEVICE_FUNC explicit MishGradFunctor() : float_functor(MishGradFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x, half dy) const {
+    return __float2half(float_functor(__half2float(x), __half2float(dy)));
+  }
+  MishGradFunctor<float> float_functor;
+};
+
+template<>
+struct SiluFunctor<half> {
+  OF_DEVICE_FUNC explicit SiluFunctor() : float_functor(SiluFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x) const {
+    return __float2half(float_functor(__half2float(x)));
+  }
+  SiluFunctor<float> float_functor;
+};
+
+template<>
+struct SiluGradFunctor<half> {
+  OF_DEVICE_FUNC explicit SiluGradFunctor() : float_functor(SiluGradFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x, half dy) const {
+    return __float2half(float_functor(__half2float(x), __half2float(dy)));
+  }
+  SiluGradFunctor<float> float_functor;
+};
+
+template<>
+struct SeluFunctor<half> {
+  OF_DEVICE_FUNC explicit SeluFunctor() : float_functor(SeluFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x) const {
+    return __float2half(float_functor(__half2float(x)));
+  }
+  SeluFunctor<float> float_functor;
+};
+
+template<>
+struct SeluGradFunctor<half> {
+  OF_DEVICE_FUNC explicit SeluGradFunctor() : float_functor(SeluGradFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x, half dy) const {
+    return __float2half(float_functor(__half2float(x), __half2float(dy)));
+  }
+  SeluGradFunctor<float> float_functor;
+};
+
+template<>
+struct SoftSignFunctor<half> {
+  OF_DEVICE_FUNC explicit SoftSignFunctor() : float_functor(SoftSignFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x) const {
+    return __float2half(float_functor(__half2float(x)));
+  }
+  SoftSignFunctor<float> float_functor;
+};
+
+template<>
+struct SoftSignGradFunctor<half> {
+  OF_DEVICE_FUNC explicit SoftSignGradFunctor() : float_functor(SoftSignGradFunctor<float>()) {}
+  OF_DEVICE_FUNC half operator()(half x, half dy) const {
+    return __float2half(float_functor(__half2float(x), __half2float(dy)));
+  }
+  SoftSignGradFunctor<float> float_functor;
+};
+
 #define REGISTER_ACTIVATION_GPU_KERNEL(dtype)           \
   REGISTER_ELU_KERNEL(DeviceType::kGPU, dtype);         \
   REGISTER_HARDSWISH_KERNEL(DeviceType::kGPU, dtype);   \
   REGISTER_HARDSIGMOID_KERNEL(DeviceType::kGPU, dtype); \
-  REGISTER_HARDTANH_KERNEL(DeviceType::kGPU, dtype);
+  REGISTER_HARDTANH_KERNEL(DeviceType::kGPU, dtype);    \
+  REGISTER_MISH_KERNEL(DeviceType::kGPU, dtype);        \
+  REGISTER_SILU_KERNEL(DeviceType::kGPU, dtype);        \
+  REGISTER_SELU_KERNEL(DeviceType::kGPU, dtype);        \
+  REGISTER_SOFTSIGN_KERNEL(DeviceType::kGPU, dtype);
 
 REGISTER_ACTIVATION_GPU_KERNEL(half);
 REGISTER_ACTIVATION_GPU_KERNEL(float);
diff --git a/oneflow/user/kernels/activation_kernels.h b/oneflow/user/kernels/activation_kernels.h
index 5d4ebb99f..b13d5f9e4 100644
--- a/oneflow/user/kernels/activation_kernels.h
+++ b/oneflow/user/kernels/activation_kernels.h
@@ -111,6 +111,80 @@ struct HardtanhGradFunctor {
   const T max_val;
 };
 
+template<typename T>
+struct MishFunctor {
+  OF_DEVICE_FUNC explicit MishFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x) const {
+    T soft_plus_val = log(static_cast<T>(1) + exp(x));
+    T exp_val = exp(soft_plus_val);
+    T neg_exp_val = exp(-soft_plus_val);
+    T tanh_val = (exp_val - neg_exp_val) / (exp_val + neg_exp_val);
+    return x * tanh_val;
+  }
+};
+
+template<typename T>
+struct MishGradFunctor {
+  OF_DEVICE_FUNC explicit MishGradFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x, T dy) const {
+    T sp = log(static_cast<T>(1) + exp(x));
+    T grad_sp = static_cast<T>(1) - exp(-sp);
+    T tsp = (exp(sp) - exp(-sp)) / (exp(sp) + exp(-sp));
+    T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp;
+    return dy * (x * grad_tsp + tsp);
+  }
+};
+
+template<typename T>
+struct SiluFunctor {
+  OF_DEVICE_FUNC explicit SiluFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x) const { return (x / (static_cast<T>(1) + exp(-x))); }
+};
+
+template<typename T>
+struct SiluGradFunctor {
+  OF_DEVICE_FUNC explicit SiluGradFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x, T dy) const {
+    T sig = static_cast<T>(1) / (static_cast<T>(1) + exp(-x));
+    return dy * (sig * (static_cast<T>(1) + x * (static_cast<T>(1) - sig)));
+  }
+};
+
+template<typename T>
+struct SeluFunctor {
+  OF_DEVICE_FUNC explicit SeluFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x) const {
+    return (x > static_cast<T>(0)) ? scale * x : scale * alpha * (exp(x) - static_cast<T>(1));
+  }
+  const T scale = 1.0507009873554804934193349852946;
+  const T alpha = 1.6732632423543772848170429916717;
+};
+
+template<typename T>
+struct SeluGradFunctor {
+  OF_DEVICE_FUNC explicit SeluGradFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x, T dy) const {
+    return (x > static_cast<T>(0)) ? scale * dy : dy * scale * alpha * (exp(x));
+  }
+  const T scale = 1.0507009873554804934193349852946;
+  const T alpha = 1.6732632423543772848170429916717;
+};
+
+template<typename T>
+struct SoftSignFunctor {
+  OF_DEVICE_FUNC explicit SoftSignFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x) const { return x / (static_cast<T>(1) + abs(x)); }
+};
+
+template<typename T>
+struct SoftSignGradFunctor {
+  OF_DEVICE_FUNC explicit SoftSignGradFunctor() {}
+  OF_DEVICE_FUNC T operator()(T x, T dy) const {
+    T val = (static_cast<T>(1) + abs(x));
+    return dy / (val * val);
+  }
+};
+
 #define REGISTER_ELU_KERNEL(device, dtype)                        \
   REGISTER_UNARY_ELEMWISE_USER_KERNEL(                            \
       device, "elu", EluFunctor, dtype, dtype,                    \
@@ -179,6 +253,42 @@ struct HardtanhGradFunctor {
         return Maybe<void>::Ok();                                                               \
       });
 
+#define REGISTER_MISH_KERNEL(device, dtype)                                                   \
+  REGISTER_UNARY_ELEMWISE_USER_KERNEL(                                                        \
+      device, "mish", MishFunctor, dtype, dtype,                                              \
+      [](user_op::KernelComputeContext* ctx) { return MishFunctor<dtype>(); }, "out", "in");  \
+  REGISTER_BINARY_ELEMWISE_USER_KERNEL(                                                       \
+      device, "mish_grad", MishGradFunctor, dtype, dtype, dtype,                              \
+      [](user_op::KernelComputeContext* ctx) { return MishGradFunctor<dtype>(); }, "dx", "x", \
+      "dy");
+
+#define REGISTER_SILU_KERNEL(device, dtype)                                                   \
+  REGISTER_UNARY_ELEMWISE_USER_KERNEL(                                                        \
+      device, "silu", SiluFunctor, dtype, dtype,                                              \
+      [](user_op::KernelComputeContext* ctx) { return SiluFunctor<dtype>(); }, "out", "in");  \
+  REGISTER_BINARY_ELEMWISE_USER_KERNEL(                                                       \
+      device, "silu_grad", SiluGradFunctor, dtype, dtype, dtype,                              \
+      [](user_op::KernelComputeContext* ctx) { return SiluGradFunctor<dtype>(); }, "dx", "x", \
+      "dy");
+
+#define REGISTER_SELU_KERNEL(device, dtype)                                                   \
+  REGISTER_UNARY_ELEMWISE_USER_KERNEL(                                                        \
+      device, "selu", SeluFunctor, dtype, dtype,                                              \
+      [](user_op::KernelComputeContext* ctx) { return SeluFunctor<dtype>(); }, "out", "in");  \
+  REGISTER_BINARY_ELEMWISE_USER_KERNEL(                                                       \
+      device, "selu_grad", SeluGradFunctor, dtype, dtype, dtype,                              \
+      [](user_op::KernelComputeContext* ctx) { return SeluGradFunctor<dtype>(); }, "dx", "x", \
+      "dy");
+
+#define REGISTER_SOFTSIGN_KERNEL(device, dtype)                                                   \
+  REGISTER_UNARY_ELEMWISE_USER_KERNEL(                                                            \
+      device, "softsign", SoftSignFunctor, dtype, dtype,                                          \
+      [](user_op::KernelComputeContext* ctx) { return SoftSignFunctor<dtype>(); }, "out", "in");  \
+  REGISTER_BINARY_ELEMWISE_USER_KERNEL(                                                           \
+      device, "softsign_grad", SoftSignGradFunctor, dtype, dtype, dtype,                          \
+      [](user_op::KernelComputeContext* ctx) { return SoftSignGradFunctor<dtype>(); }, "dx", "x", \
+      "dy");
+
 }  // namespace oneflow
 
 #endif  // _ONEFLOW_USER_KERNELS_ACTIVATION_KERNELS_H_
diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp
new file mode 100644
index 000000000..4f51ca760
--- /dev/null
+++ b/oneflow/user/ops/mish_op.cpp
@@ -0,0 +1,92 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+
+namespace {
+
+REGISTER_USER_OP("mish")
+    .Input("in")
+    .Output("out")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0);
+      FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("in", 0), i)
+            .Split(user_op::OpArg("out", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("mish_grad")
+    .Input("x")
+    .Input("dy")
+    .Output("dx")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& x_shape = ctx->InputShape("x", 0);
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK(dy_shape == x_shape);
+      *dx_shape = dy_shape;
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+      FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("x", 0), i)
+            .Split(user_op::OpArg("dy", 0), i)
+            .Split(user_op::OpArg("dx", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0));
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP_GRAD("mish").SetBackwardOpConfGenFn(
+    [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> {
+      const auto mish_grad_op_name = ctx->FwOp().op_name() + "_grad";
+      ctx->DefineOp(mish_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) {
+        return builder.OpTypeName("mish_grad")
+            .InputBind("x", ctx->FwOp().input("in", 0))
+            .InputBind("dy", ctx->FwOp().output_grad("out", 0))
+            .Output("dx")
+            .Build();
+      });
+      ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
+                                [&ctx, &mish_grad_op_name]() -> const std::string& {
+                                  return ctx->GetOp(mish_grad_op_name).output("dx", 0);
+                                });
+      return Maybe<void>::Ok();
+    });
+
+}  // namespace
+
+}  // namespace oneflow
diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp
new file mode 100644
index 000000000..8697cbf39
--- /dev/null
+++ b/oneflow/user/ops/selu_op.cpp
@@ -0,0 +1,91 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+
+namespace {
+
+REGISTER_USER_OP("selu")
+    .Input("in")
+    .Output("out")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0);
+      FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("in", 0), i)
+            .Split(user_op::OpArg("out", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("selu_grad")
+    .Input("x")
+    .Input("dy")
+    .Output("dx")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& x_shape = ctx->InputShape("x", 0);
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK(dy_shape == x_shape);
+      *dx_shape = dy_shape;
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+      FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("x", 0), i)
+            .Split(user_op::OpArg("dy", 0), i)
+            .Split(user_op::OpArg("dx", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0));
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP_GRAD("selu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) {
+  const auto selu_grad_op_name = ctx->FwOp().op_name() + "_grad";
+  ctx->DefineOp(selu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) {
+    return builder.OpTypeName("selu_grad")
+        .InputBind("x", ctx->FwOp().input("in", 0))
+        .InputBind("dy", ctx->FwOp().output_grad("out", 0))
+        .Output("dx")
+        .Build();
+  });
+  ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
+                            [&ctx, &selu_grad_op_name]() -> const std::string& {
+                              return ctx->GetOp(selu_grad_op_name).output("dx", 0);
+                            });
+  return Maybe<void>::Ok();
+});
+
+}  // namespace
+
+}  // namespace oneflow
diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp
new file mode 100644
index 000000000..eb46ab7b4
--- /dev/null
+++ b/oneflow/user/ops/silu_op.cpp
@@ -0,0 +1,91 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+
+namespace {
+
+REGISTER_USER_OP("silu")
+    .Input("in")
+    .Output("out")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0);
+      FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("in", 0), i)
+            .Split(user_op::OpArg("out", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("silu_grad")
+    .Input("x")
+    .Input("dy")
+    .Output("dx")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& x_shape = ctx->InputShape("x", 0);
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK(dy_shape == x_shape);
+      *dx_shape = dy_shape;
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+      FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("x", 0), i)
+            .Split(user_op::OpArg("dy", 0), i)
+            .Split(user_op::OpArg("dx", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0));
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP_GRAD("silu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) {
+  const auto silu_grad_op_name = ctx->FwOp().op_name() + "_grad";
+  ctx->DefineOp(silu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) {
+    return builder.OpTypeName("silu_grad")
+        .InputBind("x", ctx->FwOp().input("in", 0))
+        .InputBind("dy", ctx->FwOp().output_grad("out", 0))
+        .Output("dx")
+        .Build();
+  });
+  ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
+                            [&ctx, &silu_grad_op_name]() -> const std::string& {
+                              return ctx->GetOp(silu_grad_op_name).output("dx", 0);
+                            });
+  return Maybe<void>::Ok();
+});
+
+}  // namespace
+
+}  // namespace oneflow
diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp
new file mode 100644
index 000000000..324980302
--- /dev/null
+++ b/oneflow/user/ops/softsign_op.cpp
@@ -0,0 +1,91 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+
+namespace {
+
+REGISTER_USER_OP("softsign")
+    .Input("in")
+    .Output("out")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0);
+      FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("in", 0), i)
+            .Split(user_op::OpArg("out", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("softsign_grad")
+    .Input("x")
+    .Input("dy")
+    .Output("dx")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& x_shape = ctx->InputShape("x", 0);
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK(dy_shape == x_shape);
+      *dx_shape = dy_shape;
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+      FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {
+        ctx->NewBuilder()
+            .Split(user_op::OpArg("x", 0), i)
+            .Split(user_op::OpArg("dy", 0), i)
+            .Split(user_op::OpArg("dx", 0), i)
+            .Build();
+      }
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0));
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP_GRAD("softsign").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) {
+  const auto softsign_grad_op_name = ctx->FwOp().op_name() + "_grad";
+  ctx->DefineOp(softsign_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) {
+    return builder.OpTypeName("softsign_grad")
+        .InputBind("x", ctx->FwOp().input("in", 0))
+        .InputBind("dy", ctx->FwOp().output_grad("out", 0))
+        .Output("dx")
+        .Build();
+  });
+  ctx->FwOp().InputGradBind(user_op::OpArg("in", 0),
+                            [&ctx, &softsign_grad_op_name]() -> const std::string& {
+                              return ctx->GetOp(softsign_grad_op_name).output("dx", 0);
+                            });
+  return Maybe<void>::Ok();
+});
+
+}  // namespace
+
+}  // namespace oneflow
diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py
index 2a8e696c0..8e6f9d5cf 100644
--- a/python/oneflow/__init__.py
+++ b/python/oneflow/__init__.py
@@ -216,6 +216,12 @@ from oneflow.nn.modules.activation import mish_op as mish
 from oneflow.nn.modules.activation import sigmoid_op as sigmoid
 from oneflow.nn.modules.activation import softmax_op as softmax
 from oneflow.nn.modules.activation import tanh_op as tanh
+from oneflow.nn.modules.activation import silu_op as silu
+from oneflow.nn.modules.activation import selu_op as selu
+from oneflow.nn.modules.activation import softsign_op as softsign
+from oneflow.nn.modules.activation import mish_op as mish
+
+
 from oneflow.nn.modules.adaptive_pool import (
     adaptive_avg_pool1d,
     adaptive_avg_pool2d,
diff --git a/python/oneflow/nn/__init__.py b/python/oneflow/nn/__init__.py
index 0993b8e4e..8401b05fb 100644
--- a/python/oneflow/nn/__init__.py
+++ b/python/oneflow/nn/__init__.py
@@ -32,6 +32,9 @@ from oneflow.nn.modules.activation import (
     Softmax,
     Softplus,
     Tanh,
+    SELU,
+    SiLU,
+    Softsign,
 )
 from oneflow.nn.modules.adaptive_pool import (
     AdaptiveAvgPool1d,
diff --git a/python/oneflow/nn/modules/activation.py b/python/oneflow/nn/modules/activation.py
index 406d0187e..8990c0922 100644
--- a/python/oneflow/nn/modules/activation.py
+++ b/python/oneflow/nn/modules/activation.py
@@ -923,11 +923,11 @@ class Mish(Module):
     """
 
     def __init__(self, inplace: bool = False):
-        assert not inplace, "In-place operation is not currently supported"
+        self.inplace = inplace
         super().__init__()
 
     def forward(self, x):
-        return x * flow.tanh(flow.softplus(x))
+        return flow.F.mish(x)
 
 
 def mish_op(x):
@@ -953,6 +953,216 @@ def mish_op_tensor(x):
     return Mish()(x)
 
 
+class SiLU(Module):
+    r"""SiLU(Swish) activation:
+
+    .. math::
+    
+        \text{SiLU}(x) = x * sigmoid(x)
+    
+    .. note::
+        See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
+        where the SiLU (Sigmoid Linear Unit) was originally coined, and see
+        `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
+        in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
+        a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
+        where the SiLU was experimented with later.
+    
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+    
+    For example:
+    
+    .. code-block:: python
+    
+        >>> import numpy as np
+        >>> import oneflow as flow
+
+
+        >>> x = np.array([1, 2, 3]).astype(np.float32)
+        >>> input = flow.Tensor(x)
+        >>> silu = flow.nn.SiLU()
+        >>> out = silu(input)
+        >>> out
+        tensor([0.7311, 1.7616, 2.8577], dtype=oneflow.float32)
+    """
+
+    def __init__(self, inplace: bool = False):
+        self.inplace = inplace
+        super().__init__()
+
+    def forward(self, x):
+        return flow.F.silu(x)
+
+
+def silu_op(x):
+    r"""SiLU(Swish) activation:
+
+    .. math::
+        \text{SiLU}(x) = x * sigmoid(x)
+    
+    .. note::
+    
+        See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
+        where the SiLU (Sigmoid Linear Unit) was originally coined, and see
+        `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
+        in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
+        a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
+        where the SiLU was experimented with later.
+    
+    See :mod:`oneflow.nn.SiLU`
+    """
+
+    return SiLU()(x)
+
+
+@register_tensor_op("silu")
+def silu_op_tensor(x):
+    r"""
+    silu() -> Tensor
+    See :func:`oneflow.silu`
+    """
+    return SiLU()(x)
+
+
+class SELU(Module):
+    r"""Applies the element-wise function:
+
+    The formula is: 
+    
+    .. math::  
+    
+        \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
+    
+    with :math:`\alpha = 1.6732632423543772848170429916717` and
+    
+    :math:`\text{scale} = 1.0507009873554804934193349852946`.
+    
+    .. warning::
+    
+        When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
+        ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
+        in order to get `Self-Normalizing Neural Networks`_.
+        See :func:`torch.nn.init.calculate_gain` for more information.
+    
+    More details can be found in the paper `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_.
+    
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+    
+    For example:
+    
+    .. code-block:: python
+    
+        >>> import numpy as np
+        >>> import oneflow as flow
+        >>> x = np.array([1, 2, 3]).astype(np.float32)
+        >>> input = flow.Tensor(x)
+        >>> selu = flow.nn.SELU()
+        >>> out = selu(input)
+        >>> out
+        tensor([1.0507, 2.1014, 3.1521], dtype=oneflow.float32)
+    """
+
+    def __init__(self, inplace: bool = False):
+        self.inplace = inplace
+        super().__init__()
+
+    def forward(self, x):
+        return flow.F.selu(x)
+
+
+def selu_op(x):
+    r"""The SELU activation.
+
+    The formula is: 
+    
+    .. math::  
+    
+        \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
+    
+    with :math:`\alpha = 1.6732632423543772848170429916717` and
+    
+    :math:`\text{scale} = 1.0507009873554804934193349852946`.
+
+    See :mod:`oneflow.nn.SELU`
+    """
+    return SELU()(x)
+
+
+@register_tensor_op("selu")
+def selu_op_tensor(x):
+    r"""
+    selu() -> Tensor
+    
+    See :func:`oneflow.selu`
+    """
+    return SELU()(x)
+
+
+class Softsign(Module):
+    r"""The SoftSign activation.
+
+    The formula is: 
+    
+    .. math::  
+    
+        SoftSign(x) = \frac{x}{1 + |x|}
+    
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+    
+    For example:
+    
+    .. code-block:: python
+    
+        >>> import numpy as np
+        >>> import oneflow as flow
+        >>> x = np.array([1, 2, 3]).astype(np.float32)
+        >>> input = flow.Tensor(x)
+        >>> softsign = flow.nn.Softsign()
+        >>> out = softsign(input)
+        >>> out
+        tensor([0.5   , 0.6667, 0.75  ], dtype=oneflow.float32)
+    """
+
+    def __init__(self, inplace: bool = False):
+        self.inplace = inplace
+        super().__init__()
+
+    def forward(self, x):
+        return flow.F.softsign(x)
+
+
+def softsign_op(x):
+    r"""The SoftSign activation.
+
+    The formula is: 
+    
+    .. math::  
+
+        SoftSign(x) = \frac{x}{1 + |x|}
+    
+    See :mod:`oneflow.nn.Softsign`
+    """
+    return Softsign()(x)
+
+
+@register_tensor_op("softsign")
+def softsign_op_tensor(x):
+    r"""
+    softsign() -> Tensor
+    See :func:`oneflow.softsign`
+    """
+    return Softsign()(x)
+
+
 if __name__ == "__main__":
     import doctest
 
diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py
index a54d81c5f..27bf0df0d 100644
--- a/python/oneflow/test/modules/test_activation.py
+++ b/python/oneflow/test/modules/test_activation.py
@@ -824,6 +824,158 @@ class TestMishModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest(n=5)
+    def test_mish_module_with_random_data(test_case):
+        m = torch.nn.Mish()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor().to(device)
+        y = m(x)
+        return y
+
+
+def _np_silu_grad(x):
+    _sig = 1 / (1 + np.exp(-x))
+    return _sig * (1 + x * (1 - _sig))
+
+
+def _test_silu_impl(test_case, shape, device):
+    m = flow.nn.SiLU()
+    np_input = np.random.randn(*shape)
+    np_out = np_input / (1 + np.exp(-np_input))
+    of_input = flow.Tensor(
+        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
+    )
+    of_out = m(of_input)
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+    of_out = of_out.sum()
+    of_out.backward()
+    test_case.assertTrue(
+        np.allclose(of_input.grad.numpy(), _np_silu_grad(np_input), 1e-5, 1e-5)
+    )
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestSiluModule(flow.unittest.TestCase):
+    def test_silu(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [_test_silu_impl]
+        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
+
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+    @autotest(n=5)
+    def test_silu_module_with_random_data(test_case):
+        m = torch.nn.SiLU()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor().to(device)
+        y = m(x)
+        return y
+
+
+def _np_selu(x):
+    scale = 1.0507009873554804934193349852946
+    alpha = 1.6732632423543772848170429916717
+    return np.where(x < 0, scale * alpha * (np.exp(x) - 1), scale * x)
+
+
+def _np_selu_grad(x):
+    scale = 1.0507009873554804934193349852946
+    alpha = 1.6732632423543772848170429916717
+    return np.where(x < 0, scale * alpha * np.exp(x), scale)
+
+
+def _test_selu_impl(test_case, shape, device):
+    m = flow.nn.SELU()
+    np_input = np.random.randn(*shape)
+    np_out = _np_selu(np_input)
+    of_input = flow.Tensor(
+        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
+    )
+    of_out = m(of_input)
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+    of_out = of_out.sum()
+    of_out.backward()
+    test_case.assertTrue(
+        np.allclose(of_input.grad.numpy(), _np_selu_grad(np_input), 1e-5, 1e-5)
+    )
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestSeluModule(flow.unittest.TestCase):
+    def test_selu(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [_test_selu_impl]
+        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
+
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+    @autotest(n=5)
+    def test_selu_module_with_random_data(test_case):
+        m = torch.nn.SELU()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor().to(device)
+        y = m(x)
+        return y
+
+
+def _np_softsign(x):
+    return x / (1.0 + np.abs(x))
+
+
+def _np_softsign_grad(x):
+    return 1.0 / (np.square(1.0 + np.abs(x)))
+
+
+def _test_softsign_impl(test_case, shape, device):
+    m = flow.nn.Softsign()
+    np_input = np.random.randn(*shape)
+    np_out = _np_softsign(np_input)
+    of_input = flow.Tensor(
+        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
+    )
+    of_out = m(of_input)
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-3, 1e-3))
+
+    of_out = of_out.sum()
+    of_out.backward()
+    test_case.assertTrue(
+        np.allclose(of_input.grad.numpy(), _np_softsign_grad(np_input), 1e-3, 1e-3)
+    )
+
+
+@unittest.skip("still have error in ci test")
+class TestSoftsignModule(flow.unittest.TestCase):
+    @autotest(n=5)
+    def test_softsign_module_with_random_data(test_case):
+        m = torch.nn.Softsign()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor().to(device)
+        y = m(x)
+        return y
+
+    def test_softsign(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [_test_softsign_impl]
+        arg_dict["shape"] = [(3, 3), (2, 3, 3)]
+
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py
index d08f5ac05..6328865ef 100644
--- a/python/oneflow/test/tensor/test_tensor.py
+++ b/python/oneflow/test/tensor/test_tensor.py
@@ -1006,6 +1006,85 @@ class TestTensor(flow.unittest.TestCase):
             )
         )
 
+    @flow.unittest.skip_unless_1n1d()
+    def test_tensor_mish(test_case):
+        def np_mish(x):
+            f = 1 + np.exp(x)
+            y = x * ((f * f - 1) / (f * f + 1))
+            y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / (
+                (f * f + 1) * (f * f + 1)
+            )
+            return [y, y_grad]
+
+        np_input = np.random.randn(2, 4, 5, 6,)
+        of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
+        of_out = of_input.mish()
+
+        np_out, np_grad = np_mish(np_input)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+        of_out = of_out.sum()
+        of_out.backward()
+        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+    @flow.unittest.skip_unless_1n1d()
+    def test_tensor_silu(test_case):
+        def np_silu(x):
+            _sig = 1 / (1 + np.exp(-x))
+            y = x * _sig
+            y_grad = _sig * (1 + x * (1 - _sig))
+            return [y, y_grad]
+
+        np_input = np.random.randn(2, 4, 5, 6,)
+        of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
+        of_out = of_input.silu()
+
+        np_out, np_grad = np_silu(np_input)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+        of_out = of_out.sum()
+        of_out.backward()
+        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+    @flow.unittest.skip_unless_1n1d()
+    def test_tensor_selu(test_case):
+        _scale = 1.0507009873554804934193349852946
+        _alpha = 1.6732632423543772848170429916717
+
+        def np_selu(x):
+            y = np.where(x < 0, _scale * _alpha * (np.exp(x) - 1), _scale * x)
+            y_grad = np.where(x < 0, _scale * _alpha * np.exp(x), _scale)
+            return [y, y_grad]
+
+        np_input = np.random.randn(2, 4, 5, 6,)
+        of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
+        of_out = of_input.selu()
+
+        np_out, np_grad = np_selu(np_input)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+        of_out = of_out.sum()
+        of_out.backward()
+        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+    @unittest.skip("still have error in ci")
+    def test_tensor_softsign(test_case):
+        def np_softsign(x):
+            y = x / (1 + np.abs(x))
+            y_grad = 1 / np.square(1 + np.abs(x))
+            return [y, y_grad]
+
+        np_input = np.random.randn(2, 4, 5, 6,)
+        of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
+        of_out = of_input.softsign()
+
+        np_out, np_grad = np_softsign(np_input)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+        of_out = of_out.sum()
+        of_out.backward()
+        test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
 
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab