Skip to content
Snippets Groups Projects
Unverified Commit feac8104 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix concat backward bug (#5443)


* add argmax test

* fix ci error

* fix docstring warning

* fix tensor greater and less bug

* fix conflict

* add test_flow_xxx_against_pytorch func

* fix concat backward bug

* auto format by CI

* format

* Add autograd engine warning (#5444)

* add autograd engine warning

* fix bug

Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 00892632
No related branches found
No related tags found
No related merge requests found
...@@ -129,7 +129,11 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) { ...@@ -129,7 +129,11 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) {
JUST((*backward_fn_)(output_grads, &input_grads, create_graph)); JUST((*backward_fn_)(output_grads, &input_grads, create_graph));
for (int i = 0; i < input_meta_datas_.size(); ++i) { for (int i = 0; i < input_meta_datas_.size(); ++i) {
if (input_grads.at(i)) { if (input_grads.at(i)) {
CHECK_NOTNULL_OR_RETURN(input_meta_datas_.at(i)); CHECK_NOTNULL_OR_RETURN(input_meta_datas_.at(i))
<< op_name_
<< " calculate grad for tensor which requires_grad is False. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possiable";
JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i))); JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i)));
} }
} }
......
...@@ -24,7 +24,7 @@ namespace oneflow { ...@@ -24,7 +24,7 @@ namespace oneflow {
namespace one { namespace one {
struct ConcatInterpState : public OpExprInterpState { struct ConcatInterpState : public OpExprInterpState {
bool requires_grad; std::vector<bool> requires_grad;
int64_t axis; int64_t axis;
int64_t input_num; int64_t input_num;
}; };
...@@ -57,14 +57,8 @@ Maybe<void> Concat::Init(const OpExpr& op) { ...@@ -57,14 +57,8 @@ Maybe<void> Concat::Init(const OpExpr& op) {
Maybe<void> Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs, Maybe<void> Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const { const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = false; ctx->requires_grad.resize(inputs.size());
for (const auto& input : inputs) { for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); }
if (input->requires_grad()) {
ctx->requires_grad = true;
break;
}
}
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_); ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis")); ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis"));
...@@ -75,7 +69,6 @@ Maybe<void> Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs, ...@@ -75,7 +69,6 @@ Maybe<void> Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_grads, Maybe<void> Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const { TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(ctx->input_num); in_grads->resize(ctx->input_num);
TensorTuple inputs(ctx->input_num + 1); TensorTuple inputs(ctx->input_num + 1);
...@@ -86,7 +79,8 @@ Maybe<void> Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_g ...@@ -86,7 +79,8 @@ Maybe<void> Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_g
const auto& results = JUST(OpInterpUtil::Dispatch<TensorTuple>(*grad_op_, inputs, concat_attrs)); const auto& results = JUST(OpInterpUtil::Dispatch<TensorTuple>(*grad_op_, inputs, concat_attrs));
CHECK_EQ_OR_RETURN(results->size(), ctx->input_num); CHECK_EQ_OR_RETURN(results->size(), ctx->input_num);
for (int i = 0; i < ctx->input_num; ++i) { in_grads->at(i) = results->at(i); } for (int i = 0; i < ctx->input_num; ++i)
if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -98,6 +98,28 @@ def _test_concat_with_three_tensor_backward(test_case, device): ...@@ -98,6 +98,28 @@ def _test_concat_with_three_tensor_backward(test_case, device):
) )
def _test_concat_grad_and_no_grad(test_case, device):
input1 = flow.Tensor(
np.random.randn(2, 6, 5, 3),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
input2 = flow.Tensor(
np.random.randn(2, 6, 5, 3),
dtype=flow.float32,
device=flow.device(device),
requires_grad=False,
)
of_out = flow.cat([input1, input2], dim=1)
of_out = of_out.sum()
of_out.backward()
test_case.assertTrue(
np.allclose(input1.grad.numpy(), np.ones((2, 6, 5, 3)), 1e-4, 1e-4)
)
@unittest.skipIf( @unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(), not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode", ".numpy() doesn't work in lazy mode",
...@@ -110,6 +132,7 @@ class TestModule(flow.unittest.TestCase): ...@@ -110,6 +132,7 @@ class TestModule(flow.unittest.TestCase):
_test_concat_with_axis_one, _test_concat_with_axis_one,
_test_concat_with_three_tensor, _test_concat_with_three_tensor,
_test_concat_with_three_tensor_backward, _test_concat_with_three_tensor_backward,
_test_concat_grad_and_no_grad,
] ]
arg_dict["device"] = ["cpu", "cuda"] arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict): for arg in GenArgList(arg_dict):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment