From a9b35aed10f0e244c6834d65403344cfdfc84d3f Mon Sep 17 00:00:00 2001
From: Luyang <flowingsun007@163.com>
Date: Thu, 10 Jun 2021 08:30:31 +0800
Subject: [PATCH] Fix slice bug (#5117)

* align squeeze module to torch

* add squeeze for slice

* refine as comments

* refine as comments

* refine

* refine

* redfine

* refine test case

* refine

* refine conv

* fix conv slice

* refine

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/python/framework/tensor.py          | 11 +++++-
 oneflow/python/nn/modules/conv.py           |  9 ++---
 oneflow/python/nn/modules/slice.py          |  2 +-
 oneflow/python/nn/modules/squeeze.py        |  9 ++++-
 oneflow/python/test/modules/test_slice.py   | 42 ++++++++++++++++++++-
 oneflow/python/test/modules/test_squeeze.py |  8 ++++
 6 files changed, 72 insertions(+), 9 deletions(-)

diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index df724d41f..3a17babf8 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -402,9 +402,18 @@ class Tensor:
     @register_local_tensor_method()
     def __getitem__(self, key):
         # TODO: support inplace __getitem__
+        assert (
+            isinstance(key, int) or isinstance(key, tuple) or isinstance(key, slice)
+        ), "Unsupported key type!"
+        squeeze_dims = None
+        if isinstance(key, tuple):
+            squeeze_dims = list(filter(lambda idx: isinstance(key[idx], int), key))
+        elif isinstance(key, int):
+            squeeze_dims = [0]
+
         start, stop, step, _ = self._get_slice_obj(key)
         res = flow.experimental.slice(self, list(zip(start, stop, step)))
-        return res
+        return res.squeeze(dim=squeeze_dims)
 
     @_auto_determine
     @register_local_tensor_method()
diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py
index 443c1a8fd..35b14a8e1 100644
--- a/oneflow/python/nn/modules/conv.py
+++ b/oneflow/python/nn/modules/conv.py
@@ -210,6 +210,7 @@ class Conv2d(Module):
 
         assert padding_mode == "zeros"
         kernel_size = _pair(kernel_size)
+        self.kernel_size = kernel_size
         stride = _pair(stride)
         padding = _pair(padding)
         dilation = _pair(dilation)
@@ -271,16 +272,14 @@ class Conv2d(Module):
     def forward(self, x):
         if x.device.type == "cpu" and self.groups > 1:
             in_channel_axis = 1
-            filter_out_axis = 0
             in_split_list = ConvUtil.split(
                 x, axis=in_channel_axis, split_num=self.groups
             )
-            filter_split_list = ConvUtil.split(
-                self.weight, axis=filter_out_axis, split_num=self.groups
-            )
             out_list = []
             for i in range(len(in_split_list)):
-                out_list.append(self._cpu_op(in_split_list[i], self.weight[i])[0])
+                out_list.append(
+                    self._cpu_op(in_split_list[i], self.weight[i : i + 1, :, :, :])[0]
+                )
             res = flow.experimental.cat(out_list, dim=in_channel_axis)
         else:
             res = self._op(x, self.weight)[0]
diff --git a/oneflow/python/nn/modules/slice.py b/oneflow/python/nn/modules/slice.py
index dd8bce2c1..9f3b858b2 100644
--- a/oneflow/python/nn/modules/slice.py
+++ b/oneflow/python/nn/modules/slice.py
@@ -113,7 +113,7 @@ def slice_update_op(x, update, slice_tup_list: Sequence[Tuple[int, int, int]]):
         >>> y.numpy()
         array([1., 2., 3., 4., 1.], dtype=float32)
     """
-    start, stop, step = check_slice_tup_list(slice_tup_list, x.shape)
+    start, stop, step = GetSliceAttrs(slice_tup_list, x.shape)
     return SliceUpdate(start, stop, step)(x, update)
 
 
diff --git a/oneflow/python/nn/modules/squeeze.py b/oneflow/python/nn/modules/squeeze.py
index ed11eb2a1..7563a3be4 100644
--- a/oneflow/python/nn/modules/squeeze.py
+++ b/oneflow/python/nn/modules/squeeze.py
@@ -24,6 +24,7 @@ from typing import Optional, Sequence
 class Squeeze(Module):
     def __init__(self, dim: Optional[Sequence[int]] = None) -> None:
         super().__init__()
+        self.dim = dim
 
         self._op = (
             flow.builtin_op("squeeze")
@@ -34,6 +35,8 @@ class Squeeze(Module):
         )
 
     def forward(self, x):
+        if self.dim is None:
+            return x
         return self._op(x)[0]
 
 
@@ -67,8 +70,12 @@ def squeeze_op(input, dim: Optional[Sequence[int]] = None):
         (1, 3)
 
     """
-    if type(dim) == int:
+    if isinstance(dim, int):
         dim = [dim]
+    elif dim is None:
+        dim = range(input.ndim)
+
+    dim = list(filter(lambda i: input.size(i) == 1, dim))
     return Squeeze(dim=dim)(input)
 
 
diff --git a/oneflow/python/test/modules/test_slice.py b/oneflow/python/test/modules/test_slice.py
index 93aad7328..b9469fe1b 100644
--- a/oneflow/python/test/modules/test_slice.py
+++ b/oneflow/python/test/modules/test_slice.py
@@ -32,6 +32,14 @@ def _test_slice(test_case, device):
     test_case.assertTrue(np.array_equal(y.numpy(), np_out))
 
 
+def _test_slice_1_dim(test_case, device):
+    np_arr = np.random.randn(100).astype(np.float32)
+    x = flow.Tensor(np_arr, device=flow.device(device))
+    test_case.assertTrue(np.allclose(x[1].numpy(), np_arr[1], 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x[99].numpy(), np_arr[99], 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x[0:2].numpy(), np_arr[0:2], 1e-5, 1e-5))
+
+
 def _test_slice_4_dim(test_case, device):
     np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32)
     x = flow.Tensor(np_arr, device=flow.device(device))
@@ -42,6 +50,32 @@ def _test_slice_4_dim(test_case, device):
     test_case.assertTrue(np.array_equal(y.numpy(), np_out))
 
 
+def _test_slice_with_int_index(test_case, device):
+    np_arr = np.random.randn(2, 3, 4).astype(np.float32)
+    x = flow.Tensor(np_arr, device=flow.device(device))
+    of_out = x[0, 1:2]
+    np_out = np_arr[0, 1:2]
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+    np_arr = np.random.randn(2, 3, 4).astype(np.float32)
+    x = flow.Tensor(np_arr, device=flow.device(device))
+    of_out = x[0, :]
+    np_out = np_arr[0, :]
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+    np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.float32)
+    x = flow.Tensor(np_arr, device=flow.device(device))
+    of_out = x[0, :, :]
+    np_out = np_arr[0, :, :]
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+    np_arr = np.random.randn(2, 3, 4, 5).astype(np.float32)
+    x = flow.Tensor(np_arr, device=flow.device(device))
+    of_out = x[0, :, :, :]
+    np_out = np_arr[0, :, :, :]
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+
 def _test_slice_backward(test_case, device):
     np_arr = np.random.randn(3, 6, 9).astype(np.float32)
     x = flow.Tensor(np_arr, device=flow.device(device), requires_grad=True)
@@ -62,7 +96,13 @@ def _test_slice_backward(test_case, device):
 class TestSlice(flow.unittest.TestCase):
     def test_slice(test_case):
         arg_dict = OrderedDict()
-        arg_dict["test_fun"] = [_test_slice, _test_slice_4_dim, _test_slice_backward]
+        arg_dict["test_fun"] = [
+            _test_slice,
+            _test_slice_1_dim,
+            _test_slice_4_dim,
+            _test_slice_with_int_index,
+            _test_slice_backward,
+        ]
         arg_dict["device"] = ["cpu", "cuda"]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
diff --git a/oneflow/python/test/modules/test_squeeze.py b/oneflow/python/test/modules/test_squeeze.py
index 58644a73b..e98497053 100644
--- a/oneflow/python/test/modules/test_squeeze.py
+++ b/oneflow/python/test/modules/test_squeeze.py
@@ -38,6 +38,13 @@ def _test_squeeze(test_case, device):
     )
 
 
+def _test_squeeze_1d_input(test_case, device):
+    np_arr = np.random.rand(10)
+    input = flow.Tensor(np_arr, device=flow.device(device))
+    output = flow.squeeze(input)
+    test_case.assertTrue(np.allclose(output.numpy(), np_arr, 1e-5, 1e-5))
+
+
 def _test_tensor_squeeze(test_case, device):
     np_arr = np.random.rand(1, 1, 1, 3)
     input = flow.Tensor(np_arr, device=flow.device(device))
@@ -85,6 +92,7 @@ class TestSqueeze(flow.unittest.TestCase):
         arg_dict = OrderedDict()
         arg_dict["test_fun"] = [
             _test_squeeze,
+            _test_squeeze_1d_input,
             _test_squeeze_int,
             _test_tensor_squeeze,
             _test_squeeze_backward,
-- 
GitLab