diff --git a/oneflow/core/autograd/autograd_engine.cpp b/oneflow/core/autograd/autograd_engine.cpp index a7b38fcff6570121d4e1c2bccd85684b3fb42226..ab8c14f06041d2e88f317e2dada77b1b87310e8c 100644 --- a/oneflow/core/autograd/autograd_engine.cpp +++ b/oneflow/core/autograd/autograd_engine.cpp @@ -129,7 +129,11 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) { JUST((*backward_fn_)(output_grads, &input_grads, create_graph)); for (int i = 0; i < input_meta_datas_.size(); ++i) { if (input_grads.at(i)) { - CHECK_NOTNULL_OR_RETURN(input_meta_datas_.at(i)); + CHECK_NOTNULL_OR_RETURN(input_meta_datas_.at(i)) + << op_name_ + << " calculate grad for tensor which requires_grad is False. Please submit an issue in " + "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " + "possiable"; JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i))); } } diff --git a/oneflow/core/autograd/gradient_funcs/concat.cpp b/oneflow/core/autograd/gradient_funcs/concat.cpp index 95fd8370af7f5037340cf74c309b1b13d12e5da9..9bbb13aa50d49566879476e56bd9664447b492d4 100644 --- a/oneflow/core/autograd/gradient_funcs/concat.cpp +++ b/oneflow/core/autograd/gradient_funcs/concat.cpp @@ -24,7 +24,7 @@ namespace oneflow { namespace one { struct ConcatInterpState : public OpExprInterpState { - bool requires_grad; + std::vector<bool> requires_grad; int64_t axis; int64_t input_num; }; @@ -57,14 +57,8 @@ Maybe<void> Concat::Init(const OpExpr& op) { Maybe<void> Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { - ctx->requires_grad = false; - for (const auto& input : inputs) { - if (input->requires_grad()) { - ctx->requires_grad = true; - break; - } - } - if (!ctx->requires_grad) { return Maybe<void>::Ok(); } + ctx->requires_grad.resize(inputs.size()); + for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis")); @@ -75,7 +69,6 @@ Maybe<void> Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs, Maybe<void> Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { - if (!ctx->requires_grad) { return Maybe<void>::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(ctx->input_num); TensorTuple inputs(ctx->input_num + 1); @@ -86,7 +79,8 @@ Maybe<void> Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_g const auto& results = JUST(OpInterpUtil::Dispatch<TensorTuple>(*grad_op_, inputs, concat_attrs)); CHECK_EQ_OR_RETURN(results->size(), ctx->input_num); - for (int i = 0; i < ctx->input_num; ++i) { in_grads->at(i) = results->at(i); } + for (int i = 0; i < ctx->input_num; ++i) + if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } return Maybe<void>::Ok(); } diff --git a/oneflow/python/test/modules/test_concat.py b/oneflow/python/test/modules/test_concat.py index f1996e24ce63b6818965c23c68e25c097b3bcf15..65e6f9c7a06f57c0ed511e4c9b4d78edec0484e1 100644 --- a/oneflow/python/test/modules/test_concat.py +++ b/oneflow/python/test/modules/test_concat.py @@ -98,6 +98,28 @@ def _test_concat_with_three_tensor_backward(test_case, device): ) +def _test_concat_grad_and_no_grad(test_case, device): + input1 = flow.Tensor( + np.random.randn(2, 6, 5, 3), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + input2 = flow.Tensor( + np.random.randn(2, 6, 5, 3), + dtype=flow.float32, + device=flow.device(device), + requires_grad=False, + ) + + of_out = flow.cat([input1, input2], dim=1) + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue( + np.allclose(input1.grad.numpy(), np.ones((2, 6, 5, 3)), 1e-4, 1e-4) + ) + + @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", @@ -110,6 +132,7 @@ class TestModule(flow.unittest.TestCase): _test_concat_with_axis_one, _test_concat_with_three_tensor, _test_concat_with_three_tensor_backward, + _test_concat_grad_and_no_grad, ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict):