Skip to content
Snippets Groups Projects
Unverified Commit 2a382985 authored by Luyang's avatar Luyang Committed by GitHub
Browse files

remove device dtype params (#5434)

parent 96fe7bef
No related branches found
No related tags found
No related merge requests found
......@@ -30,8 +30,6 @@ class _NormBase(Module):
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device: Union[str, flow.device] = None,
dtype: flow.dtype = None,
) -> None:
super().__init__()
self.num_features = num_features
......@@ -39,23 +37,19 @@ class _NormBase(Module):
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.device = device
self.dtype = dtype
if self.affine:
self.weight = flow.nn.Parameter(
flow.Tensor(num_features, device=self.device)
)
self.bias = flow.nn.Parameter(flow.Tensor(num_features, device=self.device))
self.weight = flow.nn.Parameter(flow.Tensor(num_features))
self.bias = flow.nn.Parameter(flow.Tensor(num_features))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer(
"running_mean", flow.Tensor(num_features, device=self.device),
"running_mean", flow.Tensor(num_features),
)
self.register_buffer(
"running_var", flow.Tensor(num_features, device=self.device),
"running_var", flow.Tensor(num_features),
)
else:
self.register_parameter("running_mean", None)
......@@ -106,28 +100,19 @@ class _BatchNorm(_NormBase):
momentum=0.1,
affine=True,
track_running_stats=True,
device=None,
dtype=None,
):
super().__init__(
num_features, eps, momentum, affine, track_running_stats, device, dtype
)
super().__init__(num_features, eps, momentum, affine, track_running_stats)
def forward(self, x):
if self.dtype is None:
self.dtype = x.dtype
if self.device is None:
self.device = x.device
self._check_input_dim(x)
reduce_axis = []
for dim in range(len(x.shape)):
if dim != 1:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=False)
variance = x.var(dim=reduce_axis, keepdim=False)
if x.device == flow.device("cpu"):
reduce_axis = []
for dim in range(len(x.shape)):
if dim != 1:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=False)
variance = x.var(dim=reduce_axis, keepdim=False)
if self.training and self.track_running_stats:
running_mean = (
self.momentum * self.running_mean + (1 - self.momentum) * mean
......@@ -173,21 +158,38 @@ class _BatchNorm(_NormBase):
affined = affined * weight
if self.bias:
affined = affined + bias
return affined.to(dtype=self.dtype)
return affined
else:
res = flow.F.normalization(
x,
self.running_mean if self.track_running_stats else mean,
self.running_var if self.track_running_stats else variance,
self.weight,
self.bias,
axis=1,
epsilon=self.eps,
momentum=self.momentum,
is_training=self.training,
)
return res.to(dtype=self.dtype, device=self.device)
if self.track_running_stats:
return flow.F.normalization(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
axis=1,
epsilon=self.eps,
momentum=self.momentum,
is_training=self.training,
)
else:
reduce_axis = []
for dim in range(len(x.shape)):
if dim != 1:
reduce_axis.append(dim)
return flow.F.normalization(
x,
x.mean(dim=reduce_axis, keepdim=False),
x.var(dim=reduce_axis, keepdim=False),
self.weight,
self.bias,
axis=1,
epsilon=self.eps,
momentum=self.momentum,
is_training=self.training,
)
@oneflow_export("nn.BatchNorm1d")
......
......@@ -44,8 +44,8 @@ def _test_batchnorm1d_2d_input(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm1d(
num_features=5, eps=1e-5, momentum=0.1, device=flow.device(device)
m = flow.nn.BatchNorm1d(num_features=5, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
......@@ -85,8 +85,8 @@ def _test_batchnorm1d_3d_input(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm1d(
num_features=3, eps=1e-5, momentum=0.1, device=flow.device(device)
m = flow.nn.BatchNorm1d(num_features=3, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
......@@ -154,12 +154,8 @@ def _test_batchnorm2d(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(
num_features=2,
eps=1e-5,
momentum=0.1,
device=flow.device(device),
dtype=flow.float64,
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device), dtype=flow.float32)
y = m(x)
......@@ -228,12 +224,8 @@ def _test_batchnorm2d_track_running_stats(test_case, device):
)
m = flow.nn.BatchNorm2d(
num_features=2,
eps=1e-5,
momentum=0.1,
track_running_stats=False,
device=flow.device(device),
)
num_features=2, eps=1e-5, momentum=0.1, track_running_stats=False,
).to(device=flow.device(device))
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-04, 1e-04))
......@@ -300,8 +292,8 @@ def _test_batchnorm2d_4d_input(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
......@@ -369,8 +361,8 @@ def test_batchnorm2d_infer(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
m.eval()
x = flow.Tensor(input_arr, device=flow.device(device))
......@@ -439,8 +431,8 @@ def test_batchnorm2d_infer_4d_input(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
m.eval()
x = flow.Tensor(input_arr, device=flow.device(device))
......@@ -479,8 +471,8 @@ def _test_batchnorm2d_backward(test_case, device):
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
y = m(x)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment