diff --git a/oneflow/python/nn/modules/batchnorm.py b/oneflow/python/nn/modules/batchnorm.py
index 31b3c45ae5fbb4b94672165316b7a2e22d026278..6b70e9976dd66bb26615120fc58fa0b5e8edcb0c 100644
--- a/oneflow/python/nn/modules/batchnorm.py
+++ b/oneflow/python/nn/modules/batchnorm.py
@@ -30,8 +30,6 @@ class _NormBase(Module):
         momentum: float = 0.1,
         affine: bool = True,
         track_running_stats: bool = True,
-        device: Union[str, flow.device] = None,
-        dtype: flow.dtype = None,
     ) -> None:
         super().__init__()
         self.num_features = num_features
@@ -39,23 +37,19 @@ class _NormBase(Module):
         self.momentum = momentum
         self.affine = affine
         self.track_running_stats = track_running_stats
-        self.device = device
-        self.dtype = dtype
 
         if self.affine:
-            self.weight = flow.nn.Parameter(
-                flow.Tensor(num_features, device=self.device)
-            )
-            self.bias = flow.nn.Parameter(flow.Tensor(num_features, device=self.device))
+            self.weight = flow.nn.Parameter(flow.Tensor(num_features))
+            self.bias = flow.nn.Parameter(flow.Tensor(num_features))
         else:
             self.register_parameter("weight", None)
             self.register_parameter("bias", None)
         if self.track_running_stats:
             self.register_buffer(
-                "running_mean", flow.Tensor(num_features, device=self.device),
+                "running_mean", flow.Tensor(num_features),
             )
             self.register_buffer(
-                "running_var", flow.Tensor(num_features, device=self.device),
+                "running_var", flow.Tensor(num_features),
             )
         else:
             self.register_parameter("running_mean", None)
@@ -106,28 +100,19 @@ class _BatchNorm(_NormBase):
         momentum=0.1,
         affine=True,
         track_running_stats=True,
-        device=None,
-        dtype=None,
     ):
-        super().__init__(
-            num_features, eps, momentum, affine, track_running_stats, device, dtype
-        )
+        super().__init__(num_features, eps, momentum, affine, track_running_stats)
 
     def forward(self, x):
-        if self.dtype is None:
-            self.dtype = x.dtype
-        if self.device is None:
-            self.device = x.device
-
         self._check_input_dim(x)
-        reduce_axis = []
-        for dim in range(len(x.shape)):
-            if dim != 1:
-                reduce_axis.append(dim)
-        mean = x.mean(dim=reduce_axis, keepdim=False)
-        variance = x.var(dim=reduce_axis, keepdim=False)
-
         if x.device == flow.device("cpu"):
+            reduce_axis = []
+            for dim in range(len(x.shape)):
+                if dim != 1:
+                    reduce_axis.append(dim)
+            mean = x.mean(dim=reduce_axis, keepdim=False)
+            variance = x.var(dim=reduce_axis, keepdim=False)
+
             if self.training and self.track_running_stats:
                 running_mean = (
                     self.momentum * self.running_mean + (1 - self.momentum) * mean
@@ -173,21 +158,38 @@ class _BatchNorm(_NormBase):
                 affined = affined * weight
             if self.bias:
                 affined = affined + bias
-            return affined.to(dtype=self.dtype)
+            return affined
 
         else:
-            res = flow.F.normalization(
-                x,
-                self.running_mean if self.track_running_stats else mean,
-                self.running_var if self.track_running_stats else variance,
-                self.weight,
-                self.bias,
-                axis=1,
-                epsilon=self.eps,
-                momentum=self.momentum,
-                is_training=self.training,
-            )
-            return res.to(dtype=self.dtype, device=self.device)
+            if self.track_running_stats:
+                return flow.F.normalization(
+                    x,
+                    self.running_mean,
+                    self.running_var,
+                    self.weight,
+                    self.bias,
+                    axis=1,
+                    epsilon=self.eps,
+                    momentum=self.momentum,
+                    is_training=self.training,
+                )
+            else:
+                reduce_axis = []
+                for dim in range(len(x.shape)):
+                    if dim != 1:
+                        reduce_axis.append(dim)
+
+                return flow.F.normalization(
+                    x,
+                    x.mean(dim=reduce_axis, keepdim=False),
+                    x.var(dim=reduce_axis, keepdim=False),
+                    self.weight,
+                    self.bias,
+                    axis=1,
+                    epsilon=self.eps,
+                    momentum=self.momentum,
+                    is_training=self.training,
+                )
 
 
 @oneflow_export("nn.BatchNorm1d")
diff --git a/oneflow/python/test/modules/test_batchnorm.py b/oneflow/python/test/modules/test_batchnorm.py
index b50991888d75d0c175337f72721e85b7674b2809..4ca86422579ff8b429d04475aca786645e292b6a 100644
--- a/oneflow/python/test/modules/test_batchnorm.py
+++ b/oneflow/python/test/modules/test_batchnorm.py
@@ -44,8 +44,8 @@ def _test_batchnorm1d_2d_input(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm1d(
-        num_features=5, eps=1e-5, momentum=0.1, device=flow.device(device)
+    m = flow.nn.BatchNorm1d(num_features=5, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     x = flow.Tensor(input_arr, device=flow.device(device))
     y = m(x)
@@ -85,8 +85,8 @@ def _test_batchnorm1d_3d_input(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm1d(
-        num_features=3, eps=1e-5, momentum=0.1, device=flow.device(device)
+    m = flow.nn.BatchNorm1d(num_features=3, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     x = flow.Tensor(input_arr, device=flow.device(device))
     y = m(x)
@@ -154,12 +154,8 @@ def _test_batchnorm2d(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm2d(
-        num_features=2,
-        eps=1e-5,
-        momentum=0.1,
-        device=flow.device(device),
-        dtype=flow.float64,
+    m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     x = flow.Tensor(input_arr, device=flow.device(device), dtype=flow.float32)
     y = m(x)
@@ -228,12 +224,8 @@ def _test_batchnorm2d_track_running_stats(test_case, device):
     )
 
     m = flow.nn.BatchNorm2d(
-        num_features=2,
-        eps=1e-5,
-        momentum=0.1,
-        track_running_stats=False,
-        device=flow.device(device),
-    )
+        num_features=2, eps=1e-5, momentum=0.1, track_running_stats=False,
+    ).to(device=flow.device(device))
     x = flow.Tensor(input_arr, device=flow.device(device))
     y = m(x)
     test_case.assertTrue(np.allclose(y.numpy(), output, 1e-04, 1e-04))
@@ -300,8 +292,8 @@ def _test_batchnorm2d_4d_input(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm2d(
-        num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
+    m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     x = flow.Tensor(input_arr, device=flow.device(device))
     y = m(x)
@@ -369,8 +361,8 @@ def test_batchnorm2d_infer(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm2d(
-        num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
+    m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     m.eval()
     x = flow.Tensor(input_arr, device=flow.device(device))
@@ -439,8 +431,8 @@ def test_batchnorm2d_infer_4d_input(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm2d(
-        num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
+    m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     m.eval()
     x = flow.Tensor(input_arr, device=flow.device(device))
@@ -479,8 +471,8 @@ def _test_batchnorm2d_backward(test_case, device):
         dtype=np.float32,
     )
 
-    m = flow.nn.BatchNorm2d(
-        num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
+    m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
+        device=flow.device(device)
     )
     x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
     y = m(x)