diff --git a/oneflow/core/framework/user_op_registry.h b/oneflow/core/framework/user_op_registry.h index fa91c65370987549988d2a2564a32c153b63e58a..077789c1b92359945247a66a3f0724304931d0f4 100644 --- a/oneflow/core/framework/user_op_registry.h +++ b/oneflow/core/framework/user_op_registry.h @@ -48,7 +48,7 @@ using SbpSignatureInferFn = std::function<Maybe<void>(InferSbpSignatureFnContext using InputArgModifier = InputBlobModifier; using GetInputArgModifier = std::function<InputArgModifier*(const std::string& in_arg_name, int32_t in_arg_index)>; -using InputArgModifyFn = std::function<void(GetInputArgModifier, const UserOpConfWrapper&)>; +using InputArgModifyFn = std::function<Maybe<void>(GetInputArgModifier, const UserOpConfWrapper&)>; using OutputArgModifier = OutputBlobModifier; using GetOutputArgModifier = std::function<OutputArgModifier*(const std::string& out_arg_name, int32_t out_arg_index)>; diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 91a631b420a4efb2418079ca4d7edd281fdd75dd..59ab048448974c3c79bf036cea9c6868498f54d3 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -489,7 +489,7 @@ class UserOpInferParallelDistributionFnContext }; Maybe<void> UserOp::InitFromOpConf() { - CHECK(op_conf().has_user_conf()); + CHECK_OR_RETURN(op_conf().has_user_conf()); for (const auto& pair : op_conf().user_conf().input()) { EnrollRepeatedInputBn(pair.first, pair.second.s_size()); for (int32_t i = 0; i < pair.second.s_size(); ++i) { @@ -516,7 +516,7 @@ Maybe<void> UserOp::InitFromOpConf() { } return nullptr; }; - val_->input_arg_modify_fn(GetInputArgModifierFn, *user_op_conf_); + JUST(val_->input_arg_modify_fn(GetInputArgModifierFn, *user_op_conf_)); } if (val_->output_arg_modify_fn) { user_op::GetOutputArgModifier GetOutputArgModifierFn = diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp index 9a2390c00558f7f430c340ce8948eef18cca9c29..b2cb5310e782b3339fb8efc98ac83048d6c33107 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_local_opkernel.cpp @@ -303,7 +303,7 @@ Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf> auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier(); return &map->at(ibn); }; - op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper); + JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper)); } if (op_reg_val->output_arg_modify_fn) { user_op::GetOutputArgModifier GetOutputArgModifierFn = diff --git a/oneflow/user/ops/assign_op.cpp b/oneflow/user/ops/assign_op.cpp index ed4d8ddbf0ffe64b7107f633833057f03a5b9d76..f54342c2cce6b45f9f863c47f0529e843c18bec4 100644 --- a/oneflow/user/ops/assign_op.cpp +++ b/oneflow/user/ops/assign_op.cpp @@ -48,19 +48,20 @@ Maybe<void> GetSbpSignatures(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); - CHECK(ref_modifier != nullptr); + CHECK_OR_RETURN(ref_modifier != nullptr); ref_modifier->set_is_mutable(true); user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); - CHECK(value_modifier != nullptr); + CHECK_OR_RETURN(value_modifier != nullptr); value_modifier->set_requires_grad(false); if (conf.has_input("condition", 0)) { user_op::InputArgModifier* condition_modifier = GetInputArgModifierFn("condition", 0); - CHECK(condition_modifier != nullptr); + CHECK_OR_RETURN(condition_modifier != nullptr); condition_modifier->set_requires_grad(false); } + return Maybe<void>::Ok(); } Maybe<void> InferDataType(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/batch_gather_op.cpp b/oneflow/user/ops/batch_gather_op.cpp index 9fbc8fa80c004ae77936a52c75187f192e35c096..0c9ae8bec3db8260bbd838a2f1bc9a880b62a73f 100644 --- a/oneflow/user/ops/batch_gather_op.cpp +++ b/oneflow/user/ops/batch_gather_op.cpp @@ -44,10 +44,11 @@ REGISTER_USER_OP("batch_gather") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t indices_num_axes = diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index b5489b4a8abf033cd78115f2021227ae61f7295c..801dc85d7070a3a99feeffc2bdeafb27e70a16df 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -91,10 +91,11 @@ REGISTER_USER_OP("broadcast_like") .Output("y") .SetTensorDescInferFn(InferTensorDesc) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK(like_modifier != nullptr); + CHECK_OR_RETURN(like_modifier != nullptr); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn(GetSbpSignatures) .SetDataTypeInferFn(InferDataType); diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index 3cedd56f982981a29fc7f14843ab3899269d27e6..a2a7face17c582aa83887fce77db49d7c87b0f15 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -27,10 +27,11 @@ REGISTER_NO_GRAD_USER_OP("cast_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); - CHECK_NOTNULL(dtype_like_modifier); + CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); dtype_like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index a17cee4429b586b30f5d7f9617e0f1adb5df14e0..cb2f0d4351f0547ab82df5f0a254789f3239b3f5 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -45,7 +45,7 @@ REGISTER_NO_GRAD_USER_OP("CategoricalOrdinalEncode") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); table->set_is_mutable(true); table->set_requires_grad(false); @@ -54,6 +54,7 @@ REGISTER_NO_GRAD_USER_OP("CategoricalOrdinalEncode") size->set_requires_grad(false); user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); in->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index 7742a6eab9e12c07abca44e058657e007a1304f1..825b59901000edca88938215e62a0e36d15d4ca4 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -39,9 +39,10 @@ REGISTER_USER_OP("combined_margin_loss") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); label_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 11e741b52b8918bae86a62c8c6d8b11b8dd3e1ad..189dc7fa0c8c591fe02e88eb67380e0c01c4c091 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -59,10 +59,11 @@ REGISTER_USER_OP("dim_gather") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& index_tensor = @@ -131,10 +132,11 @@ REGISTER_USER_OP("dim_scatter_add_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK(like_arg_modifier != nullptr); + CHECK_OR_RETURN(like_arg_modifier != nullptr); like_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& index_tensor = diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index de97b6b610c175351f2edd712d466c2282fc5847..eb865cd9f8c04bd4be522bb0b18451e48fb7b3f6 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -33,9 +33,10 @@ REGISTER_USER_OP("dropout") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* mask = GetInputArgModifierFn("mask", 0); mask->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); diff --git a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp index f2399f7bf041de1201b46fb69811a2eafb1d242b..e24277c4235aefcb8c2212c43d195747538c232a 100644 --- a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp +++ b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp @@ -43,14 +43,15 @@ Maybe<void> InferDataType(user_op::InferContext* ctx) { return Maybe<void>::Ok(); } -void InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* loss_scale = GetInputArgModifierFn("loss_scale", 0); - CHECK(loss_scale != nullptr); + CHECK_OR_RETURN(loss_scale != nullptr); loss_scale->set_is_mutable(true); user_op::InputArgModifier* good_step_counter = GetInputArgModifierFn("good_step_counter", 0); - CHECK(good_step_counter != nullptr); + CHECK_OR_RETURN(good_step_counter != nullptr); good_step_counter->set_is_mutable(true); + return Maybe<void>::Ok(); } } // namespace diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index 63d0a576aa6ba0b7df0b46989107d6aff1c8f1f9..512823f8b958eb558d7ed3b5da7133c9cd0fc601 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -47,14 +47,15 @@ REGISTER_USER_OP("fake_quantization") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); - CHECK(scale != nullptr); + CHECK_OR_RETURN(scale != nullptr); scale->set_requires_grad(false); user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); - CHECK(zero_point != nullptr); + CHECK_OR_RETURN(zero_point != nullptr); zero_point->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index 31cdf422ba65bafc4f20d52e4ba159895565d1a3..b61d473e440a6c18d3aab8a824a87f61f9f8b859 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -167,10 +167,11 @@ REGISTER_USER_OP("fused_bias_add_mask_scale") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK(mask_modifier != nullptr); + CHECK_OR_RETURN(mask_modifier != nullptr); mask_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const auto axis = ctx->Attr<int32_t>("axis"); diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index 008a295d2bab2af52ca131e8ed6cc4d82a65b08d..12fbe80dcd7fda23bd1866073b54dfa7fd0db0a9 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -43,10 +43,11 @@ REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK(mask_modifier != nullptr); + CHECK_OR_RETURN(mask_modifier != nullptr); mask_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); diff --git a/oneflow/user/ops/gather_op.cpp b/oneflow/user/ops/gather_op.cpp index 9e951996fdc30624434d066a11a3a40026c36abd..5f85af9ee749d4b5fc417fc8e137b4e96d695e9c 100644 --- a/oneflow/user/ops/gather_op.cpp +++ b/oneflow/user/ops/gather_op.cpp @@ -42,10 +42,11 @@ REGISTER_USER_OP("gather") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t in_num_axes = diff --git a/oneflow/user/ops/gpt_data_loader_op.cpp b/oneflow/user/ops/gpt_data_loader_op.cpp index 24e87fd021be7aeb5178818166069f08bb206bfa..1a36a0c83a5bfd61707e5b8910f55f0dc8954b27 100644 --- a/oneflow/user/ops/gpt_data_loader_op.cpp +++ b/oneflow/user/ops/gpt_data_loader_op.cpp @@ -82,12 +82,13 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("megatron_gpt_mmap_data_loader") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - if (!conf.has_input("iteration", 0)) { return; } + const user_op::UserOpConfWrapper& conf) -> Maybe<void> { + if (!conf.has_input("iteration", 0)) { return Maybe<void>::Ok(); } user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); - CHECK(input_modifier != nullptr); + CHECK_OR_RETURN(input_modifier != nullptr); input_modifier->set_is_mutable(true); input_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index cdb528e0702a39f279bf81eaa5df4c665b8d1e44..d4810479f762171a73799d563953d480301fd6d4 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -202,10 +202,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_random_crop") }) .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index 565251aec4ca0d8f626d5347342afcfc130d3d7f..d1003ba287f06ed68ae134483266b628aa17a765 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -51,10 +51,11 @@ REGISTER_NO_GRAD_USER_OP("min_max_observer") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK(in != nullptr); + CHECK_OR_RETURN(in != nullptr); in->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index db9f4fac83a62e5522277d361ebbd68ebcb93ef2..deab6b5926b21ae816f2c3011516ae782cc229a8 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -257,27 +257,75 @@ Maybe<void> InferLambUpdateDataType(user_op::InferContext* ctx) { } return Maybe<void>::Ok(); } -void SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const std::string& arg_name, int32_t arg_index) { +Maybe<void> SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const std::string& arg_name, int32_t arg_index) { user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index); - CHECK_NOTNULL(arg_modifier); + CHECK_NOTNULL_OR_RETURN(arg_modifier); arg_modifier->set_is_mutable(true); + return Maybe<void>::Ok(); +} + +Maybe<void> AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); + return Maybe<void>::Ok(); } -void AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0); +Maybe<void> LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "beta1_t", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "beta2_t", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> SgdInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + return Maybe<void>::Ok(); } -void LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "beta1_t", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "beta2_t", 0); +Maybe<void> IndexedSlicesSgdInputArgModifyFn( + const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> MomentumInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> IndexedSlicesMomentumInputArgModifyFn( + const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); + return Maybe<void>::Ok(); +} + +Maybe<void> RmsPropUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "mean_square", 0)); + if (conf.attr<bool>("centered")) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "mean_gradient", 0)); + } + return Maybe<void>::Ok(); +} + +Maybe<void> LarsUpdateInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); + JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0)); + return Maybe<void>::Ok(); } Maybe<void> InferRmsPropUpdateTensorDesc(user_op::InferContext* ctx) { @@ -370,10 +418,7 @@ REGISTER_NO_GRAD_USER_OP("sgd_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - }) + .SetInputArgModifyFn(SgdInputArgModifyFn) .SetDataTypeInferFn(InferSGDUpdateDataType); REGISTER_NO_GRAD_USER_OP("indexed_slices_sgd_update") @@ -404,10 +449,7 @@ REGISTER_NO_GRAD_USER_OP("indexed_slices_sgd_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - }) + .SetInputArgModifyFn(IndexedSlicesSgdInputArgModifyFn) .SetDataTypeInferFn(InferIndexedSlicesSGDUpdateDataType); REGISTER_NO_GRAD_USER_OP("momentum_update") @@ -436,11 +478,7 @@ REGISTER_NO_GRAD_USER_OP("momentum_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0); - }) + .SetInputArgModifyFn(MomentumInputArgModifyFn) .SetDataTypeInferFn(InferMomentumUpdateDataType); REGISTER_NO_GRAD_USER_OP("indexed_slices_momentum_update") @@ -475,11 +513,7 @@ REGISTER_NO_GRAD_USER_OP("indexed_slices_momentum_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0); - }) + .SetInputArgModifyFn(IndexedSlicesMomentumInputArgModifyFn) .SetDataTypeInferFn(InferIndexedSlicesMomentumUpdateDataType); REGISTER_NO_GRAD_USER_OP("adam_update") @@ -638,14 +672,7 @@ REGISTER_NO_GRAD_USER_OP("rmsprop_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "mean_square", 0); - if (conf.attr<bool>("centered")) { - SetInputArgModifierMutable(GetInputArgModifierFn, "mean_gradient", 0); - } - }) + .SetInputArgModifyFn(RmsPropUpdateInputArgModifyFn) .SetDataTypeInferFn(InferRmsPropUpdateDataType); REGISTER_NO_GRAD_USER_OP("lars_update") @@ -675,11 +702,7 @@ REGISTER_NO_GRAD_USER_OP("lars_update") } return Maybe<void>::Ok(); }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> void { - SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0); - SetInputArgModifierMutable(GetInputArgModifierFn, "momentum", 0); - }) + .SetInputArgModifyFn(LarsUpdateInputArgModifyFn) .SetDataTypeInferFn(InferLarsUpdateDataType); } // namespace diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 02f23e50838771f05f811f8832b22b5094d59356..8c4c59dc8e180e6a171aebeb3c2cd29c06ef4e8b 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -59,25 +59,26 @@ REGISTER_NO_GRAD_USER_OP("moving_average_min_max_observer") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK(in != nullptr); + CHECK_OR_RETURN(in != nullptr); in->set_requires_grad(false); user_op::InputArgModifier* current_train_step = GetInputArgModifierFn("current_train_step", 0); - CHECK(current_train_step != nullptr); + CHECK_OR_RETURN(current_train_step != nullptr); current_train_step->set_requires_grad(false); user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); - CHECK(moving_max != nullptr); + CHECK_OR_RETURN(moving_max != nullptr); moving_max->set_requires_grad(false); moving_max->set_is_mutable(true); user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); - CHECK(moving_min != nullptr); + CHECK_OR_RETURN(moving_min != nullptr); moving_min->set_requires_grad(false); moving_min->set_is_mutable(true); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index f2a8da1f9fcdd65cd32ed105b1ca0d191ae86b57..761969fe03bcbaf702ed4d9f51994cb30dc5e0b3 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -129,10 +129,11 @@ REGISTER_USER_OP("gather_nd") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& params_tensor = @@ -175,10 +176,11 @@ REGISTER_USER_OP("scatter_nd") .SetTensorDescInferFn(InferScatterNdTensorDesc) .SetDataTypeInferFn(InferScatterNdDataType) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& indices_desc = @@ -257,10 +259,11 @@ REGISTER_USER_OP("tensor_scatter_nd_update") .SetDataTypeInferFn(InferTensorScatterNdOptDataType) .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }); REGISTER_USER_OP("tensor_scatter_nd_add") @@ -272,10 +275,11 @@ REGISTER_USER_OP("tensor_scatter_nd_add") .SetDataTypeInferFn(InferTensorScatterNdOptDataType) .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }); REGISTER_USER_OP_GRAD("gather_nd") diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index 4b418d19cc2a9000c5128022447945a95eec47ef..f81d0e6e8cd9623c431869f09c4d4e03418cf193 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -68,8 +68,8 @@ std::function<Maybe<void>(const std::string&)> MakeSetParamDataTypeFn(user_op::I }; } -void FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { bool training; if (conf.op_type_name() == "normalization") { training = conf.attr<bool>("training"); @@ -77,13 +77,14 @@ void FwInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierF training = true; } user_op::InputArgModifier* moving_mean_modifier = GetInputArgModifierFn("moving_mean", 0); - CHECK(moving_mean_modifier != nullptr); + CHECK_OR_RETURN(moving_mean_modifier != nullptr); moving_mean_modifier->set_is_mutable(training); moving_mean_modifier->set_requires_grad(false); user_op::InputArgModifier* moving_variance_modifier = GetInputArgModifierFn("moving_variance", 0); - CHECK(moving_variance_modifier != nullptr); + CHECK_OR_RETURN(moving_variance_modifier != nullptr); moving_variance_modifier->set_is_mutable(training); moving_variance_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); } Maybe<void> FwGetSbpFn(user_op::SbpContext* ctx) { diff --git a/oneflow/user/ops/ofrecord_decoder_ops.cpp b/oneflow/user/ops/ofrecord_decoder_ops.cpp index 2fa02969bff2684fb19c247c2f2b8628be84b82a..ff19fe78952c53c9b34a1343baaf41689121562c 100644 --- a/oneflow/user/ops/ofrecord_decoder_ops.cpp +++ b/oneflow/user/ops/ofrecord_decoder_ops.cpp @@ -38,10 +38,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() @@ -70,10 +71,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_bytes_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { @@ -97,10 +99,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() @@ -142,10 +145,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder_random_crop") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); diff --git a/oneflow/user/ops/one_hot_op.cpp b/oneflow/user/ops/one_hot_op.cpp index 58b965839b742c5ac3f4ae36d7a2d18f4c49c726..ac50d4aef6f8db861d3e63364b022f8251ba9ac6 100644 --- a/oneflow/user/ops/one_hot_op.cpp +++ b/oneflow/user/ops/one_hot_op.cpp @@ -40,10 +40,11 @@ REGISTER_NO_GRAD_USER_OP("one_hot") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK(indices_modifier != nullptr); + CHECK_OR_RETURN(indices_modifier != nullptr); indices_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& indices_tensor = diff --git a/oneflow/user/ops/onerec_decoder_op.cpp b/oneflow/user/ops/onerec_decoder_op.cpp index 4c59f59c338a53188029c68c68420922f8772b3a..9506e5b3693ac8e1d82e46103d26d637533838c5 100644 --- a/oneflow/user/ops/onerec_decoder_op.cpp +++ b/oneflow/user/ops/onerec_decoder_op.cpp @@ -41,10 +41,11 @@ REGISTER_NO_GRAD_CPU_ONLY_USER_OP("onerec_decoder") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL(in_modifier); + CHECK_NOTNULL_OR_RETURN(in_modifier); in_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() diff --git a/oneflow/user/ops/pad2d_ops.cpp b/oneflow/user/ops/pad2d_ops.cpp index fa5a51d9641ee566b3ee32517f3032de47d99906..c130099857b8a8a16095acb897d650db14c7de3a 100644 --- a/oneflow/user/ops/pad2d_ops.cpp +++ b/oneflow/user/ops/pad2d_ops.cpp @@ -83,10 +83,11 @@ REGISTER_USER_OP("reflection_pad2d") }) .SetGetSbpFn(GetOpSbpSignature) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); + CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); @@ -175,10 +176,11 @@ REGISTER_USER_OP("replication_pad2d") }) .SetGetSbpFn(GetOpSbpSignature) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); + CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); @@ -269,10 +271,11 @@ REGISTER_USER_OP("constant_pad2d") }) .SetGetSbpFn(GetOpSbpSignature) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL(x_modifier); + CHECK_NOTNULL_OR_RETURN(x_modifier); x_modifier->set_requires_grad(true); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index 603e6ae0c2b864375549db7e054c5e74fd700151..8586c94d3aa3c157ea71414228ae5d8b237b0e35 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -63,10 +63,11 @@ REGISTER_USER_OP("distributed_partial_fc_sample") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); - CHECK_NOTNULL(label_modifier); + CHECK_NOTNULL_OR_RETURN(label_modifier); label_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { ctx->NewBuilder() diff --git a/oneflow/user/ops/reduce_like_ops.cpp b/oneflow/user/ops/reduce_like_ops.cpp index a8e29123402b79e2bf5abc2348777fb45f21cc05..c8fc889debc5e5e3403ab42a38ed353fb63e4800 100644 --- a/oneflow/user/ops/reduce_like_ops.cpp +++ b/oneflow/user/ops/reduce_like_ops.cpp @@ -85,10 +85,11 @@ REGISTER_NO_GRAD_USER_OP("reduce_sum_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK(like_arg_modifier != nullptr); + CHECK_OR_RETURN(like_arg_modifier != nullptr); like_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }); } // namespace oneflow diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index f7e9b9c4d2d714fc79867d89347dcd2160fab20a..8a3f360f98d8a00d98dd2b45f176ba9b3fdd2271 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -40,10 +40,11 @@ REGISTER_USER_OP("reshape_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); diff --git a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp index 086c4d2ef4eedb2523bba00c3b56e4aeb1e49269..225d833ac6ee98c653ef5c48c73909c0675bc082 100644 --- a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp +++ b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp @@ -22,9 +22,10 @@ REGISTER_USER_OP("sigmoid_cross_entropy") .Input("label") .Output("loss") .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); @@ -58,9 +59,10 @@ REGISTER_USER_OP("sigmoid_cross_entropy_grad") .Input("label") .Output("prediction_diff") .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 2ae35451549a7efbe3c298472a39b1b14a321cd6..f8c2cf0047fb738da4de2fb8c6ae0438220383e2 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -137,14 +137,15 @@ Maybe<void> GetSliceGradOpSbpSignature(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); - CHECK_NOTNULL(dy_modifier); + CHECK_NOTNULL_OR_RETURN(dy_modifier); dy_modifier->set_requires_grad(false); user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); } Maybe<void> InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) { @@ -269,14 +270,15 @@ Maybe<void> GetLogicalSliceAssignSbpSignatures(user_op::SbpContext* ctx) { return Maybe<void>::Ok(); } -void InferLogicalSliceAssignInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +Maybe<void> InferLogicalSliceAssignInputArgModifier( + user_op::GetInputArgModifier GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); - CHECK(ref_modifier != nullptr); + CHECK_OR_RETURN(ref_modifier != nullptr); ref_modifier->set_is_mutable(true); user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); - CHECK(value_modifier != nullptr); + CHECK_OR_RETURN(value_modifier != nullptr); value_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); } Maybe<void> InferLogicalSliceTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/smooth_l1_loss_op.cpp b/oneflow/user/ops/smooth_l1_loss_op.cpp index 3d55fd24e8cc6585f6b1be48c4ce9d3da7ed6252..fd753fbdc82cd1b9e25c77f333d3b1f8bcc1a48a 100644 --- a/oneflow/user/ops/smooth_l1_loss_op.cpp +++ b/oneflow/user/ops/smooth_l1_loss_op.cpp @@ -31,10 +31,11 @@ REGISTER_USER_OP("smooth_l1_loss") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); - CHECK(label_modifier != nullptr); + CHECK_OR_RETURN(label_modifier != nullptr); label_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_tensor = diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index cfcb1da0742e07d96349053bc8943b4bc34ab9e7..da8189d61c986e0e8cb73c14ab83cf073e314cd0 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -23,9 +23,10 @@ REGISTER_USER_OP("softmax_cross_entropy") .Output("prob") //'prob' is just for compute prediction's grad, prob's grad will be ignored .Output("out") .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index 2d74bc04e4bf29dbf7007c78143d4aedc3499386..8972b7d14fa2bcf81c653622b0eb538753c90e88 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -157,10 +157,11 @@ void GenBackwardOpConf4SparseCrossEntropy(const std::string& op_type_name, .Attr<int64_t>("depth") \ .SetTensorDescInferFn(InferTensorDescFn) \ .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) { \ + const user_op::UserOpConfWrapper&) -> Maybe<void> { \ user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK(label_modifier != nullptr); \ + CHECK_OR_RETURN(label_modifier != nullptr); \ label_modifier->set_requires_grad(false); \ + return Maybe<void>::Ok(); \ }) \ .SetGetSbpFn(GetSbpFn<sbp_sig>) \ .SetDataTypeInferFn(InferDataType); diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index ee778b14006dc2e2c44a7efaa587702bfea29799..e8b8e2df8a1fb42ba66925f77c03620a0a633deb 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -162,10 +162,11 @@ void GenBackwardOpConf4SparseSoftmaxCrossEntropy(const std::string& op_type_name .Attr<int64_t>("depth") \ .SetTensorDescInferFn(InferTensorDescFn) \ .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) { \ + const user_op::UserOpConfWrapper&) -> Maybe<void> { \ user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK(label_modifier != nullptr); \ + CHECK_OR_RETURN(label_modifier != nullptr); \ label_modifier->set_requires_grad(false); \ + return Maybe<void>::Ok(); \ }) \ .SetGetSbpFn(GetSbpFn<sbp_sig>) \ .SetDataTypeInferFn(InferDataType); diff --git a/oneflow/user/ops/split_like_op.cpp b/oneflow/user/ops/split_like_op.cpp index f5da325d76554381c94d99992fe6658fe2e0f1de..06249e0e372913ed15568ac56477dfeef16986e7 100644 --- a/oneflow/user/ops/split_like_op.cpp +++ b/oneflow/user/ops/split_like_op.cpp @@ -67,13 +67,14 @@ Maybe<void> InferDataType(user_op::InferContext* ctx) { return Maybe<void>::Ok(); } -void SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& user_op_conf) { +Maybe<void> SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, + const user_op::UserOpConfWrapper& user_op_conf) { FOR_RANGE(int32_t, i, 0, user_op_conf.input_size("like")) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", i); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); } + return Maybe<void>::Ok(); } Maybe<void> GetSbpSignature(user_op::SbpContext* ctx) { diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 47cdf0516432ff5d68ed6fa39cf7229fe97e55e3..722af2d5b135bb06d0f580531bd60f63e91e0efd 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -264,10 +264,11 @@ REGISTER_REDUCE_DEVICE_STAGE_USER_OP_GRAD("reduce_max_device_stage", "reduce_max .SetTensorDescInferFn(InferReduceGlobalStageTensorDescFn) \ .SetDataTypeInferFn(InferReduceGlobalStageDtypeFn) \ .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) { \ + const user_op::UserOpConfWrapper&) -> Maybe<void> { \ user_op::InputArgModifier* device_count_modifier = \ GetInputArgModifierFn("device_count", 0); \ device_count_modifier->set_requires_grad(false); \ + return Maybe<void>::Ok(); \ }) \ .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { \ ctx->NewBuilder() \ diff --git a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp index 7200a83ef6d805641d3ebea9641dd00ce6e86d36..675bda4fa8b3156c2f60338c8a5d8b99b81d58fc 100644 --- a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp @@ -50,10 +50,11 @@ REGISTER_USER_OP("unsorted_batch_segment_sum") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL(segment_ids_modifier); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t segment_ids_num_axes = diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 8048e15062d0059cdc27a706610eb804237998ac..6a072a911208afdb1271f5976354d4384b7b6749 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -46,10 +46,11 @@ REGISTER_USER_OP("unsorted_segment_sum") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL(segment_ids_modifier); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t data_num_axes = @@ -131,13 +132,14 @@ REGISTER_USER_OP("unsorted_segment_sum_like") return Maybe<void>::Ok(); }) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL(segment_ids_modifier); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); segment_ids_modifier->set_requires_grad(false); user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL(like_modifier); + CHECK_NOTNULL_OR_RETURN(like_modifier); like_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { const int64_t data_num_axes = diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index 78795a92b636091630d1bb7ae85b0e691f8e812b..5b9350657a2cd3fc19256b1010dcd8ede9170456 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -56,9 +56,10 @@ REGISTER_USER_OP("where") .Output("out") .SetTensorDescInferFn(InferWhereTensorDesc) .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { + const user_op::UserOpConfWrapper&) -> Maybe<void> { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("condition", 0); cond_arg_modifier->set_requires_grad(false); + return Maybe<void>::Ok(); }) .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { const DataType& cond_dtype = ctx->InputDType("condition", 0);