diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index b580a83efd6077c2fe55e1cff642819849309733..85b28991783bf2533a3e7ed87f30d02658943576 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -215,3 +215,5 @@ Experimental features
 .. autofunction:: oneflow.experimental.Tensor.topk
 .. autofunction:: oneflow.experimental.nn.GroupNorm
 .. autofunction:: oneflow.experimental.nn.ZeroPad2d
+.. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
+.. autofunction:: oneflow.experimental.tensor_to_tensor_buffer
diff --git a/oneflow/python/nn/modules/tensor_buffer.py b/oneflow/python/nn/modules/tensor_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..334dbbe2ffae48a380986357885066e9701fb8dd
--- /dev/null
+++ b/oneflow/python/nn/modules/tensor_buffer.py
@@ -0,0 +1,128 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from typing import Sequence
+
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+
+
+class TensorBufferToTensor(Module):
+    def __init__(self, dtype, instance_shape):
+        super().__init__()
+        self._op = (
+            flow.builtin_op("tensor_buffer_to_tensor")
+            .Input("in")
+            .Output("out")
+            .Attr("dtype", dtype)
+            .Attr("instance_shape", instance_shape)
+            .Build()
+        )
+
+    def forward(self, input):
+        return self._op(input)[0]
+
+
+@oneflow_export("tensor_buffer_to_tensor")
+@experimental_api
+def tensor_buffer_to_tensor_op(x, dtype: flow.dtype, instance_shape: Sequence[int]):
+    """This operator converts the Tensor's type from TensorBuffer to original type.
+    Some operator's output data type is `TensorBuffer`, you can use this operator to convert back
+    to `Tensor`.
+
+    Refer to `Concept Explanation <https://docs.oneflow.org/basics_topics/concept_explanation.html#3tensorbuffer-tensorlist>`_
+    for more about TensorBuffer.
+
+    Args:
+        x (oneflow.Tensor): The input Tensor.
+        dtype (flow.dtype): The data dtype.
+        instance_shape (Sequence[int]): The shape of each TensorBuffer instance.
+
+    Returns:
+        oneflow.Tensor: The result Tensor.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow.experimental as flow
+        >>> flow.enable_eager_execution()
+
+        >>> x = np.random.randn(4, 16, 64, 64).astype(np.float32)
+        >>> x = flow.Tensor(x)
+        >>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2)
+        >>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float)
+        >>> output.shape
+        flow.Size([4, 16, 64, 64])
+
+    """
+    return TensorBufferToTensor(dtype=dtype, instance_shape=instance_shape)(x)
+
+
+class TensorToTensorBuffer(Module):
+    def __init__(self, instance_dims):
+        super().__init__()
+        self._op = (
+            flow.builtin_op("tensor_to_tensor_buffer")
+            .Input("in")
+            .Output("out")
+            .Attr("instance_dims", instance_dims)
+            .Build()
+        )
+
+    def forward(self, input):
+        return self._op(input)[0]
+
+
+@oneflow_export("tensor_to_tensor_buffer")
+@experimental_api
+def tensor_to_tensor_buffer(x, instance_dims: int):
+    """This operator converts the Tensor's type to TensorBuffer.
+
+    Refer to `Concept Explanation <https://docs.oneflow.org/basics_topics/concept_explanation.html#3tensorbuffer-tensorlist>`_
+    for more about TensorBuffer.
+
+    Args:
+        x (oneflow.Tensor): The input Tensor.
+        instance_dims (int): The dimensions of dynamic tensor instance.
+
+    Returns:
+        oneflow.Tensor: The result Tensor.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow.experimental as flow
+        >>> flow.enable_eager_execution()
+
+        >>> x = np.random.randn(4, 16, 64, 64).astype(np.float32)
+        >>> x = flow.Tensor(x)
+        >>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2)
+        >>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float)
+        >>> output.shape
+        flow.Size([4, 16, 64, 64])
+    
+    """
+    return TensorToTensorBuffer(instance_dims=instance_dims)(x)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/ops/tensor_buffer_ops.py b/oneflow/python/ops/tensor_buffer_ops.py
index b3534bf7f214a6a4e6bfc524aa8bd575a29195e0..4f15585dbbc7e28e6a6dfcff06960215cc507e39 100644
--- a/oneflow/python/ops/tensor_buffer_ops.py
+++ b/oneflow/python/ops/tensor_buffer_ops.py
@@ -26,6 +26,7 @@ from typing import Optional, Sequence, List
 
 
 @oneflow_export("tensor_buffer_to_tensor")
+@stable_api
 def tensor_buffer_to_tensor(
     x: oneflow._oneflow_internal.BlobDesc,
     dtype: flow.dtype,
@@ -89,6 +90,7 @@ def tensor_buffer_to_tensor(
 
 
 @oneflow_export("tensor_to_tensor_buffer")
+@stable_api
 def tensor_to_tensor_buffer(
     x: oneflow._oneflow_internal.BlobDesc,
     instance_dims: int,
diff --git a/oneflow/python/test/modules/test_tensor_buffer.py b/oneflow/python/test/modules/test_tensor_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbe746512b3bb9e93fcd20949f3eb1da46c2c21c
--- /dev/null
+++ b/oneflow/python/test/modules/test_tensor_buffer.py
@@ -0,0 +1,51 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import GenArgList, type_name_to_flow_type
+
+
+def _test_tensor_buffer_convert(test_case, device):
+    input = flow.Tensor(
+        np.random.rand(16, 24, 32, 36), dtype=flow.float32, device=flow.device(device)
+    )
+    tensor_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=2)
+    orig_tensor = flow.tensor_buffer_to_tensor(
+        tensor_buffer, dtype=flow.float32, instance_shape=[32, 36]
+    )
+
+    test_case.assertTrue(np.array_equal(input.numpy(), orig_tensor.numpy()))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestTensorBufferOps(flow.unittest.TestCase):
+    def test_tensor_buffer_convert(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [_test_tensor_buffer_convert]
+        arg_dict["device"] = ["cpu"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+
+if __name__ == "__main__":
+    unittest.main()