diff --git a/oneflow/python/nn/modules/batchnorm.py b/oneflow/python/nn/modules/batchnorm.py index 31b3c45ae5fbb4b94672165316b7a2e22d026278..6b70e9976dd66bb26615120fc58fa0b5e8edcb0c 100644 --- a/oneflow/python/nn/modules/batchnorm.py +++ b/oneflow/python/nn/modules/batchnorm.py @@ -30,8 +30,6 @@ class _NormBase(Module): momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, - device: Union[str, flow.device] = None, - dtype: flow.dtype = None, ) -> None: super().__init__() self.num_features = num_features @@ -39,23 +37,19 @@ class _NormBase(Module): self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats - self.device = device - self.dtype = dtype if self.affine: - self.weight = flow.nn.Parameter( - flow.Tensor(num_features, device=self.device) - ) - self.bias = flow.nn.Parameter(flow.Tensor(num_features, device=self.device)) + self.weight = flow.nn.Parameter(flow.Tensor(num_features)) + self.bias = flow.nn.Parameter(flow.Tensor(num_features)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) if self.track_running_stats: self.register_buffer( - "running_mean", flow.Tensor(num_features, device=self.device), + "running_mean", flow.Tensor(num_features), ) self.register_buffer( - "running_var", flow.Tensor(num_features, device=self.device), + "running_var", flow.Tensor(num_features), ) else: self.register_parameter("running_mean", None) @@ -106,28 +100,19 @@ class _BatchNorm(_NormBase): momentum=0.1, affine=True, track_running_stats=True, - device=None, - dtype=None, ): - super().__init__( - num_features, eps, momentum, affine, track_running_stats, device, dtype - ) + super().__init__(num_features, eps, momentum, affine, track_running_stats) def forward(self, x): - if self.dtype is None: - self.dtype = x.dtype - if self.device is None: - self.device = x.device - self._check_input_dim(x) - reduce_axis = [] - for dim in range(len(x.shape)): - if dim != 1: - reduce_axis.append(dim) - mean = x.mean(dim=reduce_axis, keepdim=False) - variance = x.var(dim=reduce_axis, keepdim=False) - if x.device == flow.device("cpu"): + reduce_axis = [] + for dim in range(len(x.shape)): + if dim != 1: + reduce_axis.append(dim) + mean = x.mean(dim=reduce_axis, keepdim=False) + variance = x.var(dim=reduce_axis, keepdim=False) + if self.training and self.track_running_stats: running_mean = ( self.momentum * self.running_mean + (1 - self.momentum) * mean @@ -173,21 +158,38 @@ class _BatchNorm(_NormBase): affined = affined * weight if self.bias: affined = affined + bias - return affined.to(dtype=self.dtype) + return affined else: - res = flow.F.normalization( - x, - self.running_mean if self.track_running_stats else mean, - self.running_var if self.track_running_stats else variance, - self.weight, - self.bias, - axis=1, - epsilon=self.eps, - momentum=self.momentum, - is_training=self.training, - ) - return res.to(dtype=self.dtype, device=self.device) + if self.track_running_stats: + return flow.F.normalization( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + axis=1, + epsilon=self.eps, + momentum=self.momentum, + is_training=self.training, + ) + else: + reduce_axis = [] + for dim in range(len(x.shape)): + if dim != 1: + reduce_axis.append(dim) + + return flow.F.normalization( + x, + x.mean(dim=reduce_axis, keepdim=False), + x.var(dim=reduce_axis, keepdim=False), + self.weight, + self.bias, + axis=1, + epsilon=self.eps, + momentum=self.momentum, + is_training=self.training, + ) @oneflow_export("nn.BatchNorm1d") diff --git a/oneflow/python/test/modules/test_batchnorm.py b/oneflow/python/test/modules/test_batchnorm.py index b50991888d75d0c175337f72721e85b7674b2809..4ca86422579ff8b429d04475aca786645e292b6a 100644 --- a/oneflow/python/test/modules/test_batchnorm.py +++ b/oneflow/python/test/modules/test_batchnorm.py @@ -44,8 +44,8 @@ def _test_batchnorm1d_2d_input(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm1d( - num_features=5, eps=1e-5, momentum=0.1, device=flow.device(device) + m = flow.nn.BatchNorm1d(num_features=5, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) x = flow.Tensor(input_arr, device=flow.device(device)) y = m(x) @@ -85,8 +85,8 @@ def _test_batchnorm1d_3d_input(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm1d( - num_features=3, eps=1e-5, momentum=0.1, device=flow.device(device) + m = flow.nn.BatchNorm1d(num_features=3, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) x = flow.Tensor(input_arr, device=flow.device(device)) y = m(x) @@ -154,12 +154,8 @@ def _test_batchnorm2d(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm2d( - num_features=2, - eps=1e-5, - momentum=0.1, - device=flow.device(device), - dtype=flow.float64, + m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) x = flow.Tensor(input_arr, device=flow.device(device), dtype=flow.float32) y = m(x) @@ -228,12 +224,8 @@ def _test_batchnorm2d_track_running_stats(test_case, device): ) m = flow.nn.BatchNorm2d( - num_features=2, - eps=1e-5, - momentum=0.1, - track_running_stats=False, - device=flow.device(device), - ) + num_features=2, eps=1e-5, momentum=0.1, track_running_stats=False, + ).to(device=flow.device(device)) x = flow.Tensor(input_arr, device=flow.device(device)) y = m(x) test_case.assertTrue(np.allclose(y.numpy(), output, 1e-04, 1e-04)) @@ -300,8 +292,8 @@ def _test_batchnorm2d_4d_input(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm2d( - num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device) + m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) x = flow.Tensor(input_arr, device=flow.device(device)) y = m(x) @@ -369,8 +361,8 @@ def test_batchnorm2d_infer(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm2d( - num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device) + m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) m.eval() x = flow.Tensor(input_arr, device=flow.device(device)) @@ -439,8 +431,8 @@ def test_batchnorm2d_infer_4d_input(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm2d( - num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device) + m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) m.eval() x = flow.Tensor(input_arr, device=flow.device(device)) @@ -479,8 +471,8 @@ def _test_batchnorm2d_backward(test_case, device): dtype=np.float32, ) - m = flow.nn.BatchNorm2d( - num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device) + m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to( + device=flow.device(device) ) x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True) y = m(x)