From d57de627dcdb51a5904fca631fc9a46ad527f159 Mon Sep 17 00:00:00 2001
From: Yinggang Wang <wyg19970408@gmail.com>
Date: Sun, 1 Aug 2021 00:43:07 -0500
Subject: [PATCH] Support 0shape tensor (#5620)

* feat(Tensor): support 0shape tensor

* math binary broadcast support emoty tensor input

* slice support empty tensor input and output

* fix check in slice

* test(Cat): add 0shape cat module test

* fix return type error on gcc 4.8.5

Signed-off-by: daquexian <daquexian566@gmail.com>

* auto format by CI

* add module op test for empty tensor, cuda kernel support empty tensor

* format

* feat(ReduceOp): reduce op kernels support 0shape tensor

* delete files added by mistake

* refine if

* refine if

* feat(ConstantOp): constant ops support 0shape tensor

* feat(ReshapeOp): reshape kernel support 0shape tensor

* math binary and unary backward skip when elem equal to zeros

* fix(ReduceOp): fix reduce not memset bug

* support getitem output empty tensor

* fix comment

* getitem support input is empty

* reduce_like kernel support empty

* fix op test bug

* feat(ReduceOp): refine reduce ops initialize value

* format code

* fix triu bug when input is empty

* test(AbsOp): fix test bug

* test(DivOp): fix test bug

* fix clamp bug

* fix test_sub bug

* fix(ReduceOp): fix reduce op memset bug

* auto format by CI

* fix random

Co-authored-by: liufengwei <2472937968@qq.com>
Co-authored-by: daquexian <daquexian566@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
---
 oneflow/api/python/functional/indexing.cpp    |  2 -
 .../core/functional/impl/array_functor.cpp    |  2 -
 oneflow/core/functional/tensor_index.cpp      |  5 +-
 oneflow/core/kernel/kernel_util.cu            |  3 ++
 .../kernel/util/cuda_arithemetic_interface.cu |  1 +
 .../core/kernel/util/cuda_dnn_interface.cu    |  1 +
 .../core/ndarray/ndarray_apply_binary_core.cu |  2 +
 .../ndarray/ndarray_apply_broadcast_binary.h  |  4 +-
 .../ndarray_apply_broadcast_binary_core.cu    |  2 +
 .../ndarray_apply_broadcast_unary_core.cu     |  1 +
 .../core/ndarray/ndarray_apply_unary_core.cu  |  1 +
 oneflow/core/ndarray/ndarray_assign_core.cu   |  1 +
 oneflow/user/kernels/add_n_kernel.cu          |  2 +-
 oneflow/user/kernels/clip_by_value_kernel.cu  |  2 +
 oneflow/user/kernels/concat_kernel.cpp        |  1 +
 oneflow/user/kernels/constant_kernel.cpp      |  3 +-
 oneflow/user/kernels/empty_kernel.cpp         |  4 --
 .../kernels/math_binary_elementwise_kernel.cu |  6 +++
 .../kernels/math_unary_elementwise_kernel.cu  |  4 ++
 oneflow/user/kernels/reduce_kernel.cpp        | 11 +++++
 oneflow/user/kernels/reduce_like_kernels.cpp  | 10 ++++
 oneflow/user/kernels/slice_util.cu            |  1 +
 oneflow/user/kernels/triu_kernel.cu           |  1 +
 .../user/ops/math_binary_broadcast_ops.cpp    |  4 +-
 oneflow/user/ops/reshape_op.cpp               |  2 +-
 oneflow/user/ops/slice_op.cpp                 | 10 ++--
 python/oneflow/framework/tensor.py            |  3 +-
 python/oneflow/nn/modules/reduce_ops.py       |  6 +++
 python/oneflow/ops/array_ops.py               |  2 +-
 python/oneflow/test/modules/test_abs.py       |  8 ++++
 .../oneflow/test/modules/test_activation.py   | 47 +++++++++++++++++++
 python/oneflow/test/modules/test_add.py       |  9 ++++
 python/oneflow/test/modules/test_argwhere.py  |  2 +-
 python/oneflow/test/modules/test_cast.py      | 11 +++++
 python/oneflow/test/modules/test_ceil.py      |  7 +++
 python/oneflow/test/modules/test_clamp.py     |  7 +++
 python/oneflow/test/modules/test_concat.py    |  9 ++++
 python/oneflow/test/modules/test_constant.py  |  8 ++--
 python/oneflow/test/modules/test_div.py       |  9 ++++
 python/oneflow/test/modules/test_eq.py        |  9 ++++
 python/oneflow/test/modules/test_expm1.py     |  7 +++
 python/oneflow/test/modules/test_fmod.py      |  7 +++
 python/oneflow/test/modules/test_greater.py   |  9 ++++
 python/oneflow/test/modules/test_ne.py        | 11 +++++
 python/oneflow/test/modules/test_negative.py  | 10 ++++
 python/oneflow/test/modules/test_pow.py       |  2 +-
 python/oneflow/test/modules/test_reshape.py   | 10 ++++
 python/oneflow/test/modules/test_sign.py      |  8 ++++
 python/oneflow/test/modules/test_squeeze.py   |  7 +++
 python/oneflow/test/modules/test_sub.py       | 12 +++++
 python/oneflow/test/modules/test_sum.py       | 17 +++++--
 python/oneflow/test/modules/test_transpose.py |  7 +++
 python/oneflow/test/modules/test_triu.py      |  8 ++++
 python/oneflow/test/modules/test_unsqueeze.py |  7 +++
 .../torch_flow_dual_object.py                 | 17 +++++--
 55 files changed, 326 insertions(+), 36 deletions(-)

diff --git a/oneflow/api/python/functional/indexing.cpp b/oneflow/api/python/functional/indexing.cpp
index 0457fcaca..b8e53d844 100644
--- a/oneflow/api/python/functional/indexing.cpp
+++ b/oneflow/api/python/functional/indexing.cpp
@@ -50,8 +50,6 @@ Maybe<void> PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop,
     CHECK_OR_RETURN(_PyEval_SliceIndex(obj->stop, stop))
         << "Invalid slice " << PyStringAsString(PyObject_Repr(object));
   }
-  CHECK_LT_OR_RETURN(*start, *stop)
-      << "Slice stop must be greater than start since 0 size shape is not allowed currently.";
   return Maybe<void>::Ok();
 }
 
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index 41ac1c247..001cf7d3f 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -1083,8 +1083,6 @@ class TensorGetItemFunctor {
     JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims));
     CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices.";
     Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
-    CHECK_GT_OR_RETURN(target_shape.Count(0), 0)
-        << "Target shape is zero shape which was not supported yet.";
 
     std::vector<int64_t> start(ndims), end(ndims), step(ndims);
     for (int i = 0; i < ndims; ++i) {
diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp
index 251635eee..0fa495fe7 100644
--- a/oneflow/core/functional/tensor_index.cpp
+++ b/oneflow/core/functional/tensor_index.cpp
@@ -65,18 +65,17 @@ Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
     }
     CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
     if (index_item.IsSlice()) {
-      CHECK_GT_OR_RETURN(shape.At(dim), 0) << "Slice cannot be applied to a 0-dim tensor.";
       const auto& slice = index_item.slice();
       int64_t step = std::min(slice.step(), shape.At(dim));
-      CHECK_GT_OR_RETURN(step, 0) << "Step must be greater than zero.";
       int64_t end = std::min(slice.end(), shape.At(dim));
       int64_t start = std::min(slice.start(), shape.At(dim));
       if (start < 0) { start += shape.At(dim); }
       if (start < 0) { start = 0; }
       if (end < 0) { end += shape.At(dim); }
       if (end < start) { end = start; }
+      if (start == end) { step = 1; }
       slice_indices->emplace_back(start, end, step);
-      int64_t length = (end - start + step - 1) / step;
+      int64_t length = start == end ? 0 : (end - start + step - 1) / step;
       target_dims->emplace_back(length);
       dim++;
     } else if (index_item.IsInteger()) {
diff --git a/oneflow/core/kernel/kernel_util.cu b/oneflow/core/kernel/kernel_util.cu
index bb8770894..8c517e4d5 100644
--- a/oneflow/core/kernel/kernel_util.cu
+++ b/oneflow/core/kernel/kernel_util.cu
@@ -655,6 +655,7 @@ __global__ void CastOnGpu<half, float>(const half* in, float* out, int64_t elem_
 
 template<typename T, typename U>
 void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num) {
+  if (elem_num == 0) { return; }
   if (std::is_same<T, U>::value) {
     Memcpy<DeviceType::kGPU>(ctx, out_dptr, in_dptr, elem_num * sizeof(T));
   } else {
@@ -667,6 +668,7 @@ void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_n
 template<>
 void CopyElemOnGpu<float, float16>(DeviceCtx* ctx, const float* in_dptr, float16* out_dptr,
                                    int64_t elem_num) {
+  if (RoundUp(elem_num, 2) == 0) { return; }
   CastOnGpu<float, half>
       <<<BlocksNum4ThreadsNum(RoundUp(elem_num, 2) / 2), kCudaThreadsNumPerBlock, 0,
          ctx->cuda_stream()>>>(in_dptr, reinterpret_cast<half*>(out_dptr), elem_num);
@@ -675,6 +677,7 @@ void CopyElemOnGpu<float, float16>(DeviceCtx* ctx, const float* in_dptr, float16
 template<>
 void CopyElemOnGpu<float16, float>(DeviceCtx* ctx, const float16* in_dptr, float* out_dptr,
                                    int64_t elem_num) {
+  if (RoundUp(elem_num, 2) == 0) { return; }
   CastOnGpu<half, float>
       <<<BlocksNum4ThreadsNum(RoundUp(elem_num, 2) / 2), kCudaThreadsNumPerBlock, 0,
          ctx->cuda_stream()>>>(reinterpret_cast<const half*>(in_dptr), out_dptr, elem_num);
diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu
index 286bf6c30..2707d624a 100644
--- a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu
+++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu
@@ -69,6 +69,7 @@ void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeVie
     cur_stride *= x_shape.At(i);
   }
   for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; }
+  if (elem_cnt == 0) { return; }
   TransposeGpu<NDIMS, T>
       <<<SMBlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
           y_shape_struct, x_strides, elem_cnt, x, y);
diff --git a/oneflow/core/kernel/util/cuda_dnn_interface.cu b/oneflow/core/kernel/util/cuda_dnn_interface.cu
index 97e7b9da7..85c755a61 100644
--- a/oneflow/core/kernel/util/cuda_dnn_interface.cu
+++ b/oneflow/core/kernel/util/cuda_dnn_interface.cu
@@ -132,6 +132,7 @@ template<typename T>
 struct ReluHelper final {
   static void ReluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     if (x == y) {
       InplaceReluForwardGpu<T>
           <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, y);
diff --git a/oneflow/core/ndarray/ndarray_apply_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_binary_core.cu
index c7dea0721..5ea4c5779 100644
--- a/oneflow/core/ndarray/ndarray_apply_binary_core.cu
+++ b/oneflow/core/ndarray/ndarray_apply_binary_core.cu
@@ -40,12 +40,14 @@ struct NdarrayApplyBinaryCoreWrapper<DeviceType::kGPU, T, binary_func> final {
                     const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,
                     const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
     size_t n = y.host_shape().HostElemNum();
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu<T, binary_func>), ctx, n, n, y.host_ptr(),
                     a.host_ptr(), b.host_ptr());
   }
   static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
                            const XpuVarNdarray<const T>& x) {
     size_t n = y.host_shape().HostElemNum();
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu<T, binary_func>), ctx, n, n, y.host_ptr(),
                     x.host_ptr());
   }
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h
index eaf5a76a0..2250e059b 100644
--- a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h
@@ -99,7 +99,9 @@ struct NdarrayApplyBroadcastBinary<
     CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes());
     CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes());
     for (int i = 0; i < y.shape().NumAxes(); ++i) {
-      CHECK_EQ(y.shape().At(i), std::max(a.shape().At(i), b.shape().At(i)));
+      CHECK_EQ(y.shape().At(i), (a.shape().At(i) == 0 || b.shape().At(i) == 0)
+                                    ? 0
+                                    : std::max(a.shape().At(i), b.shape().At(i)));
       if (a.shape().At(i) != b.shape().At(i)) {
         CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1);
       }
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu
index 9335521a5..c1ed93595 100644
--- a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu
@@ -89,6 +89,7 @@ struct NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS, binary
                     const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,
                     const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
     size_t n = y.host_shape().HostElemNum();
+    if (n == 0) { return; }
     if (IsKernelSafeInt32(n) && PartialBroadcast<int32_t>(ctx, y, a, b)) { return; }
     if (!IsKernelSafeInt32(n) && PartialBroadcast<int64_t>(ctx, y, a, b)) { return; }
     RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, a, b);
@@ -151,6 +152,7 @@ struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS,
     size_t n = y.host_shape().HostElemNum();
     XpuVarNdarray<const T> a(y.host_shape(), y.host_ptr());
     using NBB = NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS, binary_func>;
+    if (n == 0) { return; }
     if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast<int32_t>(ctx, y, a, x)) { return; }
     if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast<int64_t>(ctx, y, a, x)) { return; }
     RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, x);
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu
index 31e9df6c5..be6c625ad 100644
--- a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu
@@ -30,6 +30,7 @@ template<typename T, int NDIMS, template<typename> class unary_func>
 struct NdarrayApplyBroadcastUnaryCoreWrapper<DeviceType::kGPU, T, NDIMS, unary_func> final {
   static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
     size_t n = y.host_shape().HostElemNum();
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc<T, NDIMS, unary_func>), ctx, n, y, x);
   }
 };
diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_unary_core.cu
index 2b6963b66..1b9192ba9 100644
--- a/oneflow/core/ndarray/ndarray_apply_unary_core.cu
+++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cu
@@ -31,6 +31,7 @@ template<typename T, template<typename> class unary_func>
 struct NdarrayApplyUnaryCoreWrapper<DeviceType::kGPU, T, unary_func> final {
   static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) {
     size_t n = y.host_shape().HostElemNum();
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu<T, unary_func>), ctx, n, y.host_ptr(), n);
   }
 };
diff --git a/oneflow/core/ndarray/ndarray_assign_core.cu b/oneflow/core/ndarray/ndarray_assign_core.cu
index 3ee26a91a..17d79f64c 100644
--- a/oneflow/core/ndarray/ndarray_assign_core.cu
+++ b/oneflow/core/ndarray/ndarray_assign_core.cu
@@ -33,6 +33,7 @@ struct NdarrayAssignCoreWrapper<DeviceType::kGPU, T, NDIMS> final {
   static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
                      const XpuReducedNdarray<T, NDIMS>& reduced) {
     size_t n = y.host_shape().HostElemNum();
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((NdarrayAssignGpu<T, NDIMS>), ctx, n, y, reduced);
   }
 };
diff --git a/oneflow/user/kernels/add_n_kernel.cu b/oneflow/user/kernels/add_n_kernel.cu
index 0ae893773..cb9d15611 100644
--- a/oneflow/user/kernels/add_n_kernel.cu
+++ b/oneflow/user/kernels/add_n_kernel.cu
@@ -56,7 +56,7 @@ struct GpuAddCaller {
     for (int32_t i = 0; i < N; ++i) {
       para.in[i] = ctx->Tensor4ArgNameAndIndex("in", i)->dptr<T>();
     }
-
+    if (n == 0) { return; }
     gpu_add<T, N>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, para);
diff --git a/oneflow/user/kernels/clip_by_value_kernel.cu b/oneflow/user/kernels/clip_by_value_kernel.cu
index df9f95289..5a5acffd8 100644
--- a/oneflow/user/kernels/clip_by_value_kernel.cu
+++ b/oneflow/user/kernels/clip_by_value_kernel.cu
@@ -36,12 +36,14 @@ template<typename T>
 struct ClipKernelUtil<DeviceType::kGPU, T> {
   template<typename F>
   static void Forward(DeviceCtx* ctx, F clip_func, const int64_t n, const T* x, T* y) {
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((CudaClipForward<T, F>), ctx, n, clip_func, n, x, y);
   }
 
   template<typename F>
   static void Backward(DeviceCtx* ctx, F clip_func, const int64_t n, const T* x, const T* dy,
                        T* dx) {
+    if (n == 0) { return; }
     RUN_CUDA_KERNEL((CudaClipBackward<T, F>), ctx, n, clip_func, n, x, dy, dx);
   }
 };
diff --git a/oneflow/user/kernels/concat_kernel.cpp b/oneflow/user/kernels/concat_kernel.cpp
index 3619dba8c..2fe564317 100644
--- a/oneflow/user/kernels/concat_kernel.cpp
+++ b/oneflow/user/kernels/concat_kernel.cpp
@@ -59,6 +59,7 @@ class ConcatKernel final : public user_op::OpKernel {
     for (const auto& in_arg_pair : ctx->inputs()) {
       const user_op::Tensor* in_tensor =
           ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second);
+      if (in_tensor->shape().elem_cnt() == 0) { continue; }
       const int64_t in_cols = in_tensor->shape().Count(axis);
       CHECK_EQ(in_tensor->shape().elem_cnt(), rows * in_cols);
       if (in_cols > 0) {
diff --git a/oneflow/user/kernels/constant_kernel.cpp b/oneflow/user/kernels/constant_kernel.cpp
index ff6f18a3c..969f07ce9 100644
--- a/oneflow/user/kernels/constant_kernel.cpp
+++ b/oneflow/user/kernels/constant_kernel.cpp
@@ -30,7 +30,8 @@ class ConstantKernel final : public OpKernel {
     Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
     bool is_floating_value = ctx->Attr<bool>("is_floating_value");
     const int64_t elem_cnt = out_tensor->shape().elem_cnt();
-    CHECK_GT(elem_cnt, 0);
+    CHECK_GE(elem_cnt, 0);
+    if (elem_cnt == 0) { return; }
     NewKernelUtil<device_type>::Fill(ctx->device_ctx(), elem_cnt,
                                      is_floating_value
                                          ? static_cast<T>(ctx->Attr<double>("floating_value"))
diff --git a/oneflow/user/kernels/empty_kernel.cpp b/oneflow/user/kernels/empty_kernel.cpp
index 732d6d2d4..e056c83a7 100644
--- a/oneflow/user/kernels/empty_kernel.cpp
+++ b/oneflow/user/kernels/empty_kernel.cpp
@@ -27,10 +27,6 @@ class EmptyKernel final : public OpKernel {
 
  private:
   void Compute(user_op::KernelComputeContext* ctx) const override {
-    Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
-    const int64_t elem_cnt = out_tensor->shape().elem_cnt();
-    CHECK_GT(elem_cnt, 0);
-
     // Do nothing
   }
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
diff --git a/oneflow/user/kernels/math_binary_elementwise_kernel.cu b/oneflow/user/kernels/math_binary_elementwise_kernel.cu
index 6dc1e8cea..29c62fdba 100644
--- a/oneflow/user/kernels/math_binary_elementwise_kernel.cu
+++ b/oneflow/user/kernels/math_binary_elementwise_kernel.cu
@@ -52,6 +52,7 @@ class MathBinaryElementwiseGpuKernel final : public user_op::OpKernel {
     user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0);
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathBinaryElementwiseForwardGpu<BinaryFunctor, T>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_z->mut_dptr<T>());
@@ -73,6 +74,7 @@ class MathBinaryElementwiseXGradGpuKernel final : public user_op::OpKernel {
     user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, T>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_dz->dptr<T>(),
@@ -95,6 +97,7 @@ class MathBinaryElementwiseYGradGpuKernel final : public user_op::OpKernel {
     user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, T>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_dz->dptr<T>(),
@@ -143,6 +146,7 @@ class MathBinaryElementwiseGpuHalfKernel final : public user_op::OpKernel {
     half* z = reinterpret_cast<half*>(tensor_z->mut_dptr<float16>());
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathBinaryElementwiseForwardGpu<BinaryFunctor, half>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, y, z);
@@ -169,6 +173,7 @@ class MathBinaryElementwiseXGradGpuHalfKernel final : public user_op::OpKernel {
     half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, half>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, y, dz, dx);
@@ -195,6 +200,7 @@ class MathBinaryElementwiseYGradGpuHalfKernel final : public user_op::OpKernel {
     half* dy = reinterpret_cast<half*>(tensor_dy->mut_dptr<float16>());
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, half>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, y, dz, dy);
diff --git a/oneflow/user/kernels/math_unary_elementwise_kernel.cu b/oneflow/user/kernels/math_unary_elementwise_kernel.cu
index 32144bffe..7daeb0bb4 100644
--- a/oneflow/user/kernels/math_unary_elementwise_kernel.cu
+++ b/oneflow/user/kernels/math_unary_elementwise_kernel.cu
@@ -46,6 +46,7 @@ class MathUnaryElementwiseGpuKernel final : public user_op::OpKernel {
     T* y = tensor_y->mut_dptr<T>();
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathUnaryElementwiseForwardGpu<UnaryFunctor, T>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, y);
@@ -70,6 +71,7 @@ class MathUnaryElementwiseGradGpuKernel final : public user_op::OpKernel {
     T* dx = tensor_dx->mut_dptr<T>();
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathUnaryElementwiseBackwardGpu<UnaryFunctor, T>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, dy, dx);
@@ -110,6 +112,7 @@ class MathUnaryElementwiseGpuHalfKernel final : public user_op::OpKernel {
     half* y = reinterpret_cast<half*>(tensor_y->mut_dptr<float16>());
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathUnaryElementwiseForwardGpu<UnaryFunctor, half>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, y);
@@ -134,6 +137,7 @@ class MathUnaryElementwiseGradGpuHalfKernel final : public user_op::OpKernel {
     half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());
     int64_t n = tensor_x->shape().elem_cnt();
     CHECK_LE(n, GetMaxVal<int32_t>() / 2);
+    if (n == 0) { return; }
     MathUnaryElementwiseBackwardGpu<UnaryFunctor, half>
         <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
             n, x, dy, dx);
diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp
index 1d0a6083d..b6c235bb5 100644
--- a/oneflow/user/kernels/reduce_kernel.cpp
+++ b/oneflow/user/kernels/reduce_kernel.cpp
@@ -16,6 +16,7 @@ limitations under the License.
 #include "oneflow/core/framework/framework.h"
 #include "oneflow/core/ndarray/ndarray_util.h"
 #include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/kernel/kernel_util.h"
 
 namespace oneflow {
 
@@ -33,6 +34,16 @@ class ReduceKernel final : public user_op::OpKernel {
     user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0);
     user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
     const auto& axis = ctx->Attr<std::vector<int32_t>>("axis");
+
+    if (input_tensor->shape().elem_cnt() == 0) {
+      if (output_tensor->shape().elem_cnt() != 0) {
+        AutoMemset(
+            ctx->device_ctx(), output_tensor->mut_dptr<T>(), 0,
+            output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()),
+            output_tensor->mem_case());
+      }
+      return;
+    }
     const Shape& reduced_shape =
         CreateReducedShape(input_tensor->shape(), {axis.begin(), axis.end()});
     NdarrayReduce<device_type, T, BinaryFunc>::Reduce(
diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp
index d456857fc..1cab5da46 100644
--- a/oneflow/user/kernels/reduce_like_kernels.cpp
+++ b/oneflow/user/kernels/reduce_like_kernels.cpp
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include "oneflow/core/framework/framework.h"
+#include "oneflow/core/ndarray/binary_func.h"
 #include "oneflow/core/ndarray/ndarray_util.h"
+#include "oneflow/core/kernel/kernel_util.h"
 
 namespace oneflow {
 
@@ -39,6 +41,14 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel {
     user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0);
     user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0);
     const auto& axis = ctx->Attr<std::vector<int32_t>>("axis");
+    if (tensor_x->shape().elem_cnt() == 0) {
+      if (tensor_y->shape().elem_cnt() != 0) {
+        AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr<T>(), 0,
+                   tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()),
+                   tensor_y->mem_case());
+      }
+      return;
+    }
     if (axis.empty()) {
       CHECK_EQ(tensor_x->shape(), tensor_y->shape());
       Memcpy<device_type>(ctx->device_ctx(), tensor_y->mut_dptr(), tensor_x->dptr(),
diff --git a/oneflow/user/kernels/slice_util.cu b/oneflow/user/kernels/slice_util.cu
index 4ed08a84e..7edb16af4 100644
--- a/oneflow/user/kernels/slice_util.cu
+++ b/oneflow/user/kernels/slice_util.cu
@@ -48,6 +48,7 @@ void LaunchSliceForward(DeviceCtx* ctx, const SliceParams& params, const T* enti
   int64_t elem_cnt = params.elem_cnt();
   SliceIndexHelper<NDIM> entire_idx_cvtr(params.dims);
   SliceIndexHelper<NDIM> sliced_idx_cvtr(params.size);
+  if (elem_cnt == 0) { return; }
   SliceForwardGpu<T, NDIM>
       <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
           elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced);
diff --git a/oneflow/user/kernels/triu_kernel.cu b/oneflow/user/kernels/triu_kernel.cu
index bf2173d4d..2dd732982 100644
--- a/oneflow/user/kernels/triu_kernel.cu
+++ b/oneflow/user/kernels/triu_kernel.cu
@@ -90,6 +90,7 @@ class GpuTriuKernel final : public user_op::OpKernel {
     const int64_t num_cols = shape.At(shape.NumAxes() - 1);
     user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0);
     const int32_t elem_cnt = shape.elem_cnt();
+    if (elem_cnt == 0) { return; }
     if (num_cols % (kCudaWarpSize * 2) == 0) {
       const int64_t total_rows = elem_cnt / num_cols;
       TriuWarpProcessRowGpu<<<BlocksNum4ThreadsNum(total_rows * kCudaWarpSize),
diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp
index 98008f9c8..641781191 100644
--- a/oneflow/user/ops/math_binary_broadcast_ops.cpp
+++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp
@@ -47,7 +47,9 @@ Maybe<void> InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) {
       CHECK_OR_RETURN(x_shape.At(i) == 1 || y_shape.At(i) == 1 || x_shape.At(i) == y_shape.At(i))
           << "op: " << ctx->op_name() << ", type: " << ctx->op_type_name() << ", i: " << i
           << ", x_shape: " << x_shape << ", y_shape: " << y_shape;
-      out_shape.Set(i, std::max(x_shape.At(i), y_shape.At(i)));
+      out_shape.Set(i, (x_shape.At(i) == 0 || y_shape.At(i) == 0)
+                           ? 0
+                           : std::max(x_shape.At(i), y_shape.At(i)));
     }
     *tensor_z->mut_shape() = out_shape;
   }
diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp
index 573028fb2..de4331180 100644
--- a/oneflow/user/ops/reshape_op.cpp
+++ b/oneflow/user/ops/reshape_op.cpp
@@ -48,7 +48,7 @@ Maybe<void> LogicalTensorDescInferFn(user_op::InferContext* ctx) {
   *out_tensor_desc = in_tensor_desc;
   CHECK_GE_OR_RETURN(shape.NumAxes(), 1);
   DimVector dim_vec = {shape.dim_vec().begin(), shape.dim_vec().end()};
-  FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GT_OR_RETURN(dim_vec.at(i), 0); }
+  FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GE_OR_RETURN(dim_vec.at(i), 0); }
   *out_shape = Shape(dim_vec);
   CHECK_EQ_OR_RETURN(out_shape->elem_cnt(), in_shape.elem_cnt());
   return Maybe<void>::Ok();
diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp
index e409f7058..a39d973ca 100644
--- a/oneflow/user/ops/slice_op.cpp
+++ b/oneflow/user/ops/slice_op.cpp
@@ -40,14 +40,16 @@ Maybe<void> InferSliceOpTensorDesc(user_op::InferContext* ctx) {
   DimVector dim_vec(ndim);
   FOR_RANGE(size_t, i, 0, dim_vec.size()) {
     const int64_t dim_size = x_shape.At(i);
-    if (dim_size == 0) {
+    const int64_t step = step_vec.at(i);
+    int64_t start = start_vec.at(i);
+    int64_t stop = stop_vec.at(i);
+    if (dim_size == 0 || start == stop) {
       dim_vec[i] = 0;
       continue;
     }
-    const int64_t step = step_vec.at(i);
     CHECK_NE_OR_RETURN(step, 0) << "slice step cannot be 0";
-    int64_t start = RegulateSliceStart(start_vec.at(i), dim_size);
-    int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size);
+    start = RegulateSliceStart(start, dim_size);
+    stop = RegulateSliceStop(stop, dim_size);
     if (step > 0) {
       CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0"
                                          ", otherwise empty result will be outputted.";
diff --git a/python/oneflow/framework/tensor.py b/python/oneflow/framework/tensor.py
index e599fc328..b395273a7 100644
--- a/python/oneflow/framework/tensor.py
+++ b/python/oneflow/framework/tensor.py
@@ -40,7 +40,8 @@ def _tensor_numpy(eager_local_tensor):
         tuple(eager_local_tensor.shape),
         dtype=flow.convert_oneflow_dtype_to_numpy_dtype(eager_local_tensor.dtype),
     )
-    copy_to_numpy(ndarray)
+    if ndarray.size != 0:
+        copy_to_numpy(ndarray)
     return ndarray
 
 
diff --git a/python/oneflow/nn/modules/reduce_ops.py b/python/oneflow/nn/modules/reduce_ops.py
index 867ef0323..f43fbb02e 100644
--- a/python/oneflow/nn/modules/reduce_ops.py
+++ b/python/oneflow/nn/modules/reduce_ops.py
@@ -114,6 +114,9 @@ class Min(Module):
         self._op = _build_reduce_op("reduce_min", keepdims)
 
     def forward(self, input):
+        # TODO: moves this check in functor
+        if input.shape.numel() == 0:
+            raise RuntimeError("operation does not have an identity.")
         axis_checked = _check_axis(self.axis, input.shape)
         if len(axis_checked) == 0:
             return input
@@ -151,6 +154,9 @@ class Max(Module):
         self._op = _build_reduce_op("reduce_max", keepdims)
 
     def forward(self, input):
+        # TODO: moves this check in functor
+        if input.shape.numel() == 0:
+            raise RuntimeError("operation does not have an identity.")
         axis_checked = _check_axis(self.axis, input.shape)
         if len(axis_checked) == 0:
             return input
diff --git a/python/oneflow/ops/array_ops.py b/python/oneflow/ops/array_ops.py
index f32c64bf5..1c7eb5789 100644
--- a/python/oneflow/ops/array_ops.py
+++ b/python/oneflow/ops/array_ops.py
@@ -44,7 +44,7 @@ def check_slice_tup_list(slice_tup_list, shape):
         if start is None:
             start = 0 if step > 0 else np.iinfo(np.int64).max
         elif start < -dim_size or start >= dim_size:
-            raise ValueError("slice start must be in range [-size, size)")
+            start, stop, step = 0, 0, 1
         if stop is None:
             stop = np.iinfo(np.int64).max if step > 0 else np.iinfo(np.int64).min
         elif stop < -dim_size - 1 or stop > dim_size:
diff --git a/python/oneflow/test/modules/test_abs.py b/python/oneflow/test/modules/test_abs.py
index 0b8e8fd2d..f0d1404e0 100644
--- a/python/oneflow/test/modules/test_abs.py
+++ b/python/oneflow/test/modules/test_abs.py
@@ -23,6 +23,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_abs_forward(test_case, device):
@@ -81,6 +82,13 @@ class TestAbs(flow.unittest.TestCase):
         for device in ["cpu", "cuda"]:
             test_tensor_against_pytorch(test_case, "abs", device=device)
 
+    @autotest()
+    def test_abs_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = torch.abs(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py
index 40996c845..ed103468e 100644
--- a/python/oneflow/test/modules/test_activation.py
+++ b/python/oneflow/test/modules/test_activation.py
@@ -38,6 +38,16 @@ class TestReLUModule(flow.unittest.TestCase):
         y = m(x)
         return y
 
+    @autotest(auto_backward=False)
+    def test_relu_module_with_0shape_data(test_case):
+        m = torch.nn.ReLU()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device)
+        y = m(x)
+        return y
+
 
 @flow.unittest.skip_unless_1n1d()
 class TestReLU6Module(flow.unittest.TestCase):
@@ -51,6 +61,16 @@ class TestReLU6Module(flow.unittest.TestCase):
         y = m(x)
         return y
 
+    @autotest(auto_backward=False)
+    def test_relu6_module_with_0shape_data(test_case):
+        m = torch.nn.ReLU6()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device)
+        y = m(x)
+        return y
+
 
 @flow.unittest.skip_unless_1n1d()
 class TestTanh(flow.unittest.TestCase):
@@ -64,6 +84,16 @@ class TestTanh(flow.unittest.TestCase):
         y = m(x)
         return y
 
+    @autotest(auto_backward=False)
+    def test_tanh_module_with_0shapedata(test_case):
+        m = torch.nn.Tanh()
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device)
+        y = m(x)
+        return y
+
     @autotest()
     def test_flow_tanh_with_random_data(test_case):
         device = random_device()
@@ -71,6 +101,13 @@ class TestTanh(flow.unittest.TestCase):
         y = torch.tanh(x)
         return y
 
+    @autotest(auto_backward=False)
+    def test_flow_tanh_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device)
+        y = torch.tanh(x)
+        return y
+
 
 @flow.unittest.skip_unless_1n1d()
 class TestELUModule(flow.unittest.TestCase):
@@ -84,6 +121,16 @@ class TestELUModule(flow.unittest.TestCase):
         y = m(x)
         return y
 
+    @autotest(auto_backward=False)
+    def test_elu_module_with_0shape_data(test_case):
+        m = torch.nn.ELU(alpha=random() | nothing())
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device)
+        y = m(x)
+        return y
+
 
 @flow.unittest.skip_unless_1n1d()
 class TestGelu(flow.unittest.TestCase):
diff --git a/python/oneflow/test/modules/test_add.py b/python/oneflow/test/modules/test_add.py
index cbc7a2a51..1a9c2c95f 100644
--- a/python/oneflow/test/modules/test_add.py
+++ b/python/oneflow/test/modules/test_add.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_add_forward(test_case, shape, device):
@@ -151,6 +152,14 @@ class TestAddModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest()
+    def test_0shape_add(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(2, 0, 3).to(device)
+        y = random_pytorch_tensor(2, 1, 3).to(device)
+        out = x + y
+        return out
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_argwhere.py b/python/oneflow/test/modules/test_argwhere.py
index ac7cd37b5..62edfc9c8 100644
--- a/python/oneflow/test/modules/test_argwhere.py
+++ b/python/oneflow/test/modules/test_argwhere.py
@@ -39,7 +39,7 @@ class TestArgwhere(flow.unittest.TestCase):
     def test_argwhere(test_case):
         arg_dict = OrderedDict()
         arg_dict["test_fun"] = [_test_argwhere]
-        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
+        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6), (2, 3, 0, 4)]
         arg_dict["device"] = ["cpu", "cuda"]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
diff --git a/python/oneflow/test/modules/test_cast.py b/python/oneflow/test/modules/test_cast.py
index 2d21a2142..10b00e819 100644
--- a/python/oneflow/test/modules/test_cast.py
+++ b/python/oneflow/test/modules/test_cast.py
@@ -66,6 +66,17 @@ class TestCast(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    def test_cast_with_0shape_data(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_cast_float2int,
+            _test_cast_int2float,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        arg_dict["shape"] = [(2, 3, 0, 5)]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_ceil.py b/python/oneflow/test/modules/test_ceil.py
index 002251d66..b721ae25b 100644
--- a/python/oneflow/test/modules/test_ceil.py
+++ b/python/oneflow/test/modules/test_ceil.py
@@ -54,6 +54,13 @@ class TestCeilModule(flow.unittest.TestCase):
         y = torch.ceil(input)
         return y
 
+    @autotest(auto_backward=False)
+    def test_ceil_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = torch.ceil(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_clamp.py b/python/oneflow/test/modules/test_clamp.py
index 640408dd1..afcc68cd4 100644
--- a/python/oneflow/test/modules/test_clamp.py
+++ b/python/oneflow/test/modules/test_clamp.py
@@ -153,6 +153,13 @@ class TestClampModule(flow.unittest.TestCase):
         )
         return y
 
+    @autotest(auto_backward=False)
+    def test_clamp_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = torch.clamp(x, min=random().to(float), max=random().to(float))
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_concat.py b/python/oneflow/test/modules/test_concat.py
index e5dea24b2..2097f04e0 100644
--- a/python/oneflow/test/modules/test_concat.py
+++ b/python/oneflow/test/modules/test_concat.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_concat_origin(test_case, device):
@@ -132,6 +133,14 @@ class TestModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest(n=10, auto_backward=False)
+    def test_concat_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 3, 2, 4).to(device)
+        y = random_pytorch_tensor(4, 2, 3, random(0, 3), 4).to(device)
+        z = torch.cat((x, y), dim=2)
+        return z
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_constant.py b/python/oneflow/test/modules/test_constant.py
index 84d33da60..db6700ee7 100644
--- a/python/oneflow/test/modules/test_constant.py
+++ b/python/oneflow/test/modules/test_constant.py
@@ -18,11 +18,11 @@ import unittest
 from collections import OrderedDict
 
 import numpy as np
-from test_util import GenArgList
-
 import oneflow as flow
+
 import oneflow.unittest
-from oneflow.framework.tensor import register_tensor_op
+from test_util import GenArgList
+from automated_test_util import *
 
 
 def _test_ones(test_case, device, shape):
@@ -119,7 +119,7 @@ class TestConstantModule(flow.unittest.TestCase):
             _test_new_ones,
         ]
         arg_dict["device"] = ["cpu", "cuda"]
-        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
+        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 0, 4)]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
diff --git a/python/oneflow/test/modules/test_div.py b/python/oneflow/test/modules/test_div.py
index daba1e6f0..40fddb3e7 100644
--- a/python/oneflow/test/modules/test_div.py
+++ b/python/oneflow/test/modules/test_div.py
@@ -23,6 +23,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_div_impl(test_case, shape, device):
@@ -100,6 +101,14 @@ class TestDiv(flow.unittest.TestCase):
                 device=arg[1],
             )
 
+    @autotest(auto_backward=False)
+    def test_0shape_div(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        z = x / y
+        return z
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_eq.py b/python/oneflow/test/modules/test_eq.py
index c1feca9bb..b838397c2 100644
--- a/python/oneflow/test/modules/test_eq.py
+++ b/python/oneflow/test/modules/test_eq.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_eq(test_case, shape, device):
@@ -99,6 +100,14 @@ class TestEq(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest(auto_backward=False)
+    def test_eq_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(3, 2, 0, 3).to(device)
+        y = random_pytorch_tensor(3, 2, 0, 3).to(device)
+        z = torch.eq(x, y)
+        return z
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_expm1.py b/python/oneflow/test/modules/test_expm1.py
index 454b5daf6..04821a2d7 100644
--- a/python/oneflow/test/modules/test_expm1.py
+++ b/python/oneflow/test/modules/test_expm1.py
@@ -54,6 +54,13 @@ class TestExpm1Module(flow.unittest.TestCase):
         y = torch.expm1(input)
         return y
 
+    @autotest(auto_backward=False)
+    def test_expm1_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = torch.expm1(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_fmod.py b/python/oneflow/test/modules/test_fmod.py
index e0ab4da4d..d20381fd3 100644
--- a/python/oneflow/test/modules/test_fmod.py
+++ b/python/oneflow/test/modules/test_fmod.py
@@ -91,6 +91,13 @@ class TestFmodModule(flow.unittest.TestCase):
         other = random_pytorch_tensor().to(device)
         return torch.fmod(input, other)
 
+    @autotest(auto_backward=False)
+    def test_fmod_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = torch.fmod(x, 2)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_greater.py b/python/oneflow/test/modules/test_greater.py
index a7990b637..e977621dc 100644
--- a/python/oneflow/test/modules/test_greater.py
+++ b/python/oneflow/test/modules/test_greater.py
@@ -118,6 +118,15 @@ class TestGreater(flow.unittest.TestCase):
         y2 = x1 > x2
         return (y1, y2)
 
+    @autotest(auto_backward=False)
+    def test_greater_with_0shape_data(test_case):
+        device = random_device()
+        x1 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device)
+        x2 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device)
+        y1 = torch.gt(x1, x2)
+        y2 = x1 > x2
+        return (y1, y2)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_ne.py b/python/oneflow/test/modules/test_ne.py
index 229ec5a8f..48f80d873 100644
--- a/python/oneflow/test/modules/test_ne.py
+++ b/python/oneflow/test/modules/test_ne.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_ne(test_case, shape, device):
@@ -99,6 +100,16 @@ class TestNe(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest(auto_backward=False)
+    def test_ne_with_0shape_data(test_case):
+        device = random_device()
+        x1 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device)
+        x2 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device)
+        y1 = torch.ne(x1, x2)
+        y2 = torch.ne(x1, 2)
+        y3 = torch.ne(x1, 2.0)
+        return (y1, y2, y3)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_negative.py b/python/oneflow/test/modules/test_negative.py
index 7f8e7e45a..4534545a2 100644
--- a/python/oneflow/test/modules/test_negative.py
+++ b/python/oneflow/test/modules/test_negative.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_negtive(test_case, shape, device):
@@ -77,6 +78,15 @@ class TestNegativeModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest(auto_backward=False)
+    def test_ne_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 3, 0, 5).to(device)
+        y1 = torch.negative(x)
+        y2 = torch.neg(x)
+        y3 = -x
+        return (y1, y2, y3)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_pow.py b/python/oneflow/test/modules/test_pow.py
index 1a87b26ff..2a177e4a1 100644
--- a/python/oneflow/test/modules/test_pow.py
+++ b/python/oneflow/test/modules/test_pow.py
@@ -96,7 +96,7 @@ def _test_pow_backward_impl(test_case, device):
 class TestPow(flow.unittest.TestCase):
     def test_pow_forward(test_case):
         arg_dict = OrderedDict()
-        arg_dict["shape"] = [(2, 3), (2, 3, 4, 5)]
+        arg_dict["shape"] = [(2, 3), (2, 3, 4, 5), (2, 3, 0, 5)]
         arg_dict["scalar"] = [2.1, 0.8]
         arg_dict["device"] = ["cpu", "cuda"]
         for arg in GenArgList(arg_dict):
diff --git a/python/oneflow/test/modules/test_reshape.py b/python/oneflow/test/modules/test_reshape.py
index bf5cf34da..f8e844c7f 100644
--- a/python/oneflow/test/modules/test_reshape.py
+++ b/python/oneflow/test/modules/test_reshape.py
@@ -23,6 +23,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_reshape(test_case, device):
@@ -83,6 +84,15 @@ class TestModule(flow.unittest.TestCase):
         y = torch.reshape(x, shape=(-1,))
         return y
 
+    @autotest(auto_backward=False)
+    def test_reshape_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 0, 3).to(device)
+        y = torch.reshape(
+            x, shape=(random(0, 5).to(int).value(), 0, random(0, 5).to(int).value())
+        )
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_sign.py b/python/oneflow/test/modules/test_sign.py
index f46a6165b..50bf57501 100644
--- a/python/oneflow/test/modules/test_sign.py
+++ b/python/oneflow/test/modules/test_sign.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_sign_impl(test_case, shape, device):
@@ -47,6 +48,13 @@ class TestSign(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             _test_sign_impl(test_case, *arg)
 
+    @autotest(auto_backward=False)
+    def test_sign_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device)
+        y = torch.sign(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_squeeze.py b/python/oneflow/test/modules/test_squeeze.py
index a69f92cc3..7b158ee29 100644
--- a/python/oneflow/test/modules/test_squeeze.py
+++ b/python/oneflow/test/modules/test_squeeze.py
@@ -108,6 +108,13 @@ class TestSqueeze(flow.unittest.TestCase):
         y = torch.squeeze(x, random(1, 3).to(int))
         return y
 
+    @autotest(auto_backward=False)
+    def test_squeeze_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(3, 2, 1, 0).to(device)
+        y = torch.squeeze(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_sub.py b/python/oneflow/test/modules/test_sub.py
index 623d33d91..db9f9a9e0 100644
--- a/python/oneflow/test/modules/test_sub.py
+++ b/python/oneflow/test/modules/test_sub.py
@@ -23,6 +23,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_sub_impl(test_case, shape, device):
@@ -110,6 +111,17 @@ class TestSubModule(flow.unittest.TestCase):
                 device=arg[1],
             )
 
+    @autotest(auto_backward=False)
+    def test_sub_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(2, 0, 3).to(device)
+        y = random_pytorch_tensor(2, 1, 3).to(device)
+        out1 = x - y
+        out2 = x - 2
+        out3 = 2 - x
+        out4 = torch.sub(x, y)
+        return out1, out2, out3, out4
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py
index e9927eeda..87462f56d 100644
--- a/python/oneflow/test/modules/test_sum.py
+++ b/python/oneflow/test/modules/test_sum.py
@@ -69,12 +69,19 @@ class TestSumModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             _test_sum_impl(test_case, *arg)
 
+    @autotest()
     def test_sum_against_pytorch(test_case):
-        arg_dict = OrderedDict()
-        arg_dict["test_type"] = [test_flow_against_pytorch, test_tensor_against_pytorch]
-        arg_dict["device"] = ["cpu", "cuda"]
-        for arg in GenArgList(arg_dict):
-            arg[0](test_case, "sum", device=arg[1])
+        device = random_device()
+        x = random_pytorch_tensor(4, random(0, 5), 2).to(device)
+        y = torch.sum(x)
+        return y
+
+    @autotest(auto_backward=False)
+    def test_sum_with_0shape_tensor(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 4, 3, 0, 2).to(device)
+        y = torch.sum(x, dim=np.random.randint(0, 3))
+        return y
 
 
 if __name__ == "__main__":
diff --git a/python/oneflow/test/modules/test_transpose.py b/python/oneflow/test/modules/test_transpose.py
index 8ba4b1330..675253df1 100644
--- a/python/oneflow/test/modules/test_transpose.py
+++ b/python/oneflow/test/modules/test_transpose.py
@@ -102,6 +102,13 @@ class TestTranspose(flow.unittest.TestCase):
         y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))
         return y
 
+    @autotest(auto_backward=False)
+    def test_transpose_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device)
+        y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int))
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_triu.py b/python/oneflow/test/modules/test_triu.py
index 7c85f6796..34fbe40f7 100644
--- a/python/oneflow/test/modules/test_triu.py
+++ b/python/oneflow/test/modules/test_triu.py
@@ -23,6 +23,7 @@ from test_util import GenArgList
 import oneflow as flow
 import oneflow.nn as nn
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_triu(test_case, diagonal, device):
@@ -50,6 +51,13 @@ class TestTriu(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest(auto_backward=False)
+    def test_triu_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
+        y = torch.triu(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_unsqueeze.py b/python/oneflow/test/modules/test_unsqueeze.py
index 1baf9cbae..953cbf217 100644
--- a/python/oneflow/test/modules/test_unsqueeze.py
+++ b/python/oneflow/test/modules/test_unsqueeze.py
@@ -81,6 +81,13 @@ class TestUnsqueeze(flow.unittest.TestCase):
         y = torch.unsqueeze(x, random(1, 3).to(int))
         return y
 
+    @autotest(auto_backward=False)
+    def test_unsqueeze_with_0shape_data(test_case):
+        device = random_device()
+        x = random_pytorch_tensor(3, 2, 1, 0).to(device)
+        y = torch.unsqueeze(x, random(0, 2).to(int))
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py
index 23a57f906..9c55aaba7 100644
--- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py
+++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py
@@ -78,9 +78,20 @@ def get_args(callable, *args, **kwargs):
         return x
 
     for arg in args:
-        arg = get_generator_value(arg)
-        pytorch_args.append(get_pytorch_value(arg))
-        oneflow_args.append(get_oneflow_value(arg))
+        # TODO: refine codes
+        if isinstance(arg, tuple):
+            pytorch_tuple_args = []
+            oneflow_tuple_args = []
+            for t in arg:
+                t = get_generator_value(t)
+                pytorch_tuple_args.append(get_pytorch_value(t))
+                oneflow_tuple_args.append(get_oneflow_value(t))
+            pytorch_args.append(tuple(pytorch_tuple_args))
+            oneflow_args.append(tuple(oneflow_tuple_args))
+        else:
+            arg = get_generator_value(arg)
+            pytorch_args.append(get_pytorch_value(arg))
+            oneflow_args.append(get_oneflow_value(arg))
     for (key, value) in kwargs.items():
         value = get_generator_value(value)
         if isinstance(value, Nothing):
-- 
GitLab