Skip to content
Snippets Groups Projects
Unverified Commit 5eb22e5c authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fix upsample nearest bug (#5347)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent a6430075
No related branches found
No related tags found
No related merge requests found
......@@ -269,6 +269,34 @@ def _test_upsample2d_bilinear_aligncorner_backward(test_case, device):
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
def _test_interpolate_float_scale(test_case, device):
input = flow.Tensor(
np.arange(1, 10).reshape((1, 1, 3, 3)),
device=flow.device(device),
dtype=flow.float32,
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=1.5)
of_out = m(input)
np_out = np.array(
[
[
[
[1.0, 1.0, 2.0, 3.0],
[1.0, 1.0, 2.0, 3.0],
[4.0, 4.0, 5.0, 6.0],
[7.0, 7.0, 8.0, 9.0],
]
]
]
)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = np.array([[[[4.0, 2.0, 2.0], [2.0, 1.0, 1.0], [2.0, 1.0, 1.0]]]])
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
......@@ -286,6 +314,7 @@ class TestUpsample2d(flow.unittest.TestCase):
_test_upsample2d_bilinear_4dim,
_test_upsample2d_backward,
_test_upsample2d_bilinear_aligncorner_backward,
_test_interpolate_float_scale,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
......
......@@ -23,7 +23,7 @@ namespace {
static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale,
const int64_t in_dim_size) {
int64_t index = static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale));
int64_t index = static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale)));
index = index > in_dim_size - 1 ? in_dim_size - 1 : index;
index = index < static_cast<int64_t>(0) ? static_cast<int64_t>(0) : index;
return index;
......
......@@ -24,7 +24,7 @@ namespace {
__device__ int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale,
const int64_t in_dim_size) {
return max(min(static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale)),
return max(min(static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale))),
in_dim_size - 1),
static_cast<int64_t>(0));
}
......
......@@ -65,8 +65,8 @@ REGISTER_USER_OP("upsample_grad")
LOG(FATAL) << "upsample_nearest only supports NCHW";
}
*dx_shape = Shape({dy_shape.At(0), dy_shape.At(1),
static_cast<int32_t>(dy_shape.At(2) / height_scale),
static_cast<int32_t>(dy_shape.At(3) / width_scale)});
static_cast<int32_t>(std::round(dy_shape.At(2) / height_scale)),
static_cast<int32_t>(std::round(dy_shape.At(3) / width_scale))});
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment