Skip to content
Snippets Groups Projects
Unverified Commit 250823aa authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix upsample bilinear bug (#5363)


* fix upsample nearest bug

* fix upsample nearest bug (#5347)

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* fix upsample bilinear bug

* fix upsample bilinear bug

* recover code

* align with pytorch

* redesign upsample bilinear

* fix align corner bug

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent bcef01f4
No related branches found
No related tags found
No related merge requests found
......@@ -70,6 +70,7 @@ Maybe<void> Upsample::Capture(UpsampleInterpState* ctx, const TensorTuple& input
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->interpolation = JUST(composed_attrs.GetAttr<std::string>("interpolation"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
......@@ -84,8 +85,9 @@ Maybe<void> Upsample::Apply(const UpsampleInterpState* ctx, const TensorTuple& o
JUST(attrs.SetAttr<bool>("align_corners", ctx->align_corners));
JUST(attrs.SetAttr<std::string>("data_format", ctx->data_format));
JUST(attrs.SetAttr<std::string>("interpolation", ctx->interpolation));
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grads.at(0)}, attrs));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grads.at(0), x}, attrs));
return Maybe<void>::Ok();
}
......
......@@ -631,6 +631,7 @@ Maybe<one::UserOpExpr> UpsampleGradOp(const float& height_scale, const float& wi
const std::string& interpolation, const std::string& name) {
return one::OpBuilder("upsample_grad", name)
.Input("dy")
.Input("x")
.Output("dx")
.Attr<float>("height_scale", height_scale)
.Attr<float>("width_scale", width_scale)
......
......@@ -269,7 +269,7 @@ def _test_upsample2d_bilinear_aligncorner_backward(test_case, device):
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
def _test_interpolate_float_scale(test_case, device):
def _test_interpolate_nearest_float_scale(test_case, device):
input = flow.Tensor(
np.arange(1, 10).reshape((1, 1, 3, 3)),
device=flow.device(device),
......@@ -297,6 +297,55 @@ def _test_interpolate_float_scale(test_case, device):
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
def _test_interpolate_bilinear_float_scale(test_case, device):
input = flow.Tensor(
np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=0.5, mode="bilinear")
of_out = m(input)
np_out = np.array([[[[1.0]]]])
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = np.array([[[[1.0, 0.0], [0.0, 0.0]]]])
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
input = flow.Tensor(
np.arange(1, 10, dtype=np.int32).reshape((1, 1, 3, 3)),
device=flow.device(device),
dtype=flow.float32,
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=0.5, mode="bilinear")
of_out = m(input)
np_out = np.array([[[[1.0]]]])
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = np.array([[[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]])
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
def _test_upsample_bilinear_align_corners(test_case, device):
input = flow.Tensor(
np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)),
device=flow.device(device),
dtype=flow.float32,
requires_grad=True,
)
m = flow.nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=True)
of_out = m(input)
np_out = np.array([[[[1.0]]]])
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = np.array([[[[1.0, 0.0], [0.0, 0.0]]]])
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
......@@ -314,7 +363,9 @@ class TestUpsample2d(flow.unittest.TestCase):
_test_upsample2d_bilinear_4dim,
_test_upsample2d_backward,
_test_upsample2d_bilinear_aligncorner_backward,
_test_interpolate_float_scale,
_test_interpolate_nearest_float_scale,
_test_interpolate_bilinear_float_scale,
_test_upsample_bilinear_align_corners,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
......
......@@ -67,13 +67,21 @@ T GetAreaPixelScale(const int64_t input_size, const int64_t output_size, bool al
}
template<typename T>
T GetAreaPixelSourceIndex(const T scale, const int64_t dst_index, bool align_corners) {
T GetAreaPixelSourceIndex(const T scale, const int64_t dst_index, const int64_t in_len,
bool align_corners) {
T src_index;
if (align_corners) {
return scale * static_cast<T>(dst_index);
src_index = scale * static_cast<T>(dst_index);
} else {
T src_index = (static_cast<T>(dst_index) + 0.5f) * scale - 0.5f;
return (src_index < 0) ? 0 : src_index;
src_index = (static_cast<T>(dst_index) + 0.5f) * scale - 0.5f;
}
int64_t sx = static_cast<int64_t>(floorf(src_index));
src_index = (sx < 0) ? 0 : src_index;
if (scale > static_cast<T>(1.0)) {
src_index = sx >= in_len - 1 ? in_len - 2 : static_cast<T>(sx);
}
return src_index;
}
template<typename T>
......@@ -90,8 +98,8 @@ template<typename T>
void GetBilinearParam(const bool align_corners, const int64_t h, const int64_t w,
const int64_t in_height, const int64_t in_width, const T scale_h,
const T scale_w, BilinearParam<T>* params) {
const T in_h = GetAreaPixelSourceIndex(scale_h, h, align_corners);
const T in_w = GetAreaPixelSourceIndex(scale_w, w, align_corners);
const T in_h = GetAreaPixelSourceIndex(scale_h, h, in_height, align_corners);
const T in_w = GetAreaPixelSourceIndex(scale_w, w, in_width, align_corners);
params->top_h_index = in_h > 0.0 ? floorf(in_h) : 0;
params->bottom_h_index = (in_h < in_height - 1) ? ceilf(in_h) : in_height - 1;
params->h_lerp = in_h - floorf(in_h);
......
......@@ -67,13 +67,21 @@ __host__ T GetAreaPixelScale(const int64_t input_size, const int64_t output_size
}
template<typename T>
__device__ T GetAreaPixelSourceIndex(const T scale, const int64_t dst_index, bool align_corners) {
__device__ T GetAreaPixelSourceIndex(const T scale, const int64_t dst_index, const int64_t in_len,
bool align_corners) {
T src_index;
if (align_corners) {
return scale * static_cast<T>(dst_index);
src_index = scale * static_cast<T>(dst_index);
} else {
T src_index = (static_cast<T>(dst_index) + 0.5f) * scale - 0.5f;
return (src_index < 0) ? 0 : src_index;
src_index = (static_cast<T>(dst_index) + 0.5f) * scale - 0.5f;
}
int64_t sx = static_cast<int64_t>(floorf(src_index));
src_index = (sx < 0) ? 0 : src_index;
if (scale > static_cast<T>(1.0)) {
src_index = sx >= in_len - 1 ? in_len - 2 : static_cast<T>(sx);
}
return src_index;
}
template<typename T>
......@@ -90,8 +98,8 @@ template<typename T>
__device__ void GetBilinearParam(const bool align_corners, const int64_t h, const int64_t w,
const int64_t in_height, const int64_t in_width, const T scale_h,
const T scale_w, BilinearParam<T>* params) {
const T in_h = GetAreaPixelSourceIndex(scale_h, h, align_corners);
const T in_w = GetAreaPixelSourceIndex(scale_w, w, align_corners);
const T in_h = GetAreaPixelSourceIndex(scale_h, h, in_height, align_corners);
const T in_w = GetAreaPixelSourceIndex(scale_w, w, in_width, align_corners);
params->top_h_index = in_h > 0.0 ? floorf(in_h) : 0;
params->bottom_h_index = (in_h < in_height - 1) ? ceilf(in_h) : in_height - 1;
params->h_lerp = in_h - floorf(in_h);
......
......@@ -50,6 +50,7 @@ REGISTER_USER_OP("upsample")
REGISTER_USER_OP("upsample_grad")
.Input("dy")
.Input("x")
.Output("dx")
.Attr<float>("height_scale")
.Attr<float>("width_scale")
......@@ -64,9 +65,7 @@ REGISTER_USER_OP("upsample_grad")
if (ctx->Attr<std::string>("data_format") != "channels_first" || dy_shape.NumAxes() != 4) {
LOG(FATAL) << "upsample_nearest only supports NCHW";
}
*dx_shape = Shape({dy_shape.At(0), dy_shape.At(1),
static_cast<int32_t>(std::round(dy_shape.At(2) / height_scale)),
static_cast<int32_t>(std::round(dy_shape.At(3) / width_scale))});
*dx_shape = ctx->InputShape("x", 0);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
......@@ -85,6 +84,7 @@ REGISTER_USER_OP_GRAD("upsample")
user_op::UserOpConfWrapper grad_op =
builder.Op("upsample_grad")
.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Input("x", op.input("x", 0))
.Output("dx")
.Attr("height_scale", op.attr<float>("height_scale"))
.Attr("width_scale", op.attr<float>("width_scale"))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment