From e39e3c62957c4c4d7c447fe72eab7ea01cd9de01 Mon Sep 17 00:00:00 2001 From: liufengwei0103 <2472937968@qq.com> Date: Fri, 16 Jul 2021 01:20:24 +0800 Subject: [PATCH] Input arg modifier return maybe (#5453) * modified SetInputArgModifyFn * Delete the CHECK changes in the assign_op.cpp file * Format * Modified the OutputArgModifyFn interface * add return * maybe error stack from CheckAndConstructOp to OutputArgModifier callback function * maybe error stack from CheckAndConstructOp to OutputArgModifier callback function * OutputArgModifier return maybe part_1 * maybe error stack from CheckAndConstructOp to OutputArgModifier callback function * input_arg_modifier return maybe * change lambda for JUST macro * fix conflicts Co-authored-by: aishangjj <702572275@qq.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/core/framework/user_op_registry.h | 2 +- oneflow/core/operator/user_op.cpp | 4 +- .../user/kernels/stateful_local_opkernel.cpp | 2 +- oneflow/user/ops/assign_op.cpp | 11 +- oneflow/user/ops/batch_gather_op.cpp | 5 +- oneflow/user/ops/broadcast_like_op.cpp | 5 +- oneflow/user/ops/cast_like_op.cpp | 5 +- .../ops/categorical_ordinal_encode_op.cpp | 3 +- oneflow/user/ops/combined_margin_loss_op.cpp | 3 +- oneflow/user/ops/dim_gather_op.cpp | 10 +- oneflow/user/ops/dropout_op.cpp | 3 +- .../ops/dynamic_loss_scale_schedule_op.cpp | 9 +- oneflow/user/ops/fake_quantization_op.cpp | 7 +- oneflow/user/ops/fused_bias_add_op.cpp | 5 +- ...fused_scale_tril_softmax_mask_scale_op.cpp | 5 +- oneflow/user/ops/gather_op.cpp | 5 +- oneflow/user/ops/gpt_data_loader_op.cpp | 7 +- oneflow/user/ops/image_preprocess_ops.cpp | 5 +- oneflow/user/ops/min_max_observer_op.cpp | 5 +- oneflow/user/ops/model_update_ops.cpp | 115 +++++++++++------- .../moving_average_min_max_observer_op.cpp | 11 +- oneflow/user/ops/nd_index_slice_ops.cpp | 20 +-- oneflow/user/ops/normalization_op.cpp | 9 +- oneflow/user/ops/ofrecord_decoder_ops.cpp | 20 +-- oneflow/user/ops/one_hot_op.cpp | 5 +- oneflow/user/ops/onerec_decoder_op.cpp | 5 +- oneflow/user/ops/pad2d_ops.cpp | 15 ++- oneflow/user/ops/partial_fc_sample_op.cpp | 5 +- oneflow/user/ops/reduce_like_ops.cpp | 5 +- oneflow/user/ops/reshape_like_op.cpp | 5 +- oneflow/user/ops/sigmoid_cross_entropy_op.cpp | 6 +- oneflow/user/ops/slice_op.cpp | 18 +-- oneflow/user/ops/smooth_l1_loss_op.cpp | 5 +- oneflow/user/ops/softmax_cross_entropy_op.cpp | 3 +- oneflow/user/ops/sparse_cross_entropy_op.cpp | 5 +- .../ops/sparse_softmax_cross_entropy_op.cpp | 5 +- oneflow/user/ops/split_like_op.cpp | 7 +- oneflow/user/ops/two_stage_reduce_ops.cpp | 3 +- .../ops/unsorted_batch_segment_sum_op.cpp | 5 +- oneflow/user/ops/unsorted_segment_sum_op.cpp | 12 +- oneflow/user/ops/where_op.cpp | 3 +- 41 files changed, 230 insertions(+), 158 deletions(-) diff --git a/oneflow/core/framework/user_op_registry.h b/oneflow/core/framework/user_op_registry.h index fa91c6537..077789c1b 100644 --- a/oneflow/core/framework/user_op_registry.h +++ b/oneflow/core/framework/user_op_registry.h @@ -48,7 +48,7 @@ using SbpSignatureInferFn = std::function<Maybe<void>(InferSbpSignatureFnContext using InputArgModifier = InputBlobModifier; using GetInputArgModifier = std::function<InputArgModifier*(const std::string& in_arg_name, int32_t in_arg_index)>; -using InputArgModifyFn = std::function<void(GetInputArgModifier, const UserOpConfWrapper&)>; +using InputArgModifyFn = std::function<Maybe<void>(GetInputArgModifier, const UserOpConfWrapper&)>; using OutputArgModifier = OutputBlobModifier; using GetOutputArgModifier = std::function<OutputArgModifier*(const std::string& out_arg_name, int32_t out_arg_index)>; diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 91a631b42..59ab04844 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -489,7 +489,7 @@ class UserOpInferParallelDistributionFnContext }; Maybe<void> UserOp::InitFromOpConf() { - CHECK(op_conf().has_user_conf()); + CHECK_OR_RETURN(op_conf().has_user_conf()); for (const auto& pair : op_conf().user_conf().input()) { EnrollRepeatedInputBn(pair.first, pair.second.s_size()); for (int32_t i = 0; i < pair.second.s_size(); ++i) { @@ -516,7 +516,7 @@ Maybe<void> UserOp::InitFromOpConf() { } return nullptr; }; - val_->input_arg_modify_fn(GetInputArgModifierFn, *user_op_conf_); + JUST(val_->input_arg_modify_fn(GetInputArgModifierFn, *user_op_conf_)); } if (val_->output_arg_modify_fn) { user_op::GetOutputArgModifier GetOutputArgModifierFn = diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp index 9a2390c00..b2cb5310e 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_local_opkernel.cpp @@ -303,7 +303,7 @@ Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf> auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier(); return &map->at(ibn); }; - op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper); + JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper)); } if (op_reg_val->output_arg_modify_fn) { user_op::GetOutputArgModifier GetOutputArgModifierFn = diff --git a/oneflow/user/ops/assign_op.cpp b/oneflow/user/ops/assign_op.cpp index ed4d8ddbf..f54342c2c 100644 --- a/oneflow/user/ops/assign_op.cpp +++ b/oneflow/user/ops/assign_op.cpp @@ -48,19 +48,20 @@ Maybe<void> GetSbpSignatures(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); - CHECK(ref_modifier != nullptr); + CHECK_OR_RETURN(ref_modifier != nullptr); ref_modifier->set_is_mutable(true); user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); - CHECK(value_modifier != nullptr); + CHECK_OR_RETURN(value_modifier != nullptr); value_modifier->set_requires_grad(false); if (conf.has_input("condition", 0)) { user_op::InputArgModifier* condition_modifier = GetInputArgModifierFn("condition", 0); - CHECK(condition_modifier != nullptr); + CHECK_OR_RETURN(condition_modifier != nullptr); condition_modifier->set_requires_grad(false); } + return Maybe<void>::Ok(); } Maybe<void> InferDataType(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/batch_gather_op.cpp b/oneflow/user/ops/batch_gather_op.cpp index 9fbc8fa80..0c9ae8bec 100644 --- a/oneflow/user/ops/batch_gather_op.cpp +++ b/oneflow/user/ops/batch_gather_op.cpp @@ -44,10 +44,11 @@ REGISTER_USER_OP("batch_gather") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t indices_num_axes = diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index b5489b4a8..801dc85d7 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -91,10 +91,11 @@ REGISTER_USER_OP("broadcast_like") .Output("y") .SetTensorDescInferFn(InferTensorDesc) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK(like_modifier != nullptr); + CHECK_OR_RETURN(like_modifier != nullptr); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn(GetSbpSignatures) .SetDataTypeInferFn(InferDataType); diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index 3cedd56f9..a2a7face1 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -27,10 +27,11 @@ REGISTER_NO_GRAD_USER_OP("cast_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); - CHECK_NOTNULL(dtype_like_modifier); + CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); dtype_like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index a17cee442..cb2f0d435 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -45,7 +45,7 @@ REGISTER_NO_GRAD_USER_OP("CategoricalOrdinalEncode") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); table->set_is_mutable(true); table->set_requires_grad(false); @@ -54,6 +54,7 @@ REGISTER_NO_GRAD_USER_OP("CategoricalOrdinalEncode") size->set_requires_grad(false); user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); in->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index 7742a6eab..825b59901 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -39,9 +39,10 @@ REGISTER_USER_OP("combined_margin_loss") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); label_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 11e741b52..189dc7fa0 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -59,10 +59,11 @@ REGISTER_USER_OP("dim_gather") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& index_tensor = @@ -131,10 +132,11 @@ REGISTER_USER_OP("dim_scatter_add_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK(like_arg_modifier != nullptr); + CHECK_OR_RETURN(like_arg_modifier != nullptr); like_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& index_tensor = diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index de97b6b61..eb865cd9f 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -33,9 +33,10 @@ REGISTER_USER_OP("dropout") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* mask = GetInputArgModifierFn("mask", 0); mask->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); diff --git a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp index f2399f7bf..e24277c42 100644 --- a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp +++ b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp @@ -43,14 +43,15 @@ Maybe<void> InferDataType(user_op::InferContext* ctx) { return Maybe<void>::Ok(); } -void InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* loss_scale = GetInputArgModifierFn("loss_scale", 0); - CHECK(loss_scale != nullptr); + CHECK_OR_RETURN(loss_scale != nullptr); loss_scale->set_is_mutable(true); user_op::InputArgModifier* good_step_counter = GetInputArgModifierFn("good_step_counter", 0); - CHECK(good_step_counter != nullptr); + CHECK_OR_RETURN(good_step_counter != nullptr); good_step_counter->set_is_mutable(true); + return Maybe<void>::Ok(); } } // namespace diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index 63d0a576a..512823f8b 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -47,14 +47,15 @@ REGISTER_USER_OP("fake_quantization") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); - CHECK(scale != nullptr); + CHECK_OR_RETURN(scale != nullptr); scale->set_requires_grad(false); user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); - CHECK(zero_point != nullptr); + CHECK_OR_RETURN(zero_point != nullptr); zero_point->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index 31cdf422b..b61d473e4 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -167,10 +167,11 @@ REGISTER_USER_OP("fused_bias_add_mask_scale") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK(mask_modifier != nullptr); + CHECK_OR_RETURN(mask_modifier != nullptr); mask_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const auto axis = ctx->Attr<int32_t>("axis"); diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index 008a295d2..12fbe80dc 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -43,10 +43,11 @@ REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK(mask_modifier != nullptr); + CHECK_OR_RETURN(mask_modifier != nullptr); mask_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); diff --git a/oneflow/user/ops/gather_op.cpp b/oneflow/user/ops/gather_op.cpp index 9e951996f..5f85af9ee 100644 --- a/oneflow/user/ops/gather_op.cpp +++ b/oneflow/user/ops/gather_op.cpp @@ -42,10 +42,11 @@ REGISTER_USER_OP("gather") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t in_num_axes = diff --git a/oneflow/user/ops/gpt_data_loader_op.cpp b/oneflow/user/ops/gpt_data_loader_op.cpp index 24e87fd02..1a36a0c83 100644 --- a/oneflow/user/ops/gpt_data_loader_op.cpp +++ b/oneflow/user/ops/gpt_data_loader_op.cpp @@ -82,12 +82,13 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("megatron_gpt_mmap_data_loader") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - if (!conf.has_input("iteration", 0)) { return; } + const user_op::UserOpConfWrapper& conf) -> Maybe<void> { + if (!conf.has_input("iteration", 0)) { return Maybe<void>::Ok(); } user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); - CHECK(input_modifier != nullptr); + CHECK_OR_RETURN(input_modifier != nullptr); input_modifier->set_is_mutable(true); input_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index cdb528e07..d4810479f 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -202,10 +202,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_random_crop") }) .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index 565251aec..d1003ba28 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -51,10 +51,11 @@ REGISTER_NO_GRAD_USER_OP("min_max_observer") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK(in != nullptr); + CHECK_OR_RETURN(in != nullptr); in->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index db9f4fac8..deab6b592 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -257,27 +257,75 @@ Maybe<void> InferLambUpdateDataType(user_op::InferContext* ctx) { } return Maybe<void>::Ok(); } -void SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const std::string& arg_name, int32_t arg_index) { +Maybe<void> SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const std::string& arg_name, int32_t arg_index) { user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index); - CHECK_NOTNULL(arg_modifier); + CHECK_NOTNULL_OR_RETURN(arg_modifier); arg_modifier->set_is_mutable(true); + return Maybe<void>::Ok(); +} + +Maybe<void> AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); + return Maybe<void>::Ok(); } -void AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0); +Maybe<void> LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "beta1_t", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "beta2_t", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> SgdInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + return Maybe<void>::Ok(); } -void LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "beta1_t", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "beta2_t", 0); +Maybe<void> IndexedSlicesSgdInputArgModifyFn( + const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> MomentumInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> IndexedSlicesMomentumInputArgModifyFn( + const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> RmsPropUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "mean_square", 0)); + if (conf.attr<bool>("centered")) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "mean_gradient", 0)); + } + return Maybe<void>::Ok(); +} + +Maybe<void> LarsUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); + return Maybe<void>::Ok(); } Maybe<void> InferRmsPropUpdateTensorDesc(user_op::InferContext* ctx) { @@ -370,10 +418,7 @@ REGISTER_NO_GRAD_USER_OP("sgd_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - }) + .SetInputArgModifyFn(SgdInputArgModifyFn) .SetDataTypeInferFn(InferSGDUpdateDataType); REGISTER_NO_GRAD_USER_OP("indexed_slices_sgd_update") @@ -404,10 +449,7 @@ REGISTER_NO_GRAD_USER_OP("indexed_slices_sgd_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - }) + .SetInputArgModifyFn(IndexedSlicesSgdInputArgModifyFn) .SetDataTypeInferFn(InferIndexedSlicesSGDUpdateDataType); REGISTER_NO_GRAD_USER_OP("momentum_update") @@ -436,11 +478,7 @@ REGISTER_NO_GRAD_USER_OP("momentum_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0); - }) + .SetInputArgModifyFn(MomentumInputArgModifyFn) .SetDataTypeInferFn(InferMomentumUpdateDataType); REGISTER_NO_GRAD_USER_OP("indexed_slices_momentum_update") @@ -475,11 +513,7 @@ REGISTER_NO_GRAD_USER_OP("indexed_slices_momentum_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0); - }) + .SetInputArgModifyFn(IndexedSlicesMomentumInputArgModifyFn) .SetDataTypeInferFn(InferIndexedSlicesMomentumUpdateDataType); REGISTER_NO_GRAD_USER_OP("adam_update") @@ -638,14 +672,7 @@ REGISTER_NO_GRAD_USER_OP("rmsprop_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "mean_square", 0); - if (conf.attr<bool>("centered")) { - SetInputArgModifierMutable(GetInputArgModifierFn, "mean_gradient", 0); - } - }) + .SetInputArgModifyFn(RmsPropUpdateInputArgModifyFn) .SetDataTypeInferFn(InferRmsPropUpdateDataType); REGISTER_NO_GRAD_USER_OP("lars_update") @@ -675,11 +702,7 @@ REGISTER_NO_GRAD_USER_OP("lars_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0); - }) + .SetInputArgModifyFn(LarsUpdateInputArgModifyFn) .SetDataTypeInferFn(InferLarsUpdateDataType); } // namespace diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 02f23e508..8c4c59dc8 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -59,25 +59,26 @@ REGISTER_NO_GRAD_USER_OP("moving_average_min_max_observer") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK(in != nullptr); + CHECK_OR_RETURN(in != nullptr); in->set_requires_grad(false); user_op::InputArgModifier* current_train_step = GetInputArgModifierFn("current_train_step", 0); - CHECK(current_train_step != nullptr); + CHECK_OR_RETURN(current_train_step != nullptr); current_train_step->set_requires_grad(false); user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); - CHECK(moving_max != nullptr); + CHECK_OR_RETURN(moving_max != nullptr); moving_max->set_requires_grad(false); moving_max->set_is_mutable(true); user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); - CHECK(moving_min != nullptr); + CHECK_OR_RETURN(moving_min != nullptr); moving_min->set_requires_grad(false); moving_min->set_is_mutable(true); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index f2a8da1f9..761969fe0 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -129,10 +129,11 @@ REGISTER_USER_OP("gather_nd") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& params_tensor = @@ -175,10 +176,11 @@ REGISTER_USER_OP("scatter_nd") .SetTensorDescInferFn(InferScatterNdTensorDesc) .SetDataTypeInferFn(InferScatterNdDataType) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& indices_desc = @@ -257,10 +259,11 @@ REGISTER_USER_OP("tensor_scatter_nd_update") .SetDataTypeInferFn(InferTensorScatterNdOptDataType) .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }); REGISTER_USER_OP("tensor_scatter_nd_add") @@ -272,10 +275,11 @@ REGISTER_USER_OP("tensor_scatter_nd_add") .SetDataTypeInferFn(InferTensorScatterNdOptDataType) .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }); REGISTER_USER_OP_GRAD("gather_nd") diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index 4b418d19c..f81d0e6e8 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -68,8 +68,8 @@ std::function<Maybe<void>(const std::string&)> MakeSetParamDataTypeFn(user_op::I }; } -void FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { bool training; if (conf.op_type_name() == "normalization") { training = conf.attr<bool>("training"); @@ -77,13 +77,14 @@ void FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierF training = true; } user_op::InputArgModifier* moving_mean_modifier = GetInputArgModifierFn("moving_mean", 0); - CHECK(moving_mean_modifier != nullptr); + CHECK_OR_RETURN(moving_mean_modifier != nullptr); moving_mean_modifier->set_is_mutable(training); moving_mean_modifier->set_requires_grad(false); user_op::InputArgModifier* moving_variance_modifier = GetInputArgModifierFn("moving_variance", 0); - CHECK(moving_variance_modifier != nullptr); + CHECK_OR_RETURN(moving_variance_modifier != nullptr); moving_variance_modifier->set_is_mutable(training); moving_variance_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); } Maybe<void> FwGetSbpFn(user_op::SbpContext* ctx) { diff --git a/oneflow/user/ops/ofrecord_decoder_ops.cpp b/oneflow/user/ops/ofrecord_decoder_ops.cpp index 2fa02969b..ff19fe789 100644 --- a/oneflow/user/ops/ofrecord_decoder_ops.cpp +++ b/oneflow/user/ops/ofrecord_decoder_ops.cpp @@ -38,10 +38,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() @@ -70,10 +71,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_bytes_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { @@ -97,10 +99,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() @@ -142,10 +145,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder_random_crop") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); diff --git a/oneflow/user/ops/one_hot_op.cpp b/oneflow/user/ops/one_hot_op.cpp index 58b965839..ac50d4aef 100644 --- a/oneflow/user/ops/one_hot_op.cpp +++ b/oneflow/user/ops/one_hot_op.cpp @@ -40,10 +40,11 @@ REGISTER_NO_GRAD_USER_OP("one_hot") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& indices_tensor = diff --git a/oneflow/user/ops/onerec_decoder_op.cpp b/oneflow/user/ops/onerec_decoder_op.cpp index 4c59f59c3..9506e5b36 100644 --- a/oneflow/user/ops/onerec_decoder_op.cpp +++ b/oneflow/user/ops/onerec_decoder_op.cpp @@ -41,10 +41,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("onerec_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() diff --git a/oneflow/user/ops/pad2d_ops.cpp b/oneflow/user/ops/pad2d_ops.cpp index fa5a51d96..c13009985 100644 --- a/oneflow/user/ops/pad2d_ops.cpp +++ b/oneflow/user/ops/pad2d_ops.cpp @@ -83,10 +83,11 @@ REGISTER_USER_OP("reflection_pad2d") }) .SetGetSbpFn(GetOpSbpSignature) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); + CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); @@ -175,10 +176,11 @@ REGISTER_USER_OP("replication_pad2d") }) .SetGetSbpFn(GetOpSbpSignature) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); + CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); @@ -269,10 +271,11 @@ REGISTER_USER_OP("constant_pad2d") }) .SetGetSbpFn(GetOpSbpSignature) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); + CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 603e6ae0c..8586c94d3 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -63,10 +63,11 @@ REGISTER_USER_OP("distributed_partial_fc_sample") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); - CHECK_NOTNULL(label_modifier); + CHECK_NOTNULL_OR_RETURN(label_modifier); label_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() diff --git a/oneflow/user/ops/reduce_like_ops.cpp b/oneflow/user/ops/reduce_like_ops.cpp index a8e291234..c8fc889de 100644 --- a/oneflow/user/ops/reduce_like_ops.cpp +++ b/oneflow/user/ops/reduce_like_ops.cpp @@ -85,10 +85,11 @@ REGISTER_NO_GRAD_USER_OP("reduce_sum_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK(like_arg_modifier != nullptr); + CHECK_OR_RETURN(like_arg_modifier != nullptr); like_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index f7e9b9c4d..8a3f360f9 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -40,10 +40,11 @@ REGISTER_USER_OP("reshape_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); diff --git a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp index 086c4d2ef..225d833ac 100644 --- a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp +++ b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp @@ -22,9 +22,10 @@ REGISTER_USER_OP("sigmoid_cross_entropy") .Input("label") .Output("loss") .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); @@ -58,9 +59,10 @@ REGISTER_USER_OP("sigmoid_cross_entropy_grad") .Input("label") .Output("prediction_diff") .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 2ae354515..f8c2cf004 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -137,14 +137,15 @@ Maybe<void> GetSliceGradOpSbpSignature(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); - CHECK_NOTNULL(dy_modifier); + CHECK_NOTNULL_OR_RETURN(dy_modifier); dy_modifier->set_requires_grad(false); user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); } Maybe<void> InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) { @@ -269,14 +270,15 @@ Maybe<void> GetLogicalSliceAssignSbpSignatures(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void InferLogicalSliceAssignInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InferLogicalSliceAssignInputArgModifier( + user_op::GetInputArgModifier GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); - CHECK(ref_modifier != nullptr); + CHECK_OR_RETURN(ref_modifier != nullptr); ref_modifier->set_is_mutable(true); user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); - CHECK(value_modifier != nullptr); + CHECK_OR_RETURN(value_modifier != nullptr); value_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); } Maybe<void> InferLogicalSliceTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/smooth_l1_loss_op.cpp b/oneflow/user/ops/smooth_l1_loss_op.cpp index 3d55fd24e..fd753fbdc 100644 --- a/oneflow/user/ops/smooth_l1_loss_op.cpp +++ b/oneflow/user/ops/smooth_l1_loss_op.cpp @@ -31,10 +31,11 @@ REGISTER_USER_OP("smooth_l1_loss") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); - CHECK(label_modifier != nullptr); + CHECK_OR_RETURN(label_modifier != nullptr); label_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_tensor = diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index cfcb1da07..da8189d61 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -23,9 +23,10 @@ REGISTER_USER_OP("softmax_cross_entropy") .Output("prob") //'prob' is just for compute prediction's grad, prob's grad will be ignored .Output("out") .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index 2d74bc04e..8972b7d14 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -157,10 +157,11 @@ void GenBackwardOpConf4SparseCrossEntropy(const std::string& op_type_name, .Attr<int64_t>("depth") \ .SetTensorDescInferFn(InferTensorDescFn) \ .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) { \ + const user_op::UserOpConfWrapper&) -> Maybe<void> { \ user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK(label_modifier != nullptr); \ + CHECK_OR_RETURN(label_modifier != nullptr); \ label_modifier->set_requires_grad(false); \ + return Maybe<void>::Ok(); \ }) \ .SetGetSbpFn(GetSbpFn<sbp_sig>) \ .SetDataTypeInferFn(InferDataType); diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index ee778b140..e8b8e2df8 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -162,10 +162,11 @@ void GenBackwardOpConf4SparseSoftmaxCrossEntropy(const std::string& op_type_name .Attr<int64_t>("depth") \ .SetTensorDescInferFn(InferTensorDescFn) \ .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) { \ + const user_op::UserOpConfWrapper&) -> Maybe<void> { \ user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK(label_modifier != nullptr); \ + CHECK_OR_RETURN(label_modifier != nullptr); \ label_modifier->set_requires_grad(false); \ + return Maybe<void>::Ok(); \ }) \ .SetGetSbpFn(GetSbpFn<sbp_sig>) \ .SetDataTypeInferFn(InferDataType); diff --git a/oneflow/user/ops/split_like_op.cpp b/oneflow/user/ops/split_like_op.cpp index f5da325d7..06249e0e3 100644 --- a/oneflow/user/ops/split_like_op.cpp +++ b/oneflow/user/ops/split_like_op.cpp @@ -67,13 +67,14 @@ Maybe<void> InferDataType(user_op::InferContext* ctx) { return Maybe<void>::Ok(); } -void SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& user_op_conf) { +Maybe<void> SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper& user_op_conf) { FOR_RANGE(int32_t, i, 0, user_op_conf.input_size("like")) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", i); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); } + return Maybe<void>::Ok(); } Maybe<void> GetSbpSignature(user_op::SbpContext* ctx) { diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 47cdf0516..722af2d5b 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -264,10 +264,11 @@ REGISTER_REDUCE_DEVICE_STAGE_USER_OP_GRAD("reduce_max_device_stage", "reduce_max .SetTensorDescInferFn(InferReduceGlobalStageTensorDescFn) \ .SetDataTypeInferFn(InferReduceGlobalStageDtypeFn) \ .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) { \ + const user_op::UserOpConfWrapper&) -> Maybe<void> { \ user_op::InputArgModifier* device_count_modifier = \ GetInputArgModifierFn("device_count", 0); \ device_count_modifier->set_requires_grad(false); \ + return Maybe<void>::Ok(); \ }) \ .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { \ ctx->NewBuilder() \ diff --git a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp index 7200a83ef..675bda4fa 100644 --- a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp @@ -50,10 +50,11 @@ REGISTER_USER_OP("unsorted_batch_segment_sum") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL(segment_ids_modifier); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t segment_ids_num_axes = diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 8048e1506..6a072a911 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -46,10 +46,11 @@ REGISTER_USER_OP("unsorted_segment_sum") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL(segment_ids_modifier); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t data_num_axes = @@ -131,13 +132,14 @@ REGISTER_USER_OP("unsorted_segment_sum_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL(segment_ids_modifier); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t data_num_axes = diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index 78795a92b..5b9350657 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -56,9 +56,10 @@ REGISTER_USER_OP("where") .Output("out") .SetTensorDescInferFn(InferWhereTensorDesc) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("condition", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const DataType& cond_dtype = ctx->InputDType("condition", 0); -- GitLab