Skip to content
Snippets Groups Projects
Unverified Commit 373cefce authored by Lyon's avatar Lyon Committed by GitHub
Browse files

Align module params with torch (#4865)


* align mean module

* allow negative dim param

* support tuple of negative dim param

* refine

* format

Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 1ce9f262
No related branches found
No related tags found
No related merge requests found
......@@ -24,20 +24,23 @@ from oneflow.python.framework.tensor import register_tensor_op
def _check_axis(axis, shape):
ndim = len(shape)
# TODO(yaochi): refine this function when all related ops in `python/ops/math_ops.py` migrated
if axis is None:
axis = list(range(len(shape)))
if isinstance(axis, int):
axis = [axis]
assert isinstance(axis, (list, tuple)), "Invalid axis {}".format(axis)
for x in axis:
if x < 0:
x += len(shape)
assert x >= 0 and x < len(shape), "Invalid axis {}, len(shape): {}".format(
axis, len(shape)
axis = list(axis)
for i in range(len(axis)):
assert (
-ndim <= axis[i] <= ndim - 1
), "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
-ndim, ndim - 1, axis[i]
)
if axis[i] < 0:
axis[i] = axis[i] + ndim
return axis
......@@ -211,6 +214,22 @@ class Mean(Module):
self.axes = list(axis) if isinstance(axis, collections.Sized) else [axis]
def forward(self, input_tensor):
ndim = input_tensor.ndimension()
if isinstance(self.axis, int) and self.axis < 0:
assert -ndim <= self.axis <= -1, "axis should be in range:[-ndims,-1]"
self.axis = ndim + self.axis
self.axes = [self.axis]
if isinstance(self.axis, collections.Sized):
for i in range(len(self.axes)):
assert (
-ndim <= self.axes[i] <= ndim - 1
), "Dimension out of range (expected to be in range of [-{}, {}], but got {})".format(
ndim, ndim - 1, self.axes[i]
)
if self.axes[i] < 0:
self.axes[i] = self.axes[i] + ndim
reduce_sum = flow.experimental.sum(
input_tensor, dim=self.axis, keepdims=self.keepdims
)
......
......@@ -99,6 +99,13 @@ class TestStd(flow.unittest.TestCase):
np_out = np.std(np_arr, axis=1)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_std_negative_dim(test_case):
np_arr = np.random.randn(4, 2, 3, 5)
input = flow.Tensor(np_arr)
of_out = input.std(dim=(-2, -1, -3), keepdim=False)
np_out = np.std(np_arr, axis=(-2, -1, -3))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
......
......@@ -34,6 +34,12 @@ class TestMeanModule(flow.unittest.TestCase):
np_out = np.mean(input.numpy(), axis=0)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
def test_mean_negative_dim(test_case):
input = flow.Tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)
of_out = flow.mean(input, dim=(-2, -1, -3))
np_out = np.mean(input.numpy(), axis=(-2, -1, -3))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment