diff --git a/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp b/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp index e9086020fd1177a4e80ee26d115d371ded42ca9d..5d786666c92fd9f94d46f406f413ec53664facea 100644 --- a/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp +++ b/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp @@ -71,7 +71,7 @@ REGISTER_USER_OP("user_sigmoid_backward") }); REGISTER_USER_OP_GRAD("user_sigmoid_forward") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; const auto& grad_op_func = [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("user_sigmoid_backward") @@ -86,6 +86,7 @@ REGISTER_USER_OP_GRAD("user_sigmoid_forward") return ctx->GetOp(grad_op_name).output("dx", 0); }; ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), dx_get_func); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/core/framework/user_op_grad_registry.h b/oneflow/core/framework/user_op_grad_registry.h index cd0134d72d626fe0f3cdef23344b20c978ee9773..6a3027aba340c6c318e1784590743346f7c9d053 100644 --- a/oneflow/core/framework/user_op_grad_registry.h +++ b/oneflow/core/framework/user_op_grad_registry.h @@ -25,7 +25,7 @@ namespace user_op { using AddOpFn = std::function<void(const UserOpConfWrapper&)>; using GenBackwardOpConfFn = std::function<Maybe<void>(const UserOpWrapper&, AddOpFn)>; -using BackwardOpConfGenFn = std::function<void(BackwardOpConfContext*)>; +using BackwardOpConfGenFn = std::function<Maybe<void>(BackwardOpConfContext*)>; struct OpGradRegistryResult { std::string op_type_name; diff --git a/oneflow/core/job_rewriter/user_grad.cpp b/oneflow/core/job_rewriter/user_grad.cpp index 9f9ca3668842b43e788031ed1550fe1d66bde097..fb314211679f4f35dacf3896faa8233fa8345683 100644 --- a/oneflow/core/job_rewriter/user_grad.cpp +++ b/oneflow/core/job_rewriter/user_grad.cpp @@ -37,7 +37,7 @@ Maybe<void> GenerateBackwardOpConf( if (nullptr != val->bw_gen_fn) { // new refined interface user_op::BackwardOpConfContext ctx(fw_user_op, bw_op_confs); - val->bw_gen_fn(&ctx); + JUST(val->bw_gen_fn(&ctx)); } else if (nullptr != val->gen_bw_fn) { // old interface, will be removed when all backward gradient configs are using new interface auto AddOp = [&](const user_op::UserOpConfWrapper& wrapper) { diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 8ba2ab1c900155e90771e21f2b582dce4e874750..4c3188bb5f21fdfd02c1bf531872646e0c6eff06 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -65,7 +65,8 @@ REGISTER_USER_OP("acc") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("repeat") @@ -77,6 +78,7 @@ REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfCo ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index edb365f4819142014cc0bac0a5a4ce3bd0702209..edf04f1ccfaf37eb57a845171c78686552b4dc4f 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -98,7 +98,7 @@ REGISTER_USER_OP("adaptive_avg_pool2d_grad") .SetDataTypeInferFn(InferBWDataType); REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto adaptive_avg_pool2d_grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(adaptive_avg_pool2d_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("adaptive_avg_pool2d_grad") @@ -112,6 +112,7 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") [&ctx, &adaptive_avg_pool2d_grad_op_name]() -> const std::string& { return ctx->GetOp(adaptive_avg_pool2d_grad_op_name).output("dx", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index 801dc85d7070a3a99feeffc2bdeafb27e70a16df..a1a54b6a407e10708822613774f54719ce65d5e7 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -101,7 +101,7 @@ REGISTER_USER_OP("broadcast_like") .SetDataTypeInferFn(InferDataType); REGISTER_USER_OP_GRAD("broadcast_like") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; ctx->DefineOp(x_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("reduce_sum_like") @@ -116,6 +116,7 @@ REGISTER_USER_OP_GRAD("broadcast_like") [&ctx, &x_grad_op_name]() -> const std::string& { return ctx->GetOp(x_grad_op_name).output("y", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index 9b002d84b40d27c5d80fdd4f65d16bad9be60eb8..7df7451e4753d92542608b4c92ce44bd51e44742 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -100,27 +100,29 @@ REGISTER_USER_OP("ctc_loss_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("ctc_loss").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto ctc_loss_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(ctc_loss_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("ctc_loss_grad") - .InputBind("grad_out", ctx->FwOp().output_grad("loss", 0)) - .InputBind("log_probs", ctx->FwOp().input("log_probs", 0)) - .InputBind("targets", ctx->FwOp().input("targets", 0)) - .InputBind("input_lengths", ctx->FwOp().input("input_lengths", 0)) - .InputBind("target_lengths", ctx->FwOp().input("target_lengths", 0)) - .InputBind("loss", ctx->FwOp().output("loss", 0)) - .InputBind("alpha", ctx->FwOp().output("alpha", 0)) - .Attr("blank", ctx->FwOp().attr<int32_t>("blank")) - .Attr("zero_infinity", ctx->FwOp().attr<bool>("zero_infinity")) - .Output("grad") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("log_probs", 0), - [&ctx, &ctc_loss_grad_op_name]() -> const std::string& { - return ctx->GetOp(ctc_loss_grad_op_name).output("grad", 0); - }); -}); +REGISTER_USER_OP_GRAD("ctc_loss") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto ctc_loss_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(ctc_loss_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("ctc_loss_grad") + .InputBind("grad_out", ctx->FwOp().output_grad("loss", 0)) + .InputBind("log_probs", ctx->FwOp().input("log_probs", 0)) + .InputBind("targets", ctx->FwOp().input("targets", 0)) + .InputBind("input_lengths", ctx->FwOp().input("input_lengths", 0)) + .InputBind("target_lengths", ctx->FwOp().input("target_lengths", 0)) + .InputBind("loss", ctx->FwOp().output("loss", 0)) + .InputBind("alpha", ctx->FwOp().output("alpha", 0)) + .Attr("blank", ctx->FwOp().attr<int32_t>("blank")) + .Attr("zero_infinity", ctx->FwOp().attr<bool>("zero_infinity")) + .Output("grad") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("log_probs", 0), + [&ctx, &ctc_loss_grad_op_name]() -> const std::string& { + return ctx->GetOp(ctc_loss_grad_op_name).output("grad", 0); + }); + return Maybe<void>::Ok(); + }); REGISTER_USER_OP("ctc_greedy_decoder") .Input("log_probs") diff --git a/oneflow/user/ops/diag_op.cpp b/oneflow/user/ops/diag_op.cpp index 9a4d28de19e46affefab9d84128cbada16bd90bc..21ee3e89740cfa658df16d7c78542ab632bc9c43 100644 --- a/oneflow/user/ops/diag_op.cpp +++ b/oneflow/user/ops/diag_op.cpp @@ -78,7 +78,8 @@ REGISTER_USER_OP("diag_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("diag_grad") @@ -92,5 +93,6 @@ REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfC ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("dx", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 189dc7fa0c8c591fe02e88eb67380e0c01c4c091..17f670f5764f4f06e9fd01ee94bbd4b1f3f389cc 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -178,27 +178,29 @@ REGISTER_USER_OP("dim_scatter_add_like") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("dim_gather").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto op_grad_name = ctx->FwOp().op_name() + "_grad"; - - ctx->DefineOp(op_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder - .OpTypeName( - "dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, input) -> output - .InputBind("index", ctx->FwOp().input("index", 0)) // scatter.index <- gather.index - .InputBind("input", - ctx->FwOp().output_grad("output", 0)) // scatter.input <- grad of gather.out - .InputBind("like", ctx->FwOp().input("input", 0)) - .Output("output") - .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) - .Build(); - }); - - ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), - [&ctx, &op_grad_name]() -> const std::string& { - return ctx->GetOp(op_grad_name).output("output", 0); - }); -}); +REGISTER_USER_OP_GRAD("dim_gather") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto op_grad_name = ctx->FwOp().op_name() + "_grad"; + + ctx->DefineOp(op_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder + .OpTypeName( + "dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, input) -> output + .InputBind("index", ctx->FwOp().input("index", 0)) // scatter.index <- gather.index + .InputBind("input", + ctx->FwOp().output_grad("output", 0)) // scatter.input <- grad of gather.out + .InputBind("like", ctx->FwOp().input("input", 0)) + .Output("output") + .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) + .Build(); + }); + + ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), + [&ctx, &op_grad_name]() -> const std::string& { + return ctx->GetOp(op_grad_name).output("output", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace user_op diff --git a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp index 12eaeb69e4c2bd2d40d722570dddd38df299f2ad..4bce761d226e1213ea813997c7f06e2483bfa452 100644 --- a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp +++ b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp @@ -69,7 +69,7 @@ Maybe<void> InferDataType(InferContext* ctx) { } user_op::BackwardOpConfGenFn MakeGenBackwardOpFn(const std::string& op_type_name) { - return [=](user_op::BackwardOpConfContext* ctx) -> void { + return [=](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const bool x_need_grad = ctx->FwOp().NeedGenGradTensor4OpInput("x", 0); const bool y_need_grad = ctx->FwOp().NeedGenGradTensor4OpInput("y", 0); const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; @@ -95,6 +95,7 @@ user_op::BackwardOpConfGenFn MakeGenBackwardOpFn(const std::string& op_type_name return ctx->GetOp(grad_op_name).output("dy", 0); }); } + return Maybe<void>::Ok(); }; } diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index be2b0efe6a7bfea3e0f4f4115dee3c2a274b86a6..fd3ab133d114ddbd7e80c04802bf2174831322b9 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -72,21 +72,23 @@ REGISTER_USER_OP("elu_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto elu_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(elu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("elu_grad") - .InputBind("x", ctx->FwOp().input("in", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Attr("alpha", ctx->FwOp().attr<double>("alpha")) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &elu_grad_op_name]() -> const std::string& { - return ctx->GetOp(elu_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto elu_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(elu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("elu_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr("alpha", ctx->FwOp().attr<double>("alpha")) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &elu_grad_op_name]() -> const std::string& { + return ctx->GetOp(elu_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 498c426a0bdfbd11688750751420d740c3843efe..0748daa1b22bc54fd342c575186894395aa6bf3c 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -120,7 +120,7 @@ REGISTER_USER_OP("fused_self_attention_query_mul_key_and_value_grad") }); REGISTER_USER_OP_GRAD("fused_self_attention_query_mul_key_and_value") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { std::string grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { @@ -136,6 +136,7 @@ REGISTER_USER_OP_GRAD("fused_self_attention_query_mul_key_and_value") ctx->FwOp().InputGradBind(user_op::OpArg("hidden_states", 0), [=]() -> const std::string& { return ctx->GetOp(grad_op_name).output("hidden_states_grad", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index a22942e41091c2e95300887b78e3082c41dc5373..1ec34740e6841e5d3f63ade9f882f6f54f4a36bb 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -73,7 +73,7 @@ REGISTER_USER_OP("hardsigmoid_grad") }); REGISTER_USER_OP_GRAD("hardsigmoid") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto hardsigmoid_grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(hardsigmoid_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("hardsigmoid_grad") @@ -86,6 +86,7 @@ REGISTER_USER_OP_GRAD("hardsigmoid") [&ctx, &hardsigmoid_grad_op_name]() -> const std::string& { return ctx->GetOp(hardsigmoid_grad_op_name).output("dx", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index ccd602fc477f8542aa414e66bb07457a78a89b13..ffef66123b65066e042ab4e5f821b3febe6e92ee 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -70,20 +70,22 @@ REGISTER_USER_OP("hardswish_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("hardswish").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto hardswish_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(hardswish_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("hardswish_grad") - .InputBind("x", ctx->FwOp().input("in", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &hardswish_grad_op_name]() -> const std::string& { - return ctx->GetOp(hardswish_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("hardswish") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto hardswish_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(hardswish_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("hardswish_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &hardswish_grad_op_name]() -> const std::string& { + return ctx->GetOp(hardswish_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index a8ce26247e1b817b69a06bb03f97148deb2a191a..32e0a69fb9abb4a28f2bf2357cc1a400ae4993c7 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -82,22 +82,24 @@ REGISTER_USER_OP("hardtanh_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("hardtanh").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto hardtanh_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(hardtanh_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("hardtanh_grad") - .InputBind("y", ctx->FwOp().output("out", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Attr("min_val", ctx->FwOp().attr<double>("min_val")) - .Attr("max_val", ctx->FwOp().attr<double>("max_val")) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &hardtanh_grad_op_name]() -> const std::string& { - return ctx->GetOp(hardtanh_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("hardtanh") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto hardtanh_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(hardtanh_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("hardtanh_grad") + .InputBind("y", ctx->FwOp().output("out", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr("min_val", ctx->FwOp().attr<double>("min_val")) + .Attr("max_val", ctx->FwOp().attr<double>("max_val")) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &hardtanh_grad_op_name]() -> const std::string& { + return ctx->GetOp(hardtanh_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index a43442e81bc2f7200f79591e2a019383e42570a4..205751aa084f60f79a34dce2399cb5433c77dca3 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -84,7 +84,7 @@ REGISTER_USER_OP("hierarchical_parallel_cast_like") .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); REGISTER_USER_OP_GRAD("hierarchical_parallel_cast") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { if (ctx->FwOp().NeedGenGradTensor4OpInput("in", 0)) { const auto& grad_mode = ctx->FwOp().attr<std::string>("grad_mode"); if (grad_mode == "identity") { @@ -122,6 +122,7 @@ REGISTER_USER_OP_GRAD("hierarchical_parallel_cast") UNIMPLEMENTED(); } } + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 6c7cdccbd59249b0939f0218845f23577e9c4fd0..dd2086665333a90513a1600339e509828c05efc6 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -410,11 +410,11 @@ REGISTER_USER_OP("broadcast_matmul_grad_b") }); REGISTER_USER_OP_GRAD("broadcast_matmul") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> void { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { bool transpose_a = ctx->FwOp().attr<bool>("transpose_a"); bool transpose_b = ctx->FwOp().attr<bool>("transpose_b"); double alpha = ctx->FwOp().attr<double>("alpha"); - CHECK(!transpose_a); + CHECK_OR_RETURN(!transpose_a); std::string a_grad_op_name = ctx->FwOp().op_name() + "_a_grad"; ctx->DefineOp(a_grad_op_name, @@ -456,6 +456,7 @@ REGISTER_USER_OP_GRAD("broadcast_matmul") ctx->FwOp().InputGradBind(user_op::OpArg("b", 0), [&]() -> const std::string& { return ctx->GetOp(b_grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index f81d0e6e8cd9623c431869f09c4d4e03418cf193..e31d8fce0ed06bb8d03ec656e743af6bbdfcf760 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -485,7 +485,7 @@ REGISTER_USER_OP("cudnn_fused_normalization_add_relu_grad") #endif REGISTER_USER_OP_GRAD("normalization") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const bool is_training = ctx->FwOp().attr<bool>("training"); const bool is_fp16 = ctx->FwOp().arg_tensor_desc("y", 0).data_type() == DataType::kFloat16; @@ -656,10 +656,11 @@ REGISTER_USER_OP_GRAD("normalization") [&ctx, &beta_identity_op_name]() -> const std::string& { return ctx->GetOp(beta_identity_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); REGISTER_USER_OP_GRAD("normalization_add_relu") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { builder.OpTypeName("normalization_add_relu_grad") @@ -718,6 +719,7 @@ REGISTER_USER_OP_GRAD("normalization_add_relu") [&ctx, &beta_identity_op_name]() -> const std::string& { return ctx->GetOp(beta_identity_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/pack_op.cpp b/oneflow/user/ops/pack_op.cpp index f3f2dcb39660e1ec8bbc48d7c92d26d4a5ebe689..62f8d04b2c93337c2a910f9bef3210f87ef4a7c8 100644 --- a/oneflow/user/ops/pack_op.cpp +++ b/oneflow/user/ops/pack_op.cpp @@ -63,7 +63,8 @@ REGISTER_USER_OP("pack") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("unpack") @@ -75,6 +76,7 @@ REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfC ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index 1f2acb0d65c1b02b161b67205eeb7e773d59a32d..70fcc6a11d88e9df97d3ffe5d2e80477731b1013 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -60,7 +60,7 @@ REGISTER_USER_OP("parallel_cast") .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); REGISTER_USER_OP_GRAD("parallel_cast") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { if (ctx->FwOp().NeedGenGradTensor4OpInput("in", 0)) { const auto& grad_sbp_parallel_str = ctx->FwOp().attr<std::string>("grad_sbp_parallel"); if (grad_sbp_parallel_str.empty()) { @@ -81,6 +81,7 @@ REGISTER_USER_OP_GRAD("parallel_cast") }); } } + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 8586c94d3aa3c157ea71414228ae5d8b237b0e35..b40d6f94d8d5cfe975707f48b95887af323ac879 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -137,7 +137,7 @@ REGISTER_USER_OP("distributed_partial_fc_sample_disable_boxing") }); REGISTER_USER_OP_GRAD("distributed_partial_fc_sample") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto disable_boxing_op_name = ctx->FwOp().op_name() + "_disable_boxing"; ctx->DefineOp(disable_boxing_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("distributed_partial_fc_sample_disable_boxing") @@ -168,6 +168,7 @@ REGISTER_USER_OP_GRAD("distributed_partial_fc_sample") [&ctx, &unsorted_segment_sum_like_op_name]() -> const std::string& { return ctx->GetOp(unsorted_segment_sum_like_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index 08c89c0627ccd6b0e539350f6111acf33f225dc7..f9c4d75824b50e7f335f0cad45570658f8e22b16 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -73,20 +73,22 @@ REGISTER_USER_OP("relu_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("relu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto relu_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(relu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("relu_grad") - .InputBind("y", ctx->FwOp().output("out", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &relu_grad_op_name]() -> const std::string& { - return ctx->GetOp(relu_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("relu").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto relu_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(relu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("relu_grad") + .InputBind("y", ctx->FwOp().output("out", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &relu_grad_op_name]() -> const std::string& { + return ctx->GetOp(relu_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index dd102adf3a481466b5ca188efd8c7101fbe80710..f3d28090110437550f7653c821c40f03756084b1 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -54,7 +54,8 @@ REGISTER_USER_OP("repeat") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("acc") @@ -66,6 +67,7 @@ REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpCon ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/scalar_pow_op.cpp b/oneflow/user/ops/scalar_pow_op.cpp index 5d2bf81bb091c6c3e80c0ae366641feb6ccf8fb5..4cb386395654a4a98b05e2b5071697744f87e92c 100644 --- a/oneflow/user/ops/scalar_pow_op.cpp +++ b/oneflow/user/ops/scalar_pow_op.cpp @@ -68,21 +68,23 @@ REGISTER_USER_OP("scalar_pow_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("scalar_pow").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto scalar_pow_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(scalar_pow_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("scalar_pow_grad") - .InputBind("x", ctx->FwOp().input("in", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Attr<double>("exponent", ctx->FwOp().attr<double>("exponent")) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &scalar_pow_grad_op_name]() -> const std::string& { - return ctx->GetOp(scalar_pow_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("scalar_pow") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto scalar_pow_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(scalar_pow_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("scalar_pow_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr<double>("exponent", ctx->FwOp().attr<double>("exponent")) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &scalar_pow_grad_op_name]() -> const std::string& { + return ctx->GetOp(scalar_pow_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index a41ba41f04060e93d9ef6cbbd158ad52dd3aafe4..e409f7058b345eaf023963ddc9b2a66de657dfa8 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -322,7 +322,7 @@ Maybe<void> GetLogicalSliceSbpSignatures(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { +Maybe<void> GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { const std::string update_grad_op_name = ctx->FwOp().op_name() + "_update_grad"; ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("slice") @@ -358,6 +358,7 @@ void GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&]() -> const std::string& { return ctx->GetOp(x_grad_op_name).output("y", 0); }); + return Maybe<void>::Ok(); } } // namespace diff --git a/oneflow/user/ops/unpack_op.cpp b/oneflow/user/ops/unpack_op.cpp index a1156feab85cfd11e77452ebe466c47bcfeb4987..e8a34bf6fa5502c6077391cf7b050a2358a83254 100644 --- a/oneflow/user/ops/unpack_op.cpp +++ b/oneflow/user/ops/unpack_op.cpp @@ -64,7 +64,8 @@ REGISTER_USER_OP("unpack") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("pack") @@ -76,6 +77,7 @@ REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpCon ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index 5b9350657a2cd3fc19256b1010dcd8ede9170456..13e4ab3dbb03f650312932608fba8692b8e9c275 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -71,43 +71,45 @@ REGISTER_USER_OP("where") }) .SetGetSbpFn(GetWhereSbpSignatures); -REGISTER_USER_OP_GRAD("where").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto zero_op_name = ctx->FwOp().op_name() + "_zero_grad"; - ctx->DefineOp(zero_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("zero_like") - .InputBind("like", ctx->FwOp().input("x", 0)) - .Output("out") - .Build(); - }); +REGISTER_USER_OP_GRAD("where").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto zero_op_name = ctx->FwOp().op_name() + "_zero_grad"; + ctx->DefineOp(zero_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("zero_like") + .InputBind("like", ctx->FwOp().input("x", 0)) + .Output("out") + .Build(); + }); - const auto x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; - ctx->DefineOp(x_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("where") - .InputBind("condition", ctx->FwOp().input("condition", 0)) - .InputBind("x", ctx->FwOp().output_grad("out", 0)) - .InputBind("y", ctx->GetOp(zero_op_name).output("out", 0)) - .Output("out") - .Build(); - }); + const auto x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; + ctx->DefineOp(x_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("where") + .InputBind("condition", ctx->FwOp().input("condition", 0)) + .InputBind("x", ctx->FwOp().output_grad("out", 0)) + .InputBind("y", ctx->GetOp(zero_op_name).output("out", 0)) + .Output("out") + .Build(); + }); - const auto y_grad_op_name = ctx->FwOp().op_name() + "_y_grad"; - ctx->DefineOp(y_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("where") - .InputBind("condition", ctx->FwOp().input("condition", 0)) - .InputBind("x", ctx->GetOp(zero_op_name).output("out", 0)) - .InputBind("y", ctx->FwOp().output_grad("out", 0)) - .Output("out") - .Build(); - }); + const auto y_grad_op_name = ctx->FwOp().op_name() + "_y_grad"; + ctx->DefineOp(y_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("where") + .InputBind("condition", ctx->FwOp().input("condition", 0)) + .InputBind("x", ctx->GetOp(zero_op_name).output("out", 0)) + .InputBind("y", ctx->FwOp().output_grad("out", 0)) + .Output("out") + .Build(); + }); - ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), - [&ctx, &x_grad_op_name]() -> const std::string& { - return ctx->GetOp(x_grad_op_name).output("out", 0); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("y", 0), - [&ctx, &y_grad_op_name]() -> const std::string& { - return ctx->GetOp(y_grad_op_name).output("out", 0); - }); -}); + ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), + [&ctx, &x_grad_op_name]() -> const std::string& { + return ctx->GetOp(x_grad_op_name).output("out", 0); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("y", 0), + [&ctx, &y_grad_op_name]() -> const std::string& { + return ctx->GetOp(y_grad_op_name).output("out", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace oneflow