Skip to content
Snippets Groups Projects
Unverified Commit de05e7a5 authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

fix interpreter determin output leaf and grad (#4872)


* fix interpreter determin output leaf and grad

* fix GradMode get

* simplify

* add test for no_grad

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 72204213
No related branches found
No related tags found
No related merge requests found
......@@ -146,18 +146,18 @@ Maybe<void> DetermineRequiresGrad(TensorTuple* outputs, const bool& requires_gra
Maybe<void> AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs, const AttrMap& attrs) const {
bool requires_grad = false;
if (autograd::GradMode::is_enabled() && !JUST(op_expr.IsGradDisabled())) {
requires_grad =
std::any_of(inputs.begin(), inputs.end(),
[](const std::shared_ptr<Tensor>& tensor) { return tensor->requires_grad(); });
}
{
autograd::AutoGradMode mode(false);
JUST(internal_->Apply(op_expr, inputs, outputs, attrs));
if (!JUST(op_expr.IsGradDisabled())) {
requires_grad = std::any_of(
inputs.begin(), inputs.end(),
[](const std::shared_ptr<Tensor>& tensor) { return tensor->requires_grad(); });
}
JUST(DetermineIsLeaf(outputs, inputs.size() == 0, requires_grad));
JUST(DetermineRequiresGrad(outputs, requires_grad));
}
if (autograd::GradMode::is_enabled() && requires_grad) {
if (requires_grad) {
const auto& grad_closure = JUST(op_expr.GetOrCreateOpGradClosure());
grad_closure->Capture(inputs, *outputs, attrs);
......
......@@ -160,6 +160,11 @@ class TestTensor(flow.unittest.TestCase):
test_case.assertTrue(z.requires_grad)
test_case.assertFalse(z.is_leaf)
with flow.no_grad():
m = x + y
test_case.assertTrue(m.is_leaf)
test_case.assertFalse(m.requires_grad)
v = flow.Tensor(*shape, requires_grad=True)
z.retain_grad()
w = v + z
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment