diff --git a/oneflow/python/nn/module.py b/oneflow/python/nn/module.py
index d6d5985a1d63a57527ba9accf53905fc74661249..85f37ff5eceb43f2bae4e6e61bc581bde5264d7f 100644
--- a/oneflow/python/nn/module.py
+++ b/oneflow/python/nn/module.py
@@ -40,6 +40,18 @@ class _IncompatibleKeys(
     __str__ = __repr__
 
 
+def _addindent(s_, numSpaces):
+    s = s_.split("\n")
+    # don't do anything for single-line stuff
+    if len(s) == 1:
+        return s_
+    first = s.pop(0)
+    s = [(numSpaces * " ") + line for line in s]
+    s = "\n".join(s)
+    s = first + "\n" + s
+    return s
+
+
 # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
 # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
 # the type of the subclass, not the looser type of `Module`.
@@ -529,3 +541,40 @@ class Module(object):
             return t.to(device)
 
         return self._apply(convert)
+
+    def _get_name(self):
+        return self.__class__.__name__
+
+    def extra_repr(self) -> str:
+        r"""Set the extra representation of the module
+
+        To print customized extra information, you should re-implement
+        this method in your own modules. Both single-line and multi-line
+        strings are acceptable.
+        """
+        return ""
+
+    def __repr__(self):
+        # We treat the extra repr like the sub-module, one item per line
+        extra_lines = []
+        extra_repr = self.extra_repr()
+        # empty string will be split into list ['']
+        if extra_repr:
+            extra_lines = extra_repr.split("\n")
+        child_lines = []
+        for key, module in self._modules.items():
+            mod_str = repr(module)
+            mod_str = _addindent(mod_str, 2)
+            child_lines.append("(" + key + "): " + mod_str)
+        lines = extra_lines + child_lines
+
+        main_str = self._get_name() + "("
+        if lines:
+            # simple one-liner info, which most builtin Modules will use
+            if len(extra_lines) == 1 and not child_lines:
+                main_str += extra_lines[0]
+            else:
+                main_str += "\n  " + "\n  ".join(lines) + "\n"
+
+        main_str += ")"
+        return main_str
diff --git a/oneflow/python/nn/modules/activation.py b/oneflow/python/nn/modules/activation.py
index ffca0b143a57a381aa7b2867f26a22709e95969c..c855a03345ec3290b4f72cdccdec165ece53bf32 100644
--- a/oneflow/python/nn/modules/activation.py
+++ b/oneflow/python/nn/modules/activation.py
@@ -13,13 +13,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """
+
+import warnings
+from typing import Optional
+
 import oneflow as flow
 import oneflow._oneflow_internal
 from oneflow.python.nn.module import Module
 from oneflow.python.nn.modules.utils import _check_inplace_valid
 from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.framework.tensor import register_tensor_op
-from typing import Optional
 
 
 def _softmax_need_transpose(x, axis):
@@ -140,6 +143,10 @@ class ReLU(Module):
             _check_inplace_valid(x)
         return flow.F.relu(x, self.inplace)
 
+    def extra_repr(self):
+        inplace_str = "inplace=True" if self.inplace else ""
+        return inplace_str
+
 
 @oneflow_export("nn.ReLU6")
 @experimental_api
@@ -182,10 +189,17 @@ class ReLU6(Module):
 
     def __init__(self, inplace: bool = False):
         super().__init__()
+        self.inplace = inplace
 
     def forward(self, x):
+        if self.inplace:
+            warnings.warn("ReLU6 module do not support inplace now")
         return flow.F.hardtanh(x, min_val=0.0, max_val=6.0)
 
+    def extra_repr(self):
+        inplace_str = "inplace=True" if self.inplace else ""
+        return inplace_str
+
 
 @oneflow_export("nn.Tanh")
 @experimental_api
@@ -308,10 +322,18 @@ class ELU(Module):
     def __init__(self, alpha: float = 1.0, inplace: bool = False):
         super().__init__()
         self.alpha = alpha
+        self.inplace = inplace
 
     def forward(self, x):
+        if self.inplace:
+            warnings.warn("ELU module do not support inplace now")
         return flow.F.elu(x, alpha=self.alpha)
 
+    def extra_repr(self):
+        param_str = f"alpha={self.alpha}"
+        param_str += ", inplace=True" if self.inplace else ""
+        return param_str
+
 
 @oneflow_export("nn.GELU")
 @experimental_api
@@ -498,10 +520,17 @@ class Hardsigmoid(Module):
 
     def __init__(self, inplace: bool = False):
         super().__init__()
+        self.inplace = inplace
 
     def forward(self, x):
+        if self.inplace:
+            warnings.warn("Hardsigmoid module do not support inplace now")
         return flow.F.hardsigmoid(x)
 
+    def extra_repr(self):
+        inplace_str = "inplace=True" if self.inplace else ""
+        return inplace_str
+
 
 @oneflow_export("nn.Softmax")
 @experimental_api
@@ -520,6 +549,9 @@ class Softmax(Module):
             res = flow.F.transpose(res, perm=permute)
         return res
 
+    def extra_repr(self):
+        return f"axis={self.axis}"
+
 
 @oneflow_export("softmax")
 @register_tensor_op("softmax")
@@ -637,7 +669,7 @@ class LogSoftmax(Module):
         return res
 
     def extra_repr(self):
-        return "dim={dim}".format(dim=self.dim)
+        return f"dim={self.dim}"
 
 
 @oneflow_export("nn.LogSigmoid")
@@ -735,6 +767,9 @@ class Softplus(Module):
             * flow.experimental.log(1.0 + flow.experimental.exp(self.beta * x)),
         )
 
+    def extra_repr(self):
+        return f"beta={self.beta}, threshold={self.threshold}"
+
 
 @oneflow_export("nn.Hardswish")
 @experimental_api
@@ -777,10 +812,17 @@ class Hardswish(Module):
 
     def __init__(self, inplace: bool = False):
         super().__init__()
+        self.inplace = inplace
 
     def forward(self, x):
+        if self.inplace:
+            warnings.warn("Hardswish module do not support inplace now")
         return flow.F.hardswish(x)
 
+    def extra_repr(self):
+        inplace_str = "inplace=True" if self.inplace else ""
+        return inplace_str
+
 
 @oneflow_export("nn.Hardtanh")
 @experimental_api
@@ -853,10 +895,18 @@ class Hardtanh(Module):
 
         self.min_val = min_val
         self.max_val = max_val
+        self.inplace = inplace
 
     def forward(self, x):
+        if self.inplace:
+            warnings.warn("Hardtanh module do not support inplace now")
         return flow.F.hardtanh(x, min_val=self.min_val, max_val=self.max_val)
 
+    def extra_repr(self):
+        param_str = f"min_val={self.min_val}, max_val={self.max_val}"
+        param_str += ", inplace=True" if self.inplace else ""
+        return param_str
+
 
 @oneflow_export("nn.LeakyReLU")
 @experimental_api
@@ -897,10 +947,18 @@ class LeakyReLU(Module):
     def __init__(self, negative_slope: float = 1e-2, inplace: bool = False):
         super().__init__()
         self.negative_slope = negative_slope
+        self.inplace = inplace
 
     def forward(self, x):
+        if self.inplace:
+            warnings.warn("LeakyReLU module do not support inplace now")
         return flow.F.leaky_relu(x, alpha=self.negative_slope)
 
+    def extra_repr(self):
+        param_str = f"negative_slope={self.negative_slope}"
+        param_str += ", inplace=True" if self.inplace else ""
+        return param_str
+
 
 @oneflow_export("nn.Mish")
 @experimental_api
diff --git a/oneflow/python/nn/modules/batchnorm.py b/oneflow/python/nn/modules/batchnorm.py
index 6b70e9976dd66bb26615120fc58fa0b5e8edcb0c..67dbecbc0eb000e2cc2adf26d50e2c0f262b2bc2 100644
--- a/oneflow/python/nn/modules/batchnorm.py
+++ b/oneflow/python/nn/modules/batchnorm.py
@@ -91,6 +91,12 @@ class _NormBase(Module):
             error_msgs,
         )
 
+    def extra_repr(self):
+        return (
+            "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
+            "track_running_stats={track_running_stats}".format(**self.__dict__)
+        )
+
 
 class _BatchNorm(_NormBase):
     def __init__(
diff --git a/oneflow/python/nn/modules/container.py b/oneflow/python/nn/modules/container.py
index b2f4f0c0a7d8778f367eff037adca914ef17585e..f36517b320a1823faa44ad03d70442c00819ef80 100644
--- a/oneflow/python/nn/modules/container.py
+++ b/oneflow/python/nn/modules/container.py
@@ -52,14 +52,24 @@ class Sequential(Module):
 
         >>> import oneflow.experimental.nn as nn
         >>> nn.Sequential(nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU()) #doctest: +ELLIPSIS
-        <oneflow.python.nn.modules.container.Sequential object at 0x...>
+        Sequential(
+          (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
+          (1): ReLU()
+          (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
+          (3): ReLU()
+        )
         >>> nn.Sequential(OrderedDict([
         ...    ('conv1', nn.Conv2d(1,20,5)),
         ...    ('relu1', nn.ReLU()),
         ...    ('conv2', nn.Conv2d(20,64,5)),
         ...    ('relu2', nn.ReLU())
         ... ])) #doctest: +ELLIPSIS
-        <oneflow.python.nn.modules.container.Sequential object at 0x...>
+        Sequential(
+          (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
+          (relu1): ReLU()
+          (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
+          (relu2): ReLU()
+        )
 
     """
 
diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py
index 4bf2f17c2ca8cc189e5ae4301432837eda8cc89f..58432a5784c82057cbc55bb57d8395632638686a 100644
--- a/oneflow/python/nn/modules/conv.py
+++ b/oneflow/python/nn/modules/conv.py
@@ -186,6 +186,7 @@ class Conv1d(Module):
         super().__init__()
 
         assert padding_mode == "zeros"
+        self.padding_mode = padding_mode
         self.kernel_size = _single(kernel_size)
         self.stride = _single(stride)
         self.padding = _single(padding)
@@ -257,6 +258,23 @@ class Conv1d(Module):
             )
         return res
 
+    def extra_repr(self):
+        s = (
+            "{in_channels}, {out_channels}, kernel_size={kernel_size}"
+            ", stride={stride}"
+        )
+        if self.padding != (0,) * len(self.padding):
+            s += ", padding={padding}"
+        if self.dilation != (1,) * len(self.dilation):
+            s += ", dilation={dilation}"
+        if self.groups != 1:
+            s += ", groups={groups}"
+        if self.bias is None:
+            s += ", bias=False"
+        if self.padding_mode != "zeros":
+            s += ", padding_mode={padding_mode}"
+        return s.format(**self.__dict__)
+
 
 @oneflow_export("nn.Conv2d")
 @experimental_api
@@ -394,6 +412,7 @@ class Conv2d(Module):
         super().__init__()
 
         assert padding_mode == "zeros"
+        self.padding_mode = padding_mode
         self.kernel_size = _pair(kernel_size)
         self.stride = _pair(stride)
         self.padding = _pair(padding)
@@ -466,6 +485,23 @@ class Conv2d(Module):
             )
         return res
 
+    def extra_repr(self):
+        s = (
+            "{in_channels}, {out_channels}, kernel_size={kernel_size}"
+            ", stride={stride}"
+        )
+        if self.padding != (0,) * len(self.padding):
+            s += ", padding={padding}"
+        if self.dilation != (1,) * len(self.dilation):
+            s += ", dilation={dilation}"
+        if self.groups != 1:
+            s += ", groups={groups}"
+        if self.bias is None:
+            s += ", bias=False"
+        if self.padding_mode != "zeros":
+            s += ", padding_mode={padding_mode}"
+        return s.format(**self.__dict__)
+
 
 if __name__ == "__main__":
     import doctest
diff --git a/oneflow/python/nn/modules/flatten.py b/oneflow/python/nn/modules/flatten.py
index 0583356914376966d47835c28a0aee7f6a1bae68..a32949997b4c14aee506da62c653159389558f82 100644
--- a/oneflow/python/nn/modules/flatten.py
+++ b/oneflow/python/nn/modules/flatten.py
@@ -51,6 +51,9 @@ class Flatten(Module):
     def forward(self, input):
         return flow.F.flatten(input, start_dim=self.start_dim, end_dim=self.end_dim)
 
+    def extra_repr(self) -> str:
+        return "start_dim={}, end_dim={}".format(self.start_dim, self.end_dim)
+
 
 @oneflow_export("flatten")
 @register_tensor_op("flatten")
diff --git a/oneflow/python/nn/modules/linear.py b/oneflow/python/nn/modules/linear.py
index 77047fe6a1c5cf721fe26e7643aeeba50de28c65..3f9aee4cd52416c035aca8dae680ae3722de316a 100644
--- a/oneflow/python/nn/modules/linear.py
+++ b/oneflow/python/nn/modules/linear.py
@@ -18,7 +18,6 @@ from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.framework.tensor import Tensor
 from oneflow.python.nn.module import Module
 from oneflow.python.nn.init import _calculate_fan_in_and_fan_out
-from typing import Optional, List, Tuple
 import math
 
 
@@ -100,6 +99,8 @@ class Linear(Module):
     def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
         super().__init__()
 
+        self.in_features = in_features
+        self.out_features = out_features
         self.use_bias = bias
         self.weight = flow.nn.Parameter(flow.Tensor(out_features, in_features))
         self.bias = None
@@ -131,6 +132,11 @@ class Linear(Module):
 
         return res
 
+    def extra_repr(self) -> str:
+        return "in_features={}, out_features={}, bias={}".format(
+            self.in_features, self.out_features, self.bias is not None
+        )
+
 
 if __name__ == "__main__":
     import doctest
diff --git a/oneflow/python/nn/modules/normalization.py b/oneflow/python/nn/modules/normalization.py
index 6df5cfe1b5b2b1315664cdfd00a32d83c6eb0878..0b53993465892f14641d611b2b0f69064a843245 100644
--- a/oneflow/python/nn/modules/normalization.py
+++ b/oneflow/python/nn/modules/normalization.py
@@ -131,6 +131,11 @@ class GroupNorm(Module):
 
         return res
 
+    def extra_repr(self) -> str:
+        return "{num_groups}, {num_channels}, eps={eps}, " "affine={affine}".format(
+            **self.__dict__
+        )
+
 
 @oneflow_export("nn.LayerNorm")
 @experimental_api
diff --git a/oneflow/python/nn/modules/padding.py b/oneflow/python/nn/modules/padding.py
index b7911db8b130e9c1dc260dfe2b1514d54c35b4f1..1d24752e86098db6f278dafeff51addbfbdf8f8c 100644
--- a/oneflow/python/nn/modules/padding.py
+++ b/oneflow/python/nn/modules/padding.py
@@ -108,6 +108,9 @@ class ReplicationPad2d(Module):
                 "Padding size should be less than the corresponding input dimension. Please check."
             )
 
+    def extra_repr(self) -> str:
+        return "{}".format(self.padding)
+
 
 @oneflow_export("nn.ReflectionPad2d")
 @experimental_api
@@ -183,6 +186,9 @@ class ReflectionPad2d(Module):
                 "padding size should be less than the corresponding input dimension!"
             )
 
+    def extra_repr(self) -> str:
+        return "{}".format(self.padding)
+
 
 if __name__ == "__main__":
     import doctest
diff --git a/oneflow/python/nn/modules/pixelshuffle.py b/oneflow/python/nn/modules/pixelshuffle.py
index 9bc0108b106691a84668c4424f3ccd2bd2070e7b..00fd7910defc43a45227c9530291088318ba9298 100644
--- a/oneflow/python/nn/modules/pixelshuffle.py
+++ b/oneflow/python/nn/modules/pixelshuffle.py
@@ -155,6 +155,9 @@ class PixelShufflev2(Module):
 
         return out
 
+    def extra_repr(self) -> str:
+        return f"w_upscale_factor={self.w_upscale_factor}, h_upscale_factor={self.h_upscale_factor}"
+
 
 if __name__ == "__main__":
     import doctest
diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py
index d34a7105305beb428c352c834827e4e453e44e24..57de40a2cbb4fc10f614eb51e3b8e512a5801695 100644
--- a/oneflow/python/nn/modules/pooling.py
+++ b/oneflow/python/nn/modules/pooling.py
@@ -123,6 +123,7 @@ class AvgPool2d(Module):
             padding, tuple
         ), "padding can only int int or tuple of 2 ints."
         padding = _pair(padding)
+        self.padding = padding
         padding = [0, 0, *padding]
 
         assert count_include_pad is None, "count_include_pad not supported yet"
@@ -149,6 +150,12 @@ class AvgPool2d(Module):
             data_format=self._channel_pos,
         )
 
+    def extra_repr(self) -> str:
+        return (
+            "kernel_size={kernel_size}, stride={stride}, padding={padding}"
+            ", ceil_mode={ceil_mode}".format(**self.__dict__)
+        )
+
 
 @oneflow_export("nn.AvgPool3d")
 @experimental_api
@@ -352,6 +359,11 @@ class MaxPool1d(Module):
         else:
             return y
 
+    def extra_repr(self) -> str:
+        return "kernel_size={}, stride={}, padding={}".format(
+            self.kernel_size, self.stride, self.padding
+        )
+
 
 @oneflow_export("nn.MaxPool2d")
 @experimental_api
@@ -452,6 +464,7 @@ class MaxPool2d(Module):
         self.ceil_mode = ceil_mode
 
         padding = _pair(padding)
+        self.padding = padding
         if len(padding) == 2:
             if data_format == "NCHW":
                 padding = (0, 0, padding[0], padding[1])
@@ -485,6 +498,11 @@ class MaxPool2d(Module):
         else:
             return y
 
+    def extra_repr(self) -> str:
+        return "kernel_size={}, stride={}, padding={}, dilation={}".format(
+            self.kernel_size, self.stride, self.padding, self.dilation
+        )
+
 
 @oneflow_export("nn.MaxPool3d")
 @experimental_api
@@ -588,6 +606,7 @@ class MaxPool3d(Module):
         )
         self.dilation = _GetSequence(dilation, 3, "dilation")
         padding = _triple(padding)
+        self.padding = padding
         self.return_indices = return_indices
         self.ceil_mode = ceil_mode
 
@@ -624,6 +643,11 @@ class MaxPool3d(Module):
         else:
             return y
 
+    def extra_repr(self) -> str:
+        return "kernel_size={}, stride={}, padding={}, dilation={}".format(
+            self.kernel_size, self.stride, self.padding, self.dilation
+        )
+
 
 if __name__ == "__main__":
     import doctest
diff --git a/oneflow/python/nn/modules/upsampling.py b/oneflow/python/nn/modules/upsampling.py
index b747f10efacfc60decc5814ee057e7e9d3adf1f8..9316525cdc884db77cffc5548b48feea56c29195 100644
--- a/oneflow/python/nn/modules/upsampling.py
+++ b/oneflow/python/nn/modules/upsampling.py
@@ -146,6 +146,14 @@ class Upsample(Module):
         )
         return res
 
+    def extra_repr(self) -> str:
+        if self.scale_factor is not None:
+            info = "scale_factor=" + str(self.scale_factor)
+        else:
+            info = "size=" + str(self.size)
+        info += ", mode=" + self.mode
+        return info
+
 
 @oneflow_export("nn.UpsamplingNearest2d")
 @experimental_api