diff --git a/oneflow/python/test/modules/test_upsample2d.py b/oneflow/python/test/modules/test_upsample2d.py index cb3f434394929573af1cf29d77df6637c9958bb6..7f21f7c87c3be4cacc854d2c19ea129fad229a8b 100644 --- a/oneflow/python/test/modules/test_upsample2d.py +++ b/oneflow/python/test/modules/test_upsample2d.py @@ -269,6 +269,34 @@ def _test_upsample2d_bilinear_aligncorner_backward(test_case, device): test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) +def _test_interpolate_float_scale(test_case, device): + input = flow.Tensor( + np.arange(1, 10).reshape((1, 1, 3, 3)), + device=flow.device(device), + dtype=flow.float32, + requires_grad=True, + ) + m = flow.nn.Upsample(scale_factor=1.5) + of_out = m(input) + np_out = np.array( + [ + [ + [ + [1.0, 1.0, 2.0, 3.0], + [1.0, 1.0, 2.0, 3.0], + [4.0, 4.0, 5.0, 6.0], + [7.0, 7.0, 8.0, 9.0], + ] + ] + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + of_out = of_out.sum() + of_out.backward() + np_grad = np.array([[[[4.0, 2.0, 2.0], [2.0, 1.0, 1.0], [2.0, 1.0, 1.0]]]]) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) + + @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", @@ -286,6 +314,7 @@ class TestUpsample2d(flow.unittest.TestCase): _test_upsample2d_bilinear_4dim, _test_upsample2d_backward, _test_upsample2d_bilinear_aligncorner_backward, + _test_interpolate_float_scale, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): diff --git a/oneflow/user/kernels/upsample_kernel.cpp b/oneflow/user/kernels/upsample_kernel.cpp index bcdf3f79cfb6fff29ab103420ad409c38c889622..55381a89eb39c056e6d73b3c1247ba99b55cbff5 100644 --- a/oneflow/user/kernels/upsample_kernel.cpp +++ b/oneflow/user/kernels/upsample_kernel.cpp @@ -23,7 +23,7 @@ namespace { static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale, const int64_t in_dim_size) { - int64_t index = static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale)); + int64_t index = static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale))); index = index > in_dim_size - 1 ? in_dim_size - 1 : index; index = index < static_cast<int64_t>(0) ? static_cast<int64_t>(0) : index; return index; diff --git a/oneflow/user/kernels/upsample_kernel.cu b/oneflow/user/kernels/upsample_kernel.cu index 80aa5b8786930916ffcfdf354c873159f12a17dc..b8d56be47d3be17f858eaeb90cd3546bcd601166 100644 --- a/oneflow/user/kernels/upsample_kernel.cu +++ b/oneflow/user/kernels/upsample_kernel.cu @@ -24,7 +24,7 @@ namespace { __device__ int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale, const int64_t in_dim_size) { - return max(min(static_cast<int64_t>(floorf((static_cast<float>(out_dim_idx) + 0.5f) * scale)), + return max(min(static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale))), in_dim_size - 1), static_cast<int64_t>(0)); } diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index 6241753d05a53839f9eb6ceed1f97a04bf7c2ce0..8ee59286e182ff36a9a8a1681f2fffcc53a00555 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -65,8 +65,8 @@ REGISTER_USER_OP("upsample_grad") LOG(FATAL) << "upsample_nearest only supports NCHW"; } *dx_shape = Shape({dy_shape.At(0), dy_shape.At(1), - static_cast<int32_t>(dy_shape.At(2) / height_scale), - static_cast<int32_t>(dy_shape.At(3) / width_scale)}); + static_cast<int32_t>(std::round(dy_shape.At(2) / height_scale)), + static_cast<int32_t>(std::round(dy_shape.At(3) / width_scale))}); return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {