diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index 7bdf703fb857991e2bf231188f604acef0f0446e..d6ab0db3961ee9907ddea355aed5594a49d08408 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -61,6 +61,7 @@ Experimental features
 .. autofunction:: oneflow.experimental.zeros
 .. autofunction:: oneflow.experimental.zeros_like
 .. autofunction:: oneflow.experimental.ones_like
+.. autofunction:: oneflow.experimental.Tensor.new_ones
 .. autofunction:: oneflow.experimental.nn.Module
 .. autofunction:: oneflow.experimental.nn.Parameter
 .. autofunction:: oneflow.experimental.nn.Sequential
diff --git a/oneflow/python/nn/modules/constant.py b/oneflow/python/nn/modules/constant.py
index 4d04f421f79018eadd78157a3ecaaa16cb2f0ac8..4b8137217f02ea85885c1b9fe9f7b6747b4990cd 100644
--- a/oneflow/python/nn/modules/constant.py
+++ b/oneflow/python/nn/modules/constant.py
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 """
 import oneflow as flow
+from oneflow.python.framework.tensor import register_tensor_op
 from oneflow.python.nn.module import Module
 from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.nn.common_types import _size_any_t
@@ -201,6 +202,89 @@ def ones_like_op(other):
     return OnesLike()(other)
 
 
+class NewOnes(Module):
+    def __init__(
+        self,
+        size: Union[_size_any_t, flow.Size] = None,
+        dtype: Optional[flow.dtype] = None,
+        device: Union[flow.device, str] = None,
+        requires_grad: bool = False,
+    ):
+        super().__init__()
+
+        self.device = device
+        self.requires_grad = requires_grad
+        if size != None:
+            size = _single(size)
+        self.size = size
+        self.dtype = dtype
+
+    def forward(self, x):
+        new_size = self.size
+        new_dtype = self.dtype
+        new_device = self.device
+        new_requires_grad = self.requires_grad
+
+        if self.size is None:
+            new_size = x.shape
+
+        if self.dtype is None:
+            new_dtype = x.dtype
+
+        if self.device is None:
+            new_device = x.device
+
+        assert isinstance(
+            new_size, (int, tuple, flow.Size)
+        ), f"size parameter not correct, please check!"
+        assert isinstance(
+            new_dtype, (flow.dtype)
+        ), f"dtype parameter not correct, please check!"
+        assert isinstance(
+            new_device, (str, flow.device)
+        ), f"device parameter not correct, please check!"
+        assert isinstance(
+            new_requires_grad, bool
+        ), f"requires_grad parameter not correct, please check!"
+
+        res = flow.F.constant(new_size, 1.0, new_dtype)
+        res = res.to(new_device)
+        res.requires_grad = new_requires_grad
+        return res
+
+
+@register_tensor_op("new_ones")
+@experimental_api
+def new_ones_op(x, size=None, dtype=None, device=None, requires_grad=False):
+    r"""
+    
+    Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same torch.dtype and torch.device as this tensor.
+
+    Args:
+        size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.
+        dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.
+        device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.
+        requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.
+    
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow.experimental as flow
+        >>> flow.enable_eager_execution()
+
+        >>> x = flow.Tensor(np.ones((1, 2, 3)))
+        >>> y = x.new_ones((2, 2))
+        >>> y
+        tensor([[1., 1.],
+                [1., 1.]], dtype=oneflow.float32)
+    """
+    return NewOnes(size=size, dtype=dtype, device=device, requires_grad=requires_grad)(
+        x
+    )
+
+
 if __name__ == "__main__":
     import doctest
 
diff --git a/oneflow/python/test/modules/test_constant.py b/oneflow/python/test/modules/test_constant.py
index d3f2b7bc9161c3af555c835a9e814031abd6911d..a377519dec0b8b9a41d31aa014bfc8676f7c55cb 100644
--- a/oneflow/python/test/modules/test_constant.py
+++ b/oneflow/python/test/modules/test_constant.py
@@ -16,6 +16,8 @@ limitations under the License.
 import unittest
 from collections import OrderedDict
 
+from oneflow.python.framework.tensor import register_tensor_op
+
 import numpy as np
 
 import oneflow.experimental as flow
@@ -82,6 +84,26 @@ def _test_zeros_like(test_case, device, shape):
     )
 
 
+def _test_new_ones(test_case, device, shape):
+    x = flow.Tensor(np.ones(shape), device=flow.device(device))
+    y = x.new_ones(shape, device=device)
+    test_case.assertTrue(x.dtype == y.dtype)
+    test_case.assertTrue(x.device == y.device)
+    test_case.assertTrue(x.requires_grad == y.requires_grad)
+
+    x = flow.Tensor(np.ones(shape), device=flow.device(device))
+    y = x.new_ones(x.shape, device=device)
+    test_case.assertTrue(x.dtype == y.dtype)
+    test_case.assertTrue(x.device == y.device)
+    test_case.assertTrue(x.requires_grad == y.requires_grad)
+
+    x = flow.Tensor(np.ones(shape), device=flow.device(device))
+    x = x.new_ones(shape, device=device, requires_grad=True)
+    y = x.sum()
+    y.backward()
+    test_case.assertTrue(np.array_equal(np.ones_like(x.numpy()), x.grad.numpy()))
+
+
 @unittest.skipIf(
     not flow.unittest.env.eager_execution_enabled(),
     ".numpy() doesn't work in lazy mode",
@@ -96,6 +118,7 @@ class TestConstantModule(flow.unittest.TestCase):
             _test_zeros_backward,
             _test_ones_like,
             _test_zeros_like,
+            _test_new_ones,
         ]
         arg_dict["device"] = ["cpu", "cuda"]
         arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]