diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py
index 398dda795948e0be81dad4bf609a52bfcf35dd0d..46c79f6372c80b5369f7b9a2bcaf487147f8b990 100644
--- a/oneflow/python/nn/modules/loss.py
+++ b/oneflow/python/nn/modules/loss.py
@@ -22,7 +22,30 @@ from oneflow.python.nn.module import Module
 
 @oneflow_export("nn.CrossEntropyLoss")
 class CrossEntropyLoss(Module):
-    r"""
+    r"""This criterion combines :class:`~flow.nn.LogSoftmax` and :class:`~flow.nn.NLLLoss` in one single class.
+
+    It is useful when training a classification problem with `C` classes.
+    
+    The `input` is expected to contain raw, unnormalized scores for each class.
+
+    `input` has to be a Tensor of size either :math:`(minibatch, C)` or
+    :math:`(minibatch, C, d_1, d_2, ..., d_K)`
+    with :math:`K \geq 1` for the `K`-dimensional case (described later).
+
+    This criterion expects a class index in the range :math:`[0, C-1]` as the
+    `target` for each value of a 1D tensor of size `minibatch`; 
+
+    The loss can be described as:
+
+    .. math::
+        \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
+                       = -x[class] + \log\left(\sum_j \exp(x[j])\right)
+
+    Can also be used for higher dimension inputs, such as 2D images, by providing
+    an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
+    where :math:`K` is the number of dimensions, and a target of appropriate shape
+    (see below).
+
     Args:
         reduction (string, optional): Specifies the reduction to apply to the output:
             ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
@@ -33,6 +56,7 @@ class CrossEntropyLoss(Module):
             :attr:`reduction`. Default: ``'mean'``
 
     For example:
+
     .. code-block:: python
 
         import oneflow as flow
@@ -47,6 +71,7 @@ class CrossEntropyLoss(Module):
         # out_sum: [2.2769074]
         out_mean = flow.nn.CrossEntropyLoss(reduction="mean")(input, target)
         # out_mean: [0.7589692]
+        
 
     """
 
@@ -79,12 +104,20 @@ class CrossEntropyLoss(Module):
         )
 
     def forward(self, input, target):
+        input_shape_len = len(input.shape)
+        if input_shape_len == 4:
+            b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3]
+            input = flow.tmp.transpose(input, (0, 2, 3, 1))
+            input = flow.tmp.reshape(input, shape=[-1, input.shape[3]])
+            target = flow.tmp.flatten(target)
         prob, out = self._op(input, target, depth=input.shape[len(input.shape) - 1])
         if self.reduction == "mean":
             return flow.mean(out)
         elif self.reduction == "sum":
             return flow.sum(out)
         else:
+            if input_shape_len == 4:
+                out = flow.tmp.reshape(out, (b, h, w))
             return out
 
 
diff --git a/oneflow/python/test/modules/test_crossentropyloss.py b/oneflow/python/test/modules/test_crossentropyloss.py
index 75f23358be6d63b5830da926383407fb1615a1a7..c8597f1d7fc80f9c54dec95970afc2730127f8f3 100644
--- a/oneflow/python/test/modules/test_crossentropyloss.py
+++ b/oneflow/python/test/modules/test_crossentropyloss.py
@@ -30,7 +30,34 @@ g_test_samples = [
         "out": np.array([1.1380, 1.7332, 1.4287], dtype=np.float32),
         "out_sum": np.array([4.2999], dtype=np.float32),
         "out_mean": np.array([1.4333], dtype=np.float32),
-    }
+    },
+    {
+        "input": np.array(
+            [[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]]
+        ),
+        "target": np.array([[[1, 0], [0, 1]]], dtype=np.int32),
+        "out": np.array([[[0.6882, 0.6832], [0.8544, 1.8006]]], dtype=np.float32),
+        "out_sum": np.array([4.0263], dtype=np.float32),
+        "out_mean": np.array([1.0066], dtype=np.float32),
+    },
+    {
+        "input": np.array(
+            [
+                [[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]],
+                [[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]],
+            ]
+        ),
+        "target": np.array([[[1, 0], [0, 1]], [[1, 0], [0, 1]]], dtype=np.int32),
+        "out": np.array(
+            [
+                [[0.6882, 0.6832], [0.8544, 1.8006]],
+                [[0.6882, 0.6832], [0.8544, 1.8006]],
+            ],
+            dtype=np.float32,
+        ),
+        "out_sum": np.array([8.0526], dtype=np.float32),
+        "out_mean": np.array([1.0066], dtype=np.float32),
+    },
 ]