From bd2d3dc2467e1d222f295593e41c5d4f8d8f2fc1 Mon Sep 17 00:00:00 2001 From: liufengwei0103 <2472937968@qq.com> Date: Sun, 18 Jul 2021 16:29:09 +0800 Subject: [PATCH] bw_gen_fn return maybe (#5455) * modified SetInputArgModifyFn * Delete the CHECK changes in the assign_op.cpp file * Format * Modified the OutputArgModifyFn interface * add return * maybe error stack from CheckAndConstructOp to OutputArgModifier callback function * maybe error stack from CheckAndConstructOp to OutputArgModifier callback function * OutputArgModifier return maybe part_1 * maybe error stack from CheckAndConstructOp to OutputArgModifier callback function * input_arg_modifier return maybe * gen_bw_fn return maybe * bw_gen_fn return maybe * fix bug: return Maybe without JUST Co-authored-by: aishangjj <702572275@qq.com> --- .../user_sigmoid/user_sigmoid_cpp_def.cpp | 3 +- .../core/framework/user_op_grad_registry.h | 2 +- oneflow/core/job_rewriter/user_grad.cpp | 2 +- oneflow/user/ops/acc_op.cpp | 4 +- oneflow/user/ops/adaptive_pool_op.cpp | 3 +- oneflow/user/ops/broadcast_like_op.cpp | 3 +- oneflow/user/ops/ctc_loss_op.cpp | 44 ++++++------ oneflow/user/ops/diag_op.cpp | 4 +- oneflow/user/ops/dim_gather_op.cpp | 44 ++++++------ .../ops/elementwise_maximum_minimum_ops.cpp | 3 +- oneflow/user/ops/elu_op.cpp | 32 +++++---- ..._attention_query_mul_key_and_value_ops.cpp | 3 +- oneflow/user/ops/hardsigmoid_op.cpp | 3 +- oneflow/user/ops/hardswish_op.cpp | 30 ++++---- oneflow/user/ops/hardtanh_op.cpp | 34 ++++----- .../ops/hierarchical_parallel_cast_op.cpp | 3 +- oneflow/user/ops/matmul_op.cpp | 5 +- oneflow/user/ops/normalization_op.cpp | 6 +- oneflow/user/ops/pack_op.cpp | 4 +- oneflow/user/ops/parallel_cast_op.cpp | 3 +- oneflow/user/ops/partial_fc_sample_op.cpp | 3 +- oneflow/user/ops/relu_op.cpp | 30 ++++---- oneflow/user/ops/repeat_op.cpp | 4 +- oneflow/user/ops/scalar_pow_op.cpp | 32 +++++---- oneflow/user/ops/slice_op.cpp | 3 +- oneflow/user/ops/unpack_op.cpp | 4 +- oneflow/user/ops/where_op.cpp | 72 ++++++++++--------- 27 files changed, 211 insertions(+), 172 deletions(-) diff --git a/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp b/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp index e9086020f..5d786666c 100644 --- a/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp +++ b/oneflow/compatible_single_client_python/test/custom_ops/user_sigmoid/user_sigmoid_cpp_def.cpp @@ -71,7 +71,7 @@ REGISTER_USER_OP("user_sigmoid_backward") }); REGISTER_USER_OP_GRAD("user_sigmoid_forward") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; const auto& grad_op_func = [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("user_sigmoid_backward") @@ -86,6 +86,7 @@ REGISTER_USER_OP_GRAD("user_sigmoid_forward") return ctx->GetOp(grad_op_name).output("dx", 0); }; ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), dx_get_func); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/core/framework/user_op_grad_registry.h b/oneflow/core/framework/user_op_grad_registry.h index cd0134d72..6a3027aba 100644 --- a/oneflow/core/framework/user_op_grad_registry.h +++ b/oneflow/core/framework/user_op_grad_registry.h @@ -25,7 +25,7 @@ namespace user_op { using AddOpFn = std::function<void(const UserOpConfWrapper&)>; using GenBackwardOpConfFn = std::function<Maybe<void>(const UserOpWrapper&, AddOpFn)>; -using BackwardOpConfGenFn = std::function<void(BackwardOpConfContext*)>; +using BackwardOpConfGenFn = std::function<Maybe<void>(BackwardOpConfContext*)>; struct OpGradRegistryResult { std::string op_type_name; diff --git a/oneflow/core/job_rewriter/user_grad.cpp b/oneflow/core/job_rewriter/user_grad.cpp index 9f9ca3668..fb3142116 100644 --- a/oneflow/core/job_rewriter/user_grad.cpp +++ b/oneflow/core/job_rewriter/user_grad.cpp @@ -37,7 +37,7 @@ Maybe<void> GenerateBackwardOpConf( if (nullptr != val->bw_gen_fn) { // new refined interface user_op::BackwardOpConfContext ctx(fw_user_op, bw_op_confs); - val->bw_gen_fn(&ctx); + JUST(val->bw_gen_fn(&ctx)); } else if (nullptr != val->gen_bw_fn) { // old interface, will be removed when all backward gradient configs are using new interface auto AddOp = [&](const user_op::UserOpConfWrapper& wrapper) { diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 8ba2ab1c9..4c3188bb5 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -65,7 +65,8 @@ REGISTER_USER_OP("acc") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("repeat") @@ -77,6 +78,7 @@ REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfCo ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index edb365f48..edf04f1cc 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -98,7 +98,7 @@ REGISTER_USER_OP("adaptive_avg_pool2d_grad") .SetDataTypeInferFn(InferBWDataType); REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto adaptive_avg_pool2d_grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(adaptive_avg_pool2d_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("adaptive_avg_pool2d_grad") @@ -112,6 +112,7 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") [&ctx, &adaptive_avg_pool2d_grad_op_name]() -> const std::string& { return ctx->GetOp(adaptive_avg_pool2d_grad_op_name).output("dx", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index 801dc85d7..a1a54b6a4 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -101,7 +101,7 @@ REGISTER_USER_OP("broadcast_like") .SetDataTypeInferFn(InferDataType); REGISTER_USER_OP_GRAD("broadcast_like") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; ctx->DefineOp(x_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("reduce_sum_like") @@ -116,6 +116,7 @@ REGISTER_USER_OP_GRAD("broadcast_like") [&ctx, &x_grad_op_name]() -> const std::string& { return ctx->GetOp(x_grad_op_name).output("y", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index 9b002d84b..7df7451e4 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -100,27 +100,29 @@ REGISTER_USER_OP("ctc_loss_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("ctc_loss").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto ctc_loss_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(ctc_loss_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("ctc_loss_grad") - .InputBind("grad_out", ctx->FwOp().output_grad("loss", 0)) - .InputBind("log_probs", ctx->FwOp().input("log_probs", 0)) - .InputBind("targets", ctx->FwOp().input("targets", 0)) - .InputBind("input_lengths", ctx->FwOp().input("input_lengths", 0)) - .InputBind("target_lengths", ctx->FwOp().input("target_lengths", 0)) - .InputBind("loss", ctx->FwOp().output("loss", 0)) - .InputBind("alpha", ctx->FwOp().output("alpha", 0)) - .Attr("blank", ctx->FwOp().attr<int32_t>("blank")) - .Attr("zero_infinity", ctx->FwOp().attr<bool>("zero_infinity")) - .Output("grad") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("log_probs", 0), - [&ctx, &ctc_loss_grad_op_name]() -> const std::string& { - return ctx->GetOp(ctc_loss_grad_op_name).output("grad", 0); - }); -}); +REGISTER_USER_OP_GRAD("ctc_loss") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto ctc_loss_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(ctc_loss_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("ctc_loss_grad") + .InputBind("grad_out", ctx->FwOp().output_grad("loss", 0)) + .InputBind("log_probs", ctx->FwOp().input("log_probs", 0)) + .InputBind("targets", ctx->FwOp().input("targets", 0)) + .InputBind("input_lengths", ctx->FwOp().input("input_lengths", 0)) + .InputBind("target_lengths", ctx->FwOp().input("target_lengths", 0)) + .InputBind("loss", ctx->FwOp().output("loss", 0)) + .InputBind("alpha", ctx->FwOp().output("alpha", 0)) + .Attr("blank", ctx->FwOp().attr<int32_t>("blank")) + .Attr("zero_infinity", ctx->FwOp().attr<bool>("zero_infinity")) + .Output("grad") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("log_probs", 0), + [&ctx, &ctc_loss_grad_op_name]() -> const std::string& { + return ctx->GetOp(ctc_loss_grad_op_name).output("grad", 0); + }); + return Maybe<void>::Ok(); + }); REGISTER_USER_OP("ctc_greedy_decoder") .Input("log_probs") diff --git a/oneflow/user/ops/diag_op.cpp b/oneflow/user/ops/diag_op.cpp index 9a4d28de1..21ee3e897 100644 --- a/oneflow/user/ops/diag_op.cpp +++ b/oneflow/user/ops/diag_op.cpp @@ -78,7 +78,8 @@ REGISTER_USER_OP("diag_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("diag_grad") @@ -92,5 +93,6 @@ REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfC ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("dx", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 189dc7fa0..17f670f57 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -178,27 +178,29 @@ REGISTER_USER_OP("dim_scatter_add_like") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("dim_gather").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto op_grad_name = ctx->FwOp().op_name() + "_grad"; - - ctx->DefineOp(op_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder - .OpTypeName( - "dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, input) -> output - .InputBind("index", ctx->FwOp().input("index", 0)) // scatter.index <- gather.index - .InputBind("input", - ctx->FwOp().output_grad("output", 0)) // scatter.input <- grad of gather.out - .InputBind("like", ctx->FwOp().input("input", 0)) - .Output("output") - .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) - .Build(); - }); - - ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), - [&ctx, &op_grad_name]() -> const std::string& { - return ctx->GetOp(op_grad_name).output("output", 0); - }); -}); +REGISTER_USER_OP_GRAD("dim_gather") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto op_grad_name = ctx->FwOp().op_name() + "_grad"; + + ctx->DefineOp(op_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder + .OpTypeName( + "dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, input) -> output + .InputBind("index", ctx->FwOp().input("index", 0)) // scatter.index <- gather.index + .InputBind("input", + ctx->FwOp().output_grad("output", 0)) // scatter.input <- grad of gather.out + .InputBind("like", ctx->FwOp().input("input", 0)) + .Output("output") + .Attr("dim", ctx->FwOp().attr<int32_t>("dim")) + .Build(); + }); + + ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), + [&ctx, &op_grad_name]() -> const std::string& { + return ctx->GetOp(op_grad_name).output("output", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace user_op diff --git a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp index 12eaeb69e..4bce761d2 100644 --- a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp +++ b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp @@ -69,7 +69,7 @@ Maybe<void> InferDataType(InferContext* ctx) { } user_op::BackwardOpConfGenFn MakeGenBackwardOpFn(const std::string& op_type_name) { - return [=](user_op::BackwardOpConfContext* ctx) -> void { + return [=](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const bool x_need_grad = ctx->FwOp().NeedGenGradTensor4OpInput("x", 0); const bool y_need_grad = ctx->FwOp().NeedGenGradTensor4OpInput("y", 0); const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; @@ -95,6 +95,7 @@ user_op::BackwardOpConfGenFn MakeGenBackwardOpFn(const std::string& op_type_name return ctx->GetOp(grad_op_name).output("dy", 0); }); } + return Maybe<void>::Ok(); }; } diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index be2b0efe6..fd3ab133d 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -72,21 +72,23 @@ REGISTER_USER_OP("elu_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto elu_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(elu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("elu_grad") - .InputBind("x", ctx->FwOp().input("in", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Attr("alpha", ctx->FwOp().attr<double>("alpha")) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &elu_grad_op_name]() -> const std::string& { - return ctx->GetOp(elu_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto elu_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(elu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("elu_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr("alpha", ctx->FwOp().attr<double>("alpha")) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &elu_grad_op_name]() -> const std::string& { + return ctx->GetOp(elu_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 498c426a0..0748daa1b 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -120,7 +120,7 @@ REGISTER_USER_OP("fused_self_attention_query_mul_key_and_value_grad") }); REGISTER_USER_OP_GRAD("fused_self_attention_query_mul_key_and_value") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { std::string grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { @@ -136,6 +136,7 @@ REGISTER_USER_OP_GRAD("fused_self_attention_query_mul_key_and_value") ctx->FwOp().InputGradBind(user_op::OpArg("hidden_states", 0), [=]() -> const std::string& { return ctx->GetOp(grad_op_name).output("hidden_states_grad", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index a22942e41..1ec34740e 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -73,7 +73,7 @@ REGISTER_USER_OP("hardsigmoid_grad") }); REGISTER_USER_OP_GRAD("hardsigmoid") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto hardsigmoid_grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(hardsigmoid_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("hardsigmoid_grad") @@ -86,6 +86,7 @@ REGISTER_USER_OP_GRAD("hardsigmoid") [&ctx, &hardsigmoid_grad_op_name]() -> const std::string& { return ctx->GetOp(hardsigmoid_grad_op_name).output("dx", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index ccd602fc4..ffef66123 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -70,20 +70,22 @@ REGISTER_USER_OP("hardswish_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("hardswish").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto hardswish_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(hardswish_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("hardswish_grad") - .InputBind("x", ctx->FwOp().input("in", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &hardswish_grad_op_name]() -> const std::string& { - return ctx->GetOp(hardswish_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("hardswish") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto hardswish_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(hardswish_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("hardswish_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &hardswish_grad_op_name]() -> const std::string& { + return ctx->GetOp(hardswish_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index a8ce26247..32e0a69fb 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -82,22 +82,24 @@ REGISTER_USER_OP("hardtanh_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("hardtanh").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto hardtanh_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(hardtanh_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("hardtanh_grad") - .InputBind("y", ctx->FwOp().output("out", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Attr("min_val", ctx->FwOp().attr<double>("min_val")) - .Attr("max_val", ctx->FwOp().attr<double>("max_val")) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &hardtanh_grad_op_name]() -> const std::string& { - return ctx->GetOp(hardtanh_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("hardtanh") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto hardtanh_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(hardtanh_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("hardtanh_grad") + .InputBind("y", ctx->FwOp().output("out", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr("min_val", ctx->FwOp().attr<double>("min_val")) + .Attr("max_val", ctx->FwOp().attr<double>("max_val")) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &hardtanh_grad_op_name]() -> const std::string& { + return ctx->GetOp(hardtanh_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index a43442e81..205751aa0 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -84,7 +84,7 @@ REGISTER_USER_OP("hierarchical_parallel_cast_like") .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); REGISTER_USER_OP_GRAD("hierarchical_parallel_cast") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { if (ctx->FwOp().NeedGenGradTensor4OpInput("in", 0)) { const auto& grad_mode = ctx->FwOp().attr<std::string>("grad_mode"); if (grad_mode == "identity") { @@ -122,6 +122,7 @@ REGISTER_USER_OP_GRAD("hierarchical_parallel_cast") UNIMPLEMENTED(); } } + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index 6c7cdccbd..dd2086665 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -410,11 +410,11 @@ REGISTER_USER_OP("broadcast_matmul_grad_b") }); REGISTER_USER_OP_GRAD("broadcast_matmul") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> void { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { bool transpose_a = ctx->FwOp().attr<bool>("transpose_a"); bool transpose_b = ctx->FwOp().attr<bool>("transpose_b"); double alpha = ctx->FwOp().attr<double>("alpha"); - CHECK(!transpose_a); + CHECK_OR_RETURN(!transpose_a); std::string a_grad_op_name = ctx->FwOp().op_name() + "_a_grad"; ctx->DefineOp(a_grad_op_name, @@ -456,6 +456,7 @@ REGISTER_USER_OP_GRAD("broadcast_matmul") ctx->FwOp().InputGradBind(user_op::OpArg("b", 0), [&]() -> const std::string& { return ctx->GetOp(b_grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index f81d0e6e8..e31d8fce0 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -485,7 +485,7 @@ REGISTER_USER_OP("cudnn_fused_normalization_add_relu_grad") #endif REGISTER_USER_OP_GRAD("normalization") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const bool is_training = ctx->FwOp().attr<bool>("training"); const bool is_fp16 = ctx->FwOp().arg_tensor_desc("y", 0).data_type() == DataType::kFloat16; @@ -656,10 +656,11 @@ REGISTER_USER_OP_GRAD("normalization") [&ctx, &beta_identity_op_name]() -> const std::string& { return ctx->GetOp(beta_identity_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); REGISTER_USER_OP_GRAD("normalization_add_relu") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { builder.OpTypeName("normalization_add_relu_grad") @@ -718,6 +719,7 @@ REGISTER_USER_OP_GRAD("normalization_add_relu") [&ctx, &beta_identity_op_name]() -> const std::string& { return ctx->GetOp(beta_identity_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/pack_op.cpp b/oneflow/user/ops/pack_op.cpp index f3f2dcb39..62f8d04b2 100644 --- a/oneflow/user/ops/pack_op.cpp +++ b/oneflow/user/ops/pack_op.cpp @@ -63,7 +63,8 @@ REGISTER_USER_OP("pack") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("unpack") @@ -75,6 +76,7 @@ REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfC ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index 1f2acb0d6..70fcc6a11 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -60,7 +60,7 @@ REGISTER_USER_OP("parallel_cast") .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); REGISTER_USER_OP_GRAD("parallel_cast") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { if (ctx->FwOp().NeedGenGradTensor4OpInput("in", 0)) { const auto& grad_sbp_parallel_str = ctx->FwOp().attr<std::string>("grad_sbp_parallel"); if (grad_sbp_parallel_str.empty()) { @@ -81,6 +81,7 @@ REGISTER_USER_OP_GRAD("parallel_cast") }); } } + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 8586c94d3..b40d6f94d 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -137,7 +137,7 @@ REGISTER_USER_OP("distributed_partial_fc_sample_disable_boxing") }); REGISTER_USER_OP_GRAD("distributed_partial_fc_sample") - .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { const auto disable_boxing_op_name = ctx->FwOp().op_name() + "_disable_boxing"; ctx->DefineOp(disable_boxing_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("distributed_partial_fc_sample_disable_boxing") @@ -168,6 +168,7 @@ REGISTER_USER_OP_GRAD("distributed_partial_fc_sample") [&ctx, &unsorted_segment_sum_like_op_name]() -> const std::string& { return ctx->GetOp(unsorted_segment_sum_like_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index 08c89c062..f9c4d7582 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -73,20 +73,22 @@ REGISTER_USER_OP("relu_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("relu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto relu_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(relu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("relu_grad") - .InputBind("y", ctx->FwOp().output("out", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &relu_grad_op_name]() -> const std::string& { - return ctx->GetOp(relu_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("relu").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto relu_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(relu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("relu_grad") + .InputBind("y", ctx->FwOp().output("out", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &relu_grad_op_name]() -> const std::string& { + return ctx->GetOp(relu_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index dd102adf3..f3d280901 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -54,7 +54,8 @@ REGISTER_USER_OP("repeat") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("acc") @@ -66,6 +67,7 @@ REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpCon ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/scalar_pow_op.cpp b/oneflow/user/ops/scalar_pow_op.cpp index 5d2bf81bb..4cb386395 100644 --- a/oneflow/user/ops/scalar_pow_op.cpp +++ b/oneflow/user/ops/scalar_pow_op.cpp @@ -68,21 +68,23 @@ REGISTER_USER_OP("scalar_pow_grad") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("scalar_pow").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto scalar_pow_grad_op_name = ctx->FwOp().op_name() + "_grad"; - ctx->DefineOp(scalar_pow_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("scalar_pow_grad") - .InputBind("x", ctx->FwOp().input("in", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) - .Attr<double>("exponent", ctx->FwOp().attr<double>("exponent")) - .Output("dx") - .Build(); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), - [&ctx, &scalar_pow_grad_op_name]() -> const std::string& { - return ctx->GetOp(scalar_pow_grad_op_name).output("dx", 0); - }); -}); +REGISTER_USER_OP_GRAD("scalar_pow") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto scalar_pow_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(scalar_pow_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("scalar_pow_grad") + .InputBind("x", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr<double>("exponent", ctx->FwOp().attr<double>("exponent")) + .Output("dx") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &scalar_pow_grad_op_name]() -> const std::string& { + return ctx->GetOp(scalar_pow_grad_op_name).output("dx", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index a41ba41f0..e409f7058 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -322,7 +322,7 @@ Maybe<void> GetLogicalSliceSbpSignatures(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { +Maybe<void> GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { const std::string update_grad_op_name = ctx->FwOp().op_name() + "_update_grad"; ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("slice") @@ -358,6 +358,7 @@ void GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&]() -> const std::string& { return ctx->GetOp(x_grad_op_name).output("y", 0); }); + return Maybe<void>::Ok(); } } // namespace diff --git a/oneflow/user/ops/unpack_op.cpp b/oneflow/user/ops/unpack_op.cpp index a1156feab..e8a34bf6f 100644 --- a/oneflow/user/ops/unpack_op.cpp +++ b/oneflow/user/ops/unpack_op.cpp @@ -64,7 +64,8 @@ REGISTER_USER_OP("unpack") return Maybe<void>::Ok(); }); -REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { +REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) + -> Maybe<void> { const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("pack") @@ -76,6 +77,7 @@ REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpCon ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), [&ctx, &grad_op_name]() -> const std::string& { return ctx->GetOp(grad_op_name).output("out", 0); }); + return Maybe<void>::Ok(); }); } // namespace diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index 5b9350657..13e4ab3db 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -71,43 +71,45 @@ REGISTER_USER_OP("where") }) .SetGetSbpFn(GetWhereSbpSignatures); -REGISTER_USER_OP_GRAD("where").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { - const auto zero_op_name = ctx->FwOp().op_name() + "_zero_grad"; - ctx->DefineOp(zero_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("zero_like") - .InputBind("like", ctx->FwOp().input("x", 0)) - .Output("out") - .Build(); - }); +REGISTER_USER_OP_GRAD("where").SetBackwardOpConfGenFn( + [](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { + const auto zero_op_name = ctx->FwOp().op_name() + "_zero_grad"; + ctx->DefineOp(zero_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("zero_like") + .InputBind("like", ctx->FwOp().input("x", 0)) + .Output("out") + .Build(); + }); - const auto x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; - ctx->DefineOp(x_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("where") - .InputBind("condition", ctx->FwOp().input("condition", 0)) - .InputBind("x", ctx->FwOp().output_grad("out", 0)) - .InputBind("y", ctx->GetOp(zero_op_name).output("out", 0)) - .Output("out") - .Build(); - }); + const auto x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; + ctx->DefineOp(x_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("where") + .InputBind("condition", ctx->FwOp().input("condition", 0)) + .InputBind("x", ctx->FwOp().output_grad("out", 0)) + .InputBind("y", ctx->GetOp(zero_op_name).output("out", 0)) + .Output("out") + .Build(); + }); - const auto y_grad_op_name = ctx->FwOp().op_name() + "_y_grad"; - ctx->DefineOp(y_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { - return builder.OpTypeName("where") - .InputBind("condition", ctx->FwOp().input("condition", 0)) - .InputBind("x", ctx->GetOp(zero_op_name).output("out", 0)) - .InputBind("y", ctx->FwOp().output_grad("out", 0)) - .Output("out") - .Build(); - }); + const auto y_grad_op_name = ctx->FwOp().op_name() + "_y_grad"; + ctx->DefineOp(y_grad_op_name, [&ctx, &zero_op_name](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("where") + .InputBind("condition", ctx->FwOp().input("condition", 0)) + .InputBind("x", ctx->GetOp(zero_op_name).output("out", 0)) + .InputBind("y", ctx->FwOp().output_grad("out", 0)) + .Output("out") + .Build(); + }); - ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), - [&ctx, &x_grad_op_name]() -> const std::string& { - return ctx->GetOp(x_grad_op_name).output("out", 0); - }); - ctx->FwOp().InputGradBind(user_op::OpArg("y", 0), - [&ctx, &y_grad_op_name]() -> const std::string& { - return ctx->GetOp(y_grad_op_name).output("out", 0); - }); -}); + ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), + [&ctx, &x_grad_op_name]() -> const std::string& { + return ctx->GetOp(x_grad_op_name).output("out", 0); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("y", 0), + [&ctx, &y_grad_op_name]() -> const std::string& { + return ctx->GetOp(y_grad_op_name).output("out", 0); + }); + return Maybe<void>::Ok(); + }); } // namespace oneflow -- GitLab