diff --git a/oneflow/api/python/framework/throw.h b/oneflow/api/python/framework/throw.h
index 7d8a82aa3397b8c363a891d33fa8418ad38dbbbc..a0def93dd3fe92357fea6d7a692712eb5d99df0e 100644
--- a/oneflow/api/python/framework/throw.h
+++ b/oneflow/api/python/framework/throw.h
@@ -36,7 +36,7 @@ class Throw final {
 #define CHECK_OR_THROW(expr)                                                                      \
   if (!(expr))                                                                                    \
   Throw(oneflow::Error::CheckFailedError().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__)).error() \
-      << " Check failed: " << OF_PP_STRINGIZE(expr) << "\t"
+      << " Check failed: " << OF_PP_STRINGIZE(expr) << ": "
 
 #define CHECK_EQ_OR_THROW(lhs, rhs) \
   CHECK_OR_THROW((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
diff --git a/oneflow/core/common/error.cpp b/oneflow/core/common/error.cpp
index a15c225961709269d400198b1cd84a780ef36766..b589b33552dce4162f8ff31ae6de0e3b70e1c219 100644
--- a/oneflow/core/common/error.cpp
+++ b/oneflow/core/common/error.cpp
@@ -70,6 +70,12 @@ Error Error::ValueError(const std::string& error_summary) {
   return error;
 }
 
+Error Error::IndexError() {
+  auto error = std::make_shared<cfg::ErrorProto>();
+  error->mutable_index_error();
+  return error;
+}
+
 Error Error::JobNameExistError() {
   auto error = std::make_shared<cfg::ErrorProto>();
   error->mutable_job_name_exist_error();
diff --git a/oneflow/core/common/error.h b/oneflow/core/common/error.h
index 93e4470c4e8450a59442c637e8f9c38d2acb6f0e..5a878e3f9c634507437b90cfb09c780cede64d34 100644
--- a/oneflow/core/common/error.h
+++ b/oneflow/core/common/error.h
@@ -43,6 +43,7 @@ class Error final {
   static Error JobSetEmptyError();
   static Error DeviceTagNotFoundError();
   static Error ValueError(const std::string& error_summary);
+  static Error IndexError();
   static Error JobNameExistError();
   static Error JobNameEmptyError();
   static Error JobNameNotEqualError();
diff --git a/oneflow/core/common/error.proto b/oneflow/core/common/error.proto
index d090915c9daaa50f37747c17a5b6e9ffc815ebce..81dd7b622576a5d19e5e72f0b29fa0790d3b2574 100644
--- a/oneflow/core/common/error.proto
+++ b/oneflow/core/common/error.proto
@@ -125,6 +125,8 @@ message SymbolIdUninitializedError {}
 
 message ValueError {}
 
+message IndexError {}
+
 message ErrorProto {
   optional string error_summary = 1 [default = ""];
   optional string msg = 2 [default = ""];
@@ -145,6 +147,7 @@ message ErrorProto {
     JobSetEmptyError job_set_empty_error = 25;
     DeviceTagNotFoundError device_tag_not_found_error = 26;
     ValueError value_error = 27;
+    IndexError index_error = 28;
     JobNameExistError job_name_exist_error = 100;
     JobNameEmptyError job_name_empty_error = 101;
     JobNameNotEqualError job_name_not_equal_error = 102;
diff --git a/oneflow/core/common/exception.h b/oneflow/core/common/exception.h
index 084d02126c52351f64550970f1f8ef8a31bacab1..e1e1dee1fb7863e11bcaf5903c2e9d66bac1acaa 100644
--- a/oneflow/core/common/exception.h
+++ b/oneflow/core/common/exception.h
@@ -69,6 +69,7 @@ class Exception : public std::exception {
   OF_PP_MAKE_TUPLE_SEQ(Unknown)                   \
   OF_PP_MAKE_TUPLE_SEQ(CompileOptionWrong)        \
   OF_PP_MAKE_TUPLE_SEQ(Value)                     \
+  OF_PP_MAKE_TUPLE_SEQ(Index)                     \
   OF_PP_MAKE_TUPLE_SEQ(InputDeviceNotMatch)
 
 #define DEFINE_EXCEPTION_CLASS(cls)                                         \
diff --git a/oneflow/core/common/shape.cpp b/oneflow/core/common/shape.cpp
index 949935d361acdab2656b00fd6cd8b8e36c05571d..f3f5cec1947f4d85548ee0ee75b03d956793dc1c 100644
--- a/oneflow/core/common/shape.cpp
+++ b/oneflow/core/common/shape.cpp
@@ -184,4 +184,14 @@ bool Shape::Containing(const Shape& small_shape) const {
   return true;
 }
 
+Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {
+  CHECK_OR_RETURN(start_dim >= 0 && end_dim >= start_dim);
+  int64_t ndims = this->NumAxes();
+  if (start_dim > ndims) { start_dim = ndims; }
+  if (end_dim > ndims) { end_dim = ndims; }
+  DimVector dim_vec;
+  for (int64_t i = start_dim; i < end_dim && i < ndims; ++i) { dim_vec.push_back(this->At(i)); }
+  return std::make_shared<Shape>(dim_vec);
+}
+
 }  // namespace oneflow
diff --git a/oneflow/core/common/shape.h b/oneflow/core/common/shape.h
index 513ec410f281982a86bc05cbc25f9bdc47c54f3a..9b1a9f46c18d07bd196845d0ca5afb657bddddc8 100644
--- a/oneflow/core/common/shape.h
+++ b/oneflow/core/common/shape.h
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include "oneflow/core/common/shape.pb.h"
 #include "oneflow/core/common/util.h"
+#include "oneflow/core/common/maybe.h"
 #include "oneflow/core/common/shape_vec.h"
 
 namespace oneflow {
@@ -68,6 +69,8 @@ class Shape final {
 
   bool Containing(const Shape& small_shape) const;
 
+  Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const;
+
  private:
   void UpdateElemCnt();
 
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 9eaa451c9deaac2822aa75aa1bd4749ebd8770a1..a936c88fd3e54870a6754d1cb2e516a01a0be796 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -313,7 +313,7 @@
   bind_python: False
 
 - name: "expand"
-  signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)"
+  signature: "Tensor Expand(Tensor x, *, Shape shape)"
   bind_python: True
 
 - name: "expand_dims"
@@ -694,3 +694,6 @@
   signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
   bind_python: True
 
+- name: "tensor_setitem"
+  signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)"
+  bind_python: True
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index 53ddfa662b9b5cf2e167a6a04fdf742266962841..e7d913b2c2cc9912f43b632749c854aeced047a7 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -167,10 +167,42 @@ class ConcatFunctor {
 class ExpandFunctor {
  public:
   ExpandFunctor() { op_ = CHECK_JUST(one::OpBuilder("expand").Input("in").Output("out").Build()); }
-  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
-                           const std::vector<int32_t>& in_shape,
-                           const std::vector<int32_t>& out_shape,
-                           const std::vector<int32_t>& stride) const {
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const {
+    CHECK_GE_OR_RETURN(shape.NumAxes(), x->shape()->NumAxes())
+        << "The desired expanded dims should not be less than the input dims.";
+    std::vector<int32_t> in_shape(x->shape()->NumAxes());
+    for (int i = 0; i < in_shape.size(); ++i) { in_shape[i] = x->shape()->At(i); }
+
+    // calculate the original stride.
+    std::vector<int32_t> original_stride(in_shape.size(), 1);
+    for (int i = x->shape()->NumAxes() - 2; i >= 0; --i) {
+      original_stride[i] = in_shape.at(i + 1) * original_stride.at(i + 1);
+    }
+    std::vector<int32_t> out_shape(shape.NumAxes());
+    std::vector<int32_t> stride(shape.NumAxes());
+    int shift = out_shape.size() - in_shape.size();
+    for (int i = out_shape.size() - 1; i >= 0; --i) {
+      int index = i - shift;
+      if (index >= 0) {
+        if (shape.At(i) == -1 || shape.At(i) == in_shape.at(index)) {
+          out_shape[i] = in_shape.at(index);
+          stride[i] = original_stride.at(index);
+        } else {
+          CHECK_OR_RETURN(shape.At(i) > 0 && in_shape.at(index) == 1)
+              << "Invalid expand shape " << shape.ToString();
+          out_shape[i] = shape.At(i);
+          stride[i] = 0;
+        }
+      } else {
+        CHECK_GT_OR_RETURN(shape.At(i), 0) << "Invalid expand shape " << shape.ToString();
+        out_shape[i] = shape.At(i);
+        if (shape.At(i) == 1 && i < out_shape.size() - 1) {
+          stride[i] = stride.at(i + 1);
+        } else {
+          stride[i] = 0;
+        }
+      }
+    }
     MutableAttrMap attrs;
     JUST(attrs.SetAttr<std::vector<int32_t>>("in_shape", in_shape));
     JUST(attrs.SetAttr<std::vector<int32_t>>("out_shape", out_shape));
@@ -773,41 +805,24 @@ class TensorGetItemFunctor {
  public:
   TensorGetItemFunctor() {}
   Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index) const {
-    const auto& regular_index = JUST(RegularTensorIndex(index, *(x->shape())));
     int64_t ndims = x->shape()->NumAxes();
-    CHECK_GE_OR_RETURN(regular_index->size(), ndims) << "Tensor index failed to be regularlized.";
+    std::vector<detail::Slice> slice_indices;
+    std::vector<std::shared_ptr<one::Tensor>> tensor_indices;
+    std::vector<int64_t> target_dims;
+
+    JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims));
+    CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices.";
+    Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
+    CHECK_GT_OR_RETURN(target_shape.Count(0), 0)
+        << "Target shape is zero shape which was not supported yet.";
+
     std::vector<int64_t> start(ndims), end(ndims), step(ndims);
-    int dim = 0;
-    DimVector result_dims;
-    for (int i = 0; i < regular_index->size(); ++i) {
-      const auto& index_item = regular_index->at(i);
-      CHECK_OR_RETURN(!index_item.IsEllipsis())
-          << "Tensor index should not have ellipsis once regularlized.";
-      if (index_item.IsSlice()) {
-        CHECK_LT_OR_RETURN(dim, ndims);
-        start[dim] = index_item.slice().start();
-        end[dim] = index_item.slice().end();
-        step[dim] = index_item.slice().step();
-        int64_t length = (end[dim] - start[dim] + step[dim] - 1) / step[dim];
-        result_dims.emplace_back(length);
-        dim++;
-      } else if (index_item.IsInteger()) {
-        CHECK_LT_OR_RETURN(dim, ndims);
-        start[dim] = index_item.integer();
-        end[dim] = start[dim] + 1;
-        step[dim] = 1;
-        dim++;
-      } else if (index_item.IsNone()) {
-        result_dims.emplace_back(1);
-      } else if (index_item.IsBoolean()) {
-        CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
-        result_dims.emplace_back(1);
-      }
+    for (int i = 0; i < ndims; ++i) {
+      const auto& slice = slice_indices.at(i);
+      start[i] = slice.start();
+      end[i] = slice.end();
+      step[i] = slice.step();
     }
-    CHECK_EQ_OR_RETURN(dim, ndims)
-        << "Specified dims count for regularlized tensor index should equal to tensor dimension "
-        << ndims;
-
     bool is_identity = [&]() {
       for (int i = 0; i < ndims; ++i) {
         if (start[i] != 0 || end[i] != x->shape()->At(i) || step[i] != 1) { return false; }
@@ -821,14 +836,69 @@ class TensorGetItemFunctor {
       result = JUST(functional::Slice(x, start, end, step));
     }
 
-    Shape shape(result_dims);
+    Shape shape(DimVector(target_dims.begin(), target_dims.end()));
     if (shape.NumAxes() != 0 && shape != *(result->shape())) {
-      return functional::Reshape(result, shape);
+      result = JUST(functional::Reshape(result, shape));
     }
     return result;
   }
 };
 
+class TensorSetItemFunctor {
+ public:
+  TensorSetItemFunctor() {}
+  Maybe<void> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index,
+                         const std::shared_ptr<one::Tensor>& value) const {
+    int64_t ndims = x->shape()->NumAxes();
+    std::vector<detail::Slice> slice_indices;
+    std::vector<std::shared_ptr<one::Tensor>> tensor_indices;
+    std::vector<int64_t> target_dims;
+
+    JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims));
+    CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices.";
+    Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
+    if (target_shape.Count(0) == 0) { return Maybe<void>::Ok(); }
+
+    const auto& value_shape = value->shape();
+    bool matched = [&]() {
+      for (int i = 0; i < value_shape->NumAxes() - target_shape.NumAxes(); ++i) {
+        if (value_shape->At(i) != 1) { return false; }
+      }
+      return true;
+    }();
+    CHECK_OR_RETURN(matched) << "The tensor size mismatch. Target sizes: "
+                             << target_shape.ToString()
+                             << ", value sizes: " << value_shape->ToString();
+    std::shared_ptr<one::Tensor> value_tensor(value);
+    if (target_shape.NumAxes() != 0 &&  // NOLINT
+        /*need_expand=*/value_shape->Count(0) != target_shape.Count(0)) {
+      // Remove the beginning redundant 1-dimensions.
+      if (value_shape->NumAxes() > target_shape.NumAxes()) {
+        int64_t start_axis = value_shape->NumAxes() - target_shape.NumAxes();
+        const auto& shape = JUST(value_shape->Slice(start_axis, value_shape->NumAxes()));
+        value_tensor = JUST(functional::Reshape(value, *shape));
+      }
+      value_tensor = JUST(functional::Expand(value_tensor, target_shape));
+    }
+
+    std::vector<int64_t> start(ndims), end(ndims), step(ndims);
+    DimVector slice_dims(ndims);
+    for (int i = 0; i < ndims; ++i) {
+      const auto& slice = slice_indices.at(i);
+      start[i] = slice.start();
+      end[i] = slice.end();
+      step[i] = slice.step();
+      slice_dims[i] = (end[i] - start[i] + step[i] - 1) / step[i];
+    }
+    Shape slice_shape(slice_dims);
+    if (slice_shape != *(value_tensor->shape())) {
+      value_tensor = JUST(functional::Reshape(value_tensor, slice_shape));
+    }
+    JUST(LogicalSliceAssign(x, value_tensor, start, end, step));
+    return Maybe<void>::Ok();
+  }
+};
+
 }  // namespace impl
 
 ONEFLOW_FUNCTION_LIBRARY(m) {
@@ -873,6 +943,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::DiagFunctor>("Diag");
   m.add_functor<impl::DiagGradFunctor>("DiagGrad");
   m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
+  m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
 };
 
 }  // namespace functional
diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp
index 2a24a8319b4d083998cbc1d422cb681ea62a78ef..843991963c6445512833bf122c0e814e29663927 100644
--- a/oneflow/core/functional/tensor_index.cpp
+++ b/oneflow/core/functional/tensor_index.cpp
@@ -28,14 +28,15 @@ int64_t CountSpecifiedDims(const TensorIndex& index) {
   return specified_ndims;
 }
 
-Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape) {
-  int64_t specified_ndims = CountSpecifiedDims(index);
+Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
+                                std::vector<detail::Slice>* slice_indices,
+                                std::vector<std::shared_ptr<Tensor>>* tensor_indices,
+                                std::vector<int64_t>* target_dims) {
   int64_t ndims = shape.NumAxes();
+  int64_t specified_ndims = CountSpecifiedDims(index);
   CHECK_LE_OR_RETURN(specified_ndims, ndims)
       << "Too many indices for tensor of dimension " << ndims;
-
-  auto regular_index = std::make_shared<TensorIndex>();
-  int64_t dim = 0;
+  int dim = 0;
   for (int i = 0; i < index.size(); ++i) {
     const auto& index_item = index.at(i);
     if (index_item.IsSlice()) {
@@ -50,37 +51,40 @@ Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& sha
       if (start < 0) { start = 0; }
       if (end < 0) { end += shape.At(dim); }
       if (end < start) { end = start; }
-      regular_index->emplace_back(detail::IndexItem(start, end, step));
+      slice_indices->emplace_back(start, end, step);
+      int64_t length = (end - start + step - 1) / step;
+      target_dims->emplace_back(length);
       dim++;
     } else if (index_item.IsInteger()) {
       CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
       int64_t integer = index_item.integer();
       if (integer < 0) { integer += shape.At(dim); }
-      CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) << Error::ValueError(
-          std::string("Index ") + std::to_string(index_item.integer())
-          + std::string(" is out of bounds for dimension ") + std::to_string(dim)
-          + std::string(" with size ") + std::to_string(shape.At(dim)));
-      regular_index->emplace_back(detail::IndexItem(integer));
+      if (integer < 0 || integer >= shape.At(dim)) {
+        return Error::IndexError()
+               << "Index " << index_item.integer() << " is out of bounds for dimension " << dim
+               << " with size " << shape.At(dim);
+      }
+      slice_indices->emplace_back(integer, integer + 1, 1);
       dim++;
     } else if (index_item.IsEllipsis()) {
       int64_t unspecified_ndims = ndims - specified_ndims;
       unspecified_ndims = std::min(ndims - dim, unspecified_ndims);
       for (int j = 0; j < unspecified_ndims; ++j) {
-        regular_index->emplace_back(detail::IndexItem(0, shape.At(dim + j), 1));
+        slice_indices->emplace_back(0, shape.At(dim + j), 1);
+        target_dims->emplace_back(shape.At(dim + j));
       }
       dim += unspecified_ndims;
-    } else {
-      // None or Boolean.
-      if (index_item.IsBoolean()) {
-        CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
-      }
-      regular_index->emplace_back(index_item);
+    } else if (index_item.IsNone()) {
+      target_dims->emplace_back(1);
+    } else if (index_item.IsBoolean()) {
+      target_dims->emplace_back(index_item.boolean());
     }
   }
   for (int i = dim; i < ndims; ++i) {
-    regular_index->emplace_back(detail::IndexItem(0, shape.At(i), 1));
+    slice_indices->emplace_back(0, shape.At(i), 1);
+    target_dims->emplace_back(shape.At(i));
   }
-  return regular_index;
+  return Maybe<void>::Ok();
 }
 
 }  // namespace functional
diff --git a/oneflow/core/functional/tensor_index.h b/oneflow/core/functional/tensor_index.h
index f2b7e78c8094a6fccfc75f49c27087b884f39568..eba019f8686fb8a110021a0c288c8370f5cabc57 100644
--- a/oneflow/core/functional/tensor_index.h
+++ b/oneflow/core/functional/tensor_index.h
@@ -25,6 +25,9 @@ limitations under the License.
 
 namespace oneflow {
 namespace one {
+
+class Tensor;
+
 namespace functional {
 
 namespace detail {
@@ -94,7 +97,11 @@ class TensorIndex : public std::vector<detail::IndexItem> {
 };
 
 int64_t CountSpecifiedDims(const TensorIndex& index);
-Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape);
+
+Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
+                                std::vector<detail::Slice>* slice_indices,
+                                std::vector<std::shared_ptr<Tensor>>* tensor_indices,
+                                std::vector<int64_t>* target_dims);
 
 }  // namespace functional
 }  // namespace one
diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index 5d2db0e7138b576a5ce9d359fe4594301e7bfc76..04dae678da5078f8a1beb26c9bb9ec18390676aa 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 """
 import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util
-from oneflow._oneflow_internal.exception import ValueException
+from oneflow._oneflow_internal.exception import IndexException
 
 from oneflow.python.oneflow_export import oneflow_export
 import oneflow.python.framework.remote_blob as remote_blob_util
@@ -454,44 +454,17 @@ class Tensor:
     def __getitem__(self, key):
         try:
             return flow.F.tensor_getitem(self, key)
-        except ValueException as e:
+        except IndexException as e:
             # The stop condition of for in python is IndexError,
-            # so we have to catch ValueException from C++ and throw IndexError
+            # so we have to catch IndexException from C++ and throw IndexError
             raise IndexError(e)
 
     @_auto_determine
     @register_local_tensor_method()
     def __setitem__(self, key, value):
-        if isinstance(key, tuple):
-            key = self._transform_ellipsis_type(key)
-            unsqueeze_dims = list(
-                filter(lambda idx: isinstance(key[idx], int), range(len(key)))
-            )
-        elif isinstance(key, int):
-            if key < 0:
-                key = self.shape[0] + key
-            unsqueeze_dims = [0]
-        else:
-            unsqueeze_dims = []
-
-        start, stop, step, shape = self._get_slice_obj(key)
         if isinstance(value, (int, float)):
-            scalar = value
-            value = flow.Tensor(*shape)
-            value.fill_(scalar)
-        else:
-            prepended_broadcasting_dims = range(
-                len(self.shape) - len(unsqueeze_dims) - len(value.shape)
-            )
-            for dim in prepended_broadcasting_dims:
-                value = flow.experimental.unsqueeze(value, dim)
-            for dim in unsqueeze_dims:
-                value = flow.experimental.unsqueeze(value, dim)
-            value = flow.experimental.expand(value, *shape)
-
-        flow.experimental.tmp.logical_slice_assign(
-            self, value, list(zip(start, stop, step))
-        )
+            value = flow.F.constant([1], value, self.dtype)
+        flow.F.tensor_setitem(self, key, value)
         return self
 
     @register_local_tensor_method()
diff --git a/oneflow/python/nn/modules/expand.py b/oneflow/python/nn/modules/expand.py
index acb03b33ea2f490c44f6b98dfdb7c0d1168e7ec8..11e9695392a2e897ca22e2577ba4a7477e3a45ff 100644
--- a/oneflow/python/nn/modules/expand.py
+++ b/oneflow/python/nn/modules/expand.py
@@ -28,39 +28,7 @@ class Expand(Module):
     def forward(self, x):
         if x.dtype == flow.int8:
             x = flow.experimental.cast(x, flow.int32)
-        expand_size = self.expand_size
-        assert len(expand_size) >= len(
-            x.shape
-        ), "The desired expanded dims should not be less than the input dims."
-        # calculate the original stride
-        original_stride = [1]
-        for i in range(len(x.shape) - 2, -1, -1):
-            original_stride.insert(0, original_stride[0] * x.shape[i + 1])
-
-        # calculate the output shape and stride
-        new_size = []
-        new_stride = []
-        diff = len(expand_size) - len(x.shape)
-        for i in range(len(expand_size) - 1, -1, -1):
-            if i >= diff:
-                if expand_size[i] == -1 or expand_size[i] == x.shape[i - diff]:
-                    new_size.insert(0, x.shape[i - diff])
-                    new_stride.insert(0, original_stride[i - diff])
-                else:
-                    assert expand_size[i] >= 1 and x.shape[i - diff] == 1
-                    new_size.insert(0, expand_size[i])
-                    new_stride.insert(0, 0)
-            else:
-                assert expand_size[i] >= 1
-                new_size.insert(0, expand_size[i])
-                if expand_size[i] == 1:
-                    new_stride.insert(0, new_stride[0])
-                else:
-                    new_stride.insert(0, 0)
-
-        return flow.F.expand(
-            x, in_shape=list(x.shape), out_shape=new_size, stride=new_stride
-        )
+        return flow.F.expand(x, self.expand_size)
 
 
 @oneflow_export("expand")