diff --git a/oneflow/python/nn/modules/dataset.py b/oneflow/python/nn/modules/dataset.py
index 9b3fa9a4b53428523a42d9cb6c55aa162b802530..8b35ed603598ec983c9c1eecc90ece63277187ad 100644
--- a/oneflow/python/nn/modules/dataset.py
+++ b/oneflow/python/nn/modules/dataset.py
@@ -382,13 +382,6 @@ class ImageResize(Module):
                 .Attr("interpolation_type", interpolation_type)
                 .Build()
             )
-            # TODO(Liang Depeng)
-            # scale = flow.tensor_buffer_to_tensor(
-            #     scale, dtype=flow.float32, instance_shape=(2,)
-            # )
-            # new_size = flow.tensor_buffer_to_tensor(
-            #     new_size, dtype=flow.int32, instance_shape=(2,)
-            # )
         else:
             if (
                 not isinstance(target_size, (list, tuple))
@@ -417,8 +410,19 @@ class ImageResize(Module):
             )
 
     def forward(self, input):
-        res = self._op(input)[0]
-        return res
+        res = self._op(input)
+        res_image = res[0]
+        if len(res) == 3:
+            new_size = flow.experimental.tensor_buffer_to_tensor(
+                res[1], dtype=flow.int32, instance_shape=(2,)
+            )
+            scale = flow.experimental.tensor_buffer_to_tensor(
+                res[2], dtype=flow.float32, instance_shape=(2,)
+            )
+        else:
+            new_size = None
+            scale = res[1]
+        return res_image, scale, new_size
 
 
 @oneflow_export("tmp.RawDecoder")
diff --git a/oneflow/python/test/modules/image_test_util.py b/oneflow/python/test/modules/image_test_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..897f219b30957a6b2b2d606e52b1fbd450b0fb6e
--- /dev/null
+++ b/oneflow/python/test/modules/image_test_util.py
@@ -0,0 +1,174 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import numpy as np
+import cv2
+import oneflow as flow
+import PIL
+import random
+import os
+
+global_coco_dict = dict()
+default_coco_anno_file = "/dataset/mscoco_2017/annotations/instances_val2017.json"
+default_coco_image_dir = "/dataset/mscoco_2017/val2017"
+
+
+def get_coco(anno_file):
+    global global_coco_dict
+
+    if anno_file not in global_coco_dict:
+        from pycocotools.coco import COCO
+
+        global_coco_dict[anno_file] = COCO(anno_file)
+
+    return global_coco_dict[anno_file]
+
+
+def random_sample_images_from_coco(
+    anno_file=default_coco_anno_file, image_dir=default_coco_image_dir, batch_size=2
+):
+    image_files = []
+    image_ids = []
+    batch_group_id = -1
+
+    coco = get_coco(anno_file)
+    img_ids = coco.getImgIds()
+
+    while len(image_files) < batch_size:
+        rand_img_id = random.choice(img_ids)
+        img_h = coco.imgs[rand_img_id]["height"]
+        img_w = coco.imgs[rand_img_id]["width"]
+        group_id = int(img_h / img_w)
+
+        if batch_group_id == -1:
+            batch_group_id = group_id
+
+        if group_id != batch_group_id:
+            continue
+
+        image_files.append(os.path.join(image_dir, coco.imgs[rand_img_id]["file_name"]))
+        image_ids.append(rand_img_id)
+
+    assert len(image_files) == len(image_ids)
+    return image_files, image_ids
+
+
+def read_images_by_cv(image_files, dtype, channels=3):
+    np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype)
+    images = [cv2.imread(image_file).astype(np_dtype) for image_file in image_files]
+    assert all(isinstance(image, np.ndarray) for image in images)
+    assert all(image.ndim == 3 for image in images)
+    assert all(image.shape[2] == channels for image in images)
+    return images
+
+
+def read_images_by_pil(image_files, dtype, channels=3):
+    image_objs = [PIL.Image.open(image_file) for image_file in image_files]
+    images = []
+    np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype)
+
+    for im in image_objs:
+        bands = im.getbands()
+        band = "".join(bands)
+        if band == "RGB":
+            # convert to BGR
+            images.append(np.asarray(im).astype(np_dtype)[:, :, ::-1])
+        elif band == "L":
+            gs_image = np.asarray(im).astype(np_dtype)
+            gs_image_shape = gs_image.shape
+            assert len(gs_image_shape) == 2
+            gs_image = gs_image.reshape(gs_image_shape + (1,))
+            gs_image = np.broadcast_to(gs_image, shape=gs_image_shape + (3,))
+            images.append(gs_image)
+        elif band == "BGR":
+            images.append(np.asarray(im).astype(np_dtype))
+        else:
+            raise NotImplementedError
+
+    assert all(isinstance(image, np.ndarray) for image in images)
+    assert all(image.ndim == 3 for image in images)
+    assert all(image.shape[2] == channels for image in images)
+
+    return images
+
+
+def infer_images_static_shape(images, channels=3):
+    image_shapes = [image.shape for image in images]
+    assert all(image.ndim == 3 for image in images)
+    assert all(image.shape[2] == channels for image in images)
+    image_shapes = np.asarray(image_shapes)
+
+    max_h = np.max(image_shapes[:, 0]).item()
+    max_w = np.max(image_shapes[:, 1]).item()
+    image_static_shape = (len(images), max_h, max_w, channels)
+
+    group_ids = []  # 0: h < w, 1: h >= w
+    aspect_ratio_list = []  # shorter / longer
+    for image_shape in image_shapes:
+        h, w = image_shape[0:2]
+        if h < w:
+            group_id = 0
+            aspect_ratio = h / w
+        else:
+            group_id = 1
+            aspect_ratio = w / h
+        group_ids.append(group_id)
+        aspect_ratio_list.append(aspect_ratio)
+    assert all(group_id == group_ids[0] for group_id in group_ids)
+
+    return image_static_shape, aspect_ratio_list
+
+
+def compute_keep_aspect_ratio_resized_size(
+    target_size, min_size, max_size, aspect_ratio, resize_side
+):
+    if resize_side == "shorter":
+        min_res_size = target_size
+        max_res_size = int(round(min_res_size / aspect_ratio))
+        if max_size is not None and max_res_size > max_size:
+            max_res_size = max_size
+            min_res_size = int(round(max_res_size * aspect_ratio))
+    elif resize_side == "longer":
+        max_res_size = target_size
+        min_res_size = int(round(max_res_size * aspect_ratio))
+        if min_size is not None and min_res_size < min_size:
+            min_res_size = min_size
+            max_res_size = int(round(min_res_size / aspect_ratio))
+    else:
+        raise NotImplementedError
+
+    return (min_res_size, max_res_size)
+
+
+def infer_keep_aspect_ratio_resized_images_static_shape(
+    target_size,
+    min_size,
+    max_size,
+    aspect_ratio_list,
+    resize_side="shorter",
+    channels=3,
+):
+    resized_size_list = []
+    for aspect_ratio in aspect_ratio_list:
+        resized_size_list.append(
+            compute_keep_aspect_ratio_resized_size(
+                target_size, min_size, max_size, aspect_ratio, resize_side
+            )
+        )
+
+    res_min_size, res_max_size = max(
+        resized_size_list, key=lambda size: size[0] * size[1]
+    )
+    return (res_min_size, res_max_size, channels)
diff --git a/oneflow/python/test/modules/test_dataset.py b/oneflow/python/test/modules/test_dataset.py
index f35e706dd323b9f9f1275c01ab43b8e14d810e46..541cc1c65185eb4ca1075545b9c3f50413afeef6 100644
--- a/oneflow/python/test/modules/test_dataset.py
+++ b/oneflow/python/test/modules/test_dataset.py
@@ -73,7 +73,7 @@ class TestOFRecordModule(flow.unittest.TestCase):
         gt_np = cv2.imread("/dataset/imagenette/ofrecord/gt_tensor_buffer_image.png")
         test_case.assertTrue(np.array_equal(image_raw_buffer_nd, gt_np))
 
-        image = resize(image_raw_buffer)
+        image = resize(image_raw_buffer)[0]
 
         resized_image_raw_buffer_nd = image.numpy()[0]
         gt_np = cv2.imread(
diff --git a/oneflow/python/test/modules/test_image_resize.py b/oneflow/python/test/modules/test_image_resize.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb9d66e4c0452328959537c4bd36c165a0521586
--- /dev/null
+++ b/oneflow/python/test/modules/test_image_resize.py
@@ -0,0 +1,282 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import cv2
+import numpy as np
+import oneflow.experimental as flow
+import oneflow.experimental.nn as nn
+import image_test_util
+
+
+def _of_image_resize(
+    image_list,
+    dtype=flow.float32,
+    origin_dtype=flow.float32,
+    channels=3,
+    keep_aspect_ratio=False,
+    target_size=None,
+    min_size=None,
+    max_size=None,
+    resize_side="shorter",
+    interpolation_type="bilinear",
+):
+    assert isinstance(image_list, (list, tuple))
+    assert all(isinstance(image, np.ndarray) for image in image_list)
+    assert all(image.ndim == 3 for image in image_list)
+    assert all(image.shape[2] == channels for image in image_list)
+
+    res_image_list = []
+    res_size_list = []
+    res_scale_list = []
+    image_resize_module = nn.image.Resize(
+        target_size=target_size,
+        min_size=min_size,
+        max_size=max_size,
+        keep_aspect_ratio=keep_aspect_ratio,
+        resize_side=resize_side,
+        dtype=dtype,
+        interpolation_type=interpolation_type,
+        channels=channels,
+    )
+    for image in image_list:
+        tensor_dtype = dtype if keep_aspect_ratio else origin_dtype
+        input = flow.Tensor(
+            np.expand_dims(image, axis=0), dtype=tensor_dtype, device=flow.device("cpu")
+        )
+        image_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=3)
+        res_image, scale, new_size = image_resize_module(image_buffer)
+        res_image = res_image.numpy()
+        scale = scale.numpy()
+        if not keep_aspect_ratio:
+            new_size = np.asarray([(target_size, target_size)])
+        else:
+            new_size = new_size.numpy()
+        res_image_list.append(res_image[0])
+        res_size_list.append(new_size[0])
+        res_scale_list.append(scale[0])
+    return (res_image_list, res_scale_list, res_size_list)
+
+
+def _get_resize_size_and_scale(
+    w,
+    h,
+    target_size,
+    min_size=None,
+    max_size=None,
+    keep_aspect_ratio=True,
+    resize_side="shorter",
+):
+    if keep_aspect_ratio:
+        assert isinstance(target_size, int)
+        aspect_ratio = float(min((w, h))) / float(max((w, h)))
+        (
+            min_res_size,
+            max_res_size,
+        ) = image_test_util.compute_keep_aspect_ratio_resized_size(
+            target_size, min_size, max_size, aspect_ratio, resize_side
+        )
+        if w < h:
+            res_w = min_res_size
+            res_h = max_res_size
+        else:
+            res_w = max_res_size
+            res_h = min_res_size
+
+    else:
+        assert isinstance(target_size, (list, tuple))
+        assert len(target_size) == 2
+        assert all(isinstance(size, int) for size in target_size)
+        res_w, res_h = target_size
+
+    scale_w = res_w / w
+    scale_h = res_h / h
+    return (res_w, res_h), (scale_w, scale_h)
+
+
+def _cv_image_resize(
+    image_list,
+    target_size,
+    keep_aspect_ratio=True,
+    min_size=None,
+    max_size=None,
+    resize_side="shorter",
+    interpolation=cv2.INTER_LINEAR,
+    dtype=np.float32,
+):
+    res_image_list = []
+    res_size_list = []
+    res_scale_list = []
+
+    for image in image_list:
+        h, w = image.shape[:2]
+        new_size, scale = _get_resize_size_and_scale(
+            w, h, target_size, min_size, max_size, keep_aspect_ratio, resize_side
+        )
+        res_image_list.append(
+            cv2.resize(image.squeeze(), new_size, interpolation=interpolation).astype(
+                dtype
+            )
+        )
+        res_size_list.append(new_size)
+        res_scale_list.append(scale)
+
+    return res_image_list, res_scale_list, res_size_list
+
+
+def _test_image_resize_with_cv(
+    test_case,
+    image_files,
+    target_size,
+    min_size=None,
+    max_size=None,
+    keep_aspect_ratio=True,
+    resize_side="shorter",
+    dtype=flow.float32,
+    origin_dtype=None,
+):
+    if origin_dtype is None:
+        origin_dtype = dtype
+
+    image_list = image_test_util.read_images_by_cv(image_files, origin_dtype)
+
+    of_res_images, of_scales, of_new_sizes = _of_image_resize(
+        image_list=image_list,
+        dtype=dtype,
+        origin_dtype=origin_dtype,
+        keep_aspect_ratio=keep_aspect_ratio,
+        target_size=target_size,
+        min_size=min_size,
+        max_size=max_size,
+        resize_side=resize_side,
+    )
+
+    cv_res_images, cv_scales, cv_new_sizes = _cv_image_resize(
+        image_list=image_list,
+        target_size=target_size,
+        keep_aspect_ratio=keep_aspect_ratio,
+        min_size=min_size,
+        max_size=max_size,
+        resize_side=resize_side,
+        dtype=flow.convert_oneflow_dtype_to_numpy_dtype(dtype),
+    )
+
+    for (
+        of_res_image,
+        cv_res_image,
+        of_scale,
+        cv_scale,
+        of_new_size,
+        cv_new_size,
+    ) in zip(
+        of_res_images, cv_res_images, of_scales, cv_scales, of_new_sizes, cv_new_sizes,
+    ):
+        test_case.assertTrue(np.allclose(of_res_image, cv_res_image))
+        test_case.assertTrue(np.allclose(of_scale, cv_scale))
+        test_case.assertTrue(np.allclose(of_new_size, cv_new_size))
+
+
+@flow.unittest.skip_unless_1n1d()
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestImageResize(flow.unittest.TestCase):
+    def test_image_resize_to_fixed_size(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case, image_files, target_size=(224, 224), keep_aspect_ratio=False,
+        )
+
+    def test_image_resize_shorter_to_target_size(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=800,
+            keep_aspect_ratio=True,
+            resize_side="shorter",
+        )
+
+    def test_image_resize_longer_to_target_size(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=1000,
+            keep_aspect_ratio=True,
+            resize_side="longer",
+        )
+
+    def test_image_resize_shorter_to_target_size_with_max_size(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=800,
+            max_size=1333,
+            keep_aspect_ratio=True,
+            resize_side="shorter",
+        )
+
+    def test_image_resize_longer_to_target_size_with_min_size(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=1000,
+            min_size=600,
+            keep_aspect_ratio=True,
+            resize_side="longer",
+        )
+
+    def test_image_resize_to_fixed_size_with_dtype_uint8(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=(1000, 1000),
+            keep_aspect_ratio=False,
+            dtype=flow.uint8,
+        )
+
+    def test_image_reisze_shorter_to_target_size_with_max_size_with_dtype_uint8(
+        test_case,
+    ):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=1000,
+            max_size=1600,
+            keep_aspect_ratio=True,
+            resize_side="shorter",
+            dtype=flow.uint8,
+        )
+
+    def test_image_resize_uint8_to_float(test_case):
+        image_files, _ = image_test_util.random_sample_images_from_coco()
+        _test_image_resize_with_cv(
+            test_case,
+            image_files,
+            target_size=(1000, 1000),
+            keep_aspect_ratio=False,
+            dtype=flow.float32,
+            origin_dtype=flow.uint8,
+        )
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_resnet50_with_bn.py b/oneflow/python/test/modules/test_resnet50_with_bn.py
index 50f6ee3b2aec8c79d8ae0cc3d76f9e70b72d598e..41daebabd43532d2ab45a504939872718a3fcce8 100644
--- a/oneflow/python/test/modules/test_resnet50_with_bn.py
+++ b/oneflow/python/test/modules/test_resnet50_with_bn.py
@@ -82,7 +82,7 @@ class TestResNet50(flow.unittest.TestCase):
             val_record = record_reader()
             label = record_label_decoder(val_record)
             image_raw_buffer = record_image_decoder(val_record)
-            image = resize(image_raw_buffer)
+            image = resize(image_raw_buffer)[0]
             image = crop_mirror_normal(image)
             image = image.to("cuda")
             label = label.to("cuda")
diff --git a/oneflow/python/test/modules/test_resnet50_without_bn.py b/oneflow/python/test/modules/test_resnet50_without_bn.py
index e2cecc59cfb3221493cea743b5ba15257938f91a..0e6c3a595a55cb59c19dcdbd80efc778a3cb2ed0 100644
--- a/oneflow/python/test/modules/test_resnet50_without_bn.py
+++ b/oneflow/python/test/modules/test_resnet50_without_bn.py
@@ -188,7 +188,7 @@ class TestResNet50(flow.unittest.TestCase):
             val_record = record_reader()
             label = record_label_decoder(val_record)
             image_raw_buffer = record_image_decoder(val_record)
-            image = resize(image_raw_buffer)
+            image = resize(image_raw_buffer)[0]
             image = crop_mirror_normal(image)
             image = image.to("cuda")
             label = label.to("cuda")