Skip to content
Snippets Groups Projects
Unverified Commit e8808b18 authored by YongtaoShi's avatar YongtaoShi Committed by GitHub
Browse files

add batchnorm3d module (#5631)


* add batchnorm3d module

* add testing flag, add more log

Signed-off-by: default avatardaquexian <daquexian566@gmail.com>

* only test cpu device

* auto format by CI

Co-authored-by: default avatardaquexian <daquexian566@gmail.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
parent d57de627
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,7 @@ Operators for neural networks
BCEWithLogitsLoss,
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
COCOReader,
CTCLoss,
CoinFlip,
......
......@@ -41,7 +41,7 @@ from oneflow.nn.modules.adaptive_pool import (
AdaptiveAvgPool2d,
AdaptiveAvgPool3d,
)
from oneflow.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d
from oneflow.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
from oneflow.nn.modules.container import (
ModuleDict,
ModuleList,
......
......@@ -328,7 +328,84 @@ class BatchNorm2d(_BatchNorm):
def _check_input_dim(self, input):
if input.ndim != 4:
raise ValueError("expected 4D input (got {}D input)".format(input.ndim()))
raise ValueError("expected 4D input (got {}D input)".format(input.ndim))
class BatchNorm3d(_BatchNorm):
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> x = flow.Tensor(np.random.randn(3, 2, 5, 8, 4))
>>> m = flow.nn.BatchNorm3d(num_features=2, eps=1e-5, momentum=0.1)
>>> y = m(x)
>>> y.size()
flow.Size([3, 2, 5, 8, 4])
"""
def _check_input_dim(self, input):
if input.ndim != 5:
raise ValueError("expected 5D input (got {}D input)".format(input.ndim))
if __name__ == "__main__":
......
......@@ -488,6 +488,17 @@ class TestBatchNorm(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest(n=20, auto_backward=True)
def test_batchnorm3d_module_with_random_data(test_case):
channel = random().to(int)
m = torch.nn.BatchNorm3d(num_features=channel, track_running_stats=False)
m.train(random())
device = "cpu"
m.to(device)
x = random_pytorch_tensor(ndim=5, dim1=channel, requires_grad=True).to(device)
y = m(x)
return y
@unittest.skip("batchnorm module has a bug")
def test_with_random_data(test_case):
for device in ["cpu", "cuda"]:
......
......@@ -27,6 +27,8 @@ from .generators import Nothing, generator, random_tensor
postulate = [".rand", ".Tensor"]
testing = False
def torch_tensor_to_flow(x):
return flow.tensor(x.cpu().numpy())
......@@ -186,9 +188,11 @@ class DualObject:
state_dict = pytorch.state_dict()
state_dict = {k: v.detach().cpu().numpy() for (k, v) in state_dict.items()}
oneflow.load_state_dict(state_dict)
dual_modules_to_test.append(self)
if testing:
dual_modules_to_test.append(self)
if isinstance(pytorch, torch_original.Tensor):
dual_objects_to_test.append(self)
if testing:
dual_objects_to_test.append(self)
def __repr__(self):
return f"PyTorch object:\n{self.pytorch}\n\nOneFlow object:\n{self.oneflow}"
......@@ -239,10 +243,13 @@ def check_tensor_equality(torch_tensor, flow_tensor, rtol=0.0001, atol=1e-05):
if torch_tensor.grad is not None:
assert (
flow_tensor.grad is not None
), "OneFlow tensor doesn't have grad while PyTorch tensor has one"
if not np.allclose(
torch_tensor.grad.detach().cpu().numpy(), flow_tensor.grad.numpy()
):
), f"OneFlow tensor doesn't have grad while PyTorch tensor has one, PyTorch tensor is\n {torch_tensor}\n, OneFlow tensor is\n{flow_tensor} "
torch_grad = torch_tensor.grad.detach().cpu().numpy()
flow_grad = flow_tensor.grad.numpy()
if not np.allclose(torch_grad, flow_grad, rtol=rtol, atol=atol):
print(
"Grads are not equal. PyTorch grad: \n{torch_grad}\n, OneFlow grad: \n{flow_grad}"
)
return False
return np.allclose(
torch_tensor.detach().cpu().numpy(),
......@@ -273,7 +280,10 @@ def autotest(n=20, auto_backward=True, rtol=0.0001, atol=1e-05):
dual_modules_to_test.clear()
dual_objects_to_test.clear()
try:
global testing
testing = True
res = f(test_case)
testing = False
except PyTorchDoesNotSupportError as e:
if verbose:
print(e)
......@@ -297,7 +307,7 @@ def autotest(n=20, auto_backward=True, rtol=0.0001, atol=1e-05):
)
)
for x in dual_objects_to_test:
test_case.assertTrue(check_equality(x, rtol=rtol, atol=atol))
test_case.assertTrue(check_equality(x, rtol=rtol, atol=atol), x)
if verbose:
print("test passed")
n -= 1
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment