diff --git a/oneflow/core/common/error.cpp b/oneflow/core/common/error.cpp index 09da92016b25a5a7be8145c9c7ab18c9a0a5cbe1..cb350c269c9090dc143fb72aa084863a85d2edb9 100644 --- a/oneflow/core/common/error.cpp +++ b/oneflow/core/common/error.cpp @@ -1,5 +1,6 @@ #include "oneflow/core/common/error.h" #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/util.h" namespace oneflow { @@ -31,4 +32,27 @@ Error Error::JobTypeNotSet() { return error; } +Error Error::CheckFailed() { + auto error = std::make_shared<ErrorProto>(); + error->mutable_check_failed(); + return error; +} + +Error Error::Todo() { + auto error = std::make_shared<ErrorProto>(); + error->mutable_todo_error(); + return error; +} + +Error Error::Unimplemented() { + auto error = std::make_shared<ErrorProto>(); + error->mutable_unimplemented_error(); + return error; +} + +Error&& operator<=(const std::string& log_str, Error&& error) { + LOG(ERROR) << log_str << error->msg(); + return std::move(error); +} + } // namespace oneflow diff --git a/oneflow/core/common/error.h b/oneflow/core/common/error.h index 5fd3a8abc37d291a0bb194c08dc6bbdfbc6cb291..1dfd60460fdf849854a7b7d654208a9f350b9d97 100644 --- a/oneflow/core/common/error.h +++ b/oneflow/core/common/error.h @@ -17,6 +17,9 @@ class Error final { static Error JobSetEmpty(); static Error DeviceTagNotFound(); static Error JobTypeNotSet(); + static Error CheckFailed(); + static Error Todo(); + static Error Unimplemented(); std::shared_ptr<ErrorProto> error_proto() const { return error_proto_; } ErrorProto* operator->() const { return error_proto_.get(); } @@ -34,6 +37,9 @@ Error&& operator<<(Error&& error, const T& x) { return std::move(error); } +// for LOG(ERROR) +Error&& operator<=(const std::string& log_str, Error&& error); + } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ERROR_H_ diff --git a/oneflow/core/common/error.proto b/oneflow/core/common/error.proto index 54ec48414d33a4098c60eb00d2cb4f179837aa0e..ee234c4eb033275f948d4a5ac2ce9ddd0f9b9e51 100644 --- a/oneflow/core/common/error.proto +++ b/oneflow/core/common/error.proto @@ -62,8 +62,13 @@ enum JobBuildAndInferError { kUnknownJobBuildAndInferError = 500; } -message ProtoParseFailedError { -} +message ProtoParseFailedError { } + +message CheckFailed { } + +message TodoError { } + +message UnimplementedError { } message UnkownError { } @@ -74,6 +79,9 @@ message ErrorProto { ConfigResourceUnavailableError config_resource_unavailable_error = 3; JobBuildAndInferError job_build_and_infer_error = 4; ProtoParseFailedError proto_parse_failed_error = 5; - UnkownError unknown_error = 6; + CheckFailed check_failed = 6; + TodoError todo_error = 7; + UnimplementedError unimplemented_error = 8; + UnkownError unknown_error = 100; } } diff --git a/oneflow/core/common/maybe.h b/oneflow/core/common/maybe.h index a006936d661b06ac74a65b72a7750f9477fc6145..862d052cfc90e56d917bbe9d6b9ffd7551aeccfd 100644 --- a/oneflow/core/common/maybe.h +++ b/oneflow/core/common/maybe.h @@ -4,6 +4,7 @@ #include "oneflow/core/common/util.h" #include "oneflow/core/common/either_ptr.h" #include "oneflow/core/common/error.h" +#include "oneflow/core/common/preprocessor.h" namespace oneflow { @@ -56,7 +57,7 @@ inline Maybe<T> MaybeFuncSafeCallWrapper(Maybe<T>&& maybe) { return maybe; } -#define __MAYBE_CALL_LOC__ __FILE__ ":" OF_PP_STRINGIZE(__LINE__) "\n" +#define __LOC__ __FILE__ ":" OF_PP_STRINGIZE(__LINE__) "\n" #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) @@ -65,7 +66,7 @@ inline Maybe<T> MaybeFuncSafeCallWrapper(Maybe<T>&& maybe) { ({ \ const auto& maybe = MaybeFuncSafeCallWrapper(__VA_ARGS__); \ if (!maybe.IsOk()) { \ - LOG(INFO) << "maybe failed:" << __MAYBE_CALL_LOC__; \ + LOG(INFO) << "maybe failed:" << __LOC__; \ return maybe.error(); \ } \ maybe.data(); \ @@ -83,98 +84,43 @@ inline Maybe<T> MaybeFuncSafeCallWrapper(Maybe<T>&& maybe) { } // namespace oneflow -namespace { +#define OF_CHECK(expr) \ + if (!(expr)) \ + return __LOC__ <= Error::CheckFailed() << " Check failed: " << OF_PP_STRINGIZE(expr) << "\t" -enum class ErrorType { - kUnknown = 0, - kCondition = 1, - kEnforce = 2, -}; - -template<ErrorType type> -std::ostringstream& SerializeExprError(std::ostringstream& oss, const std::string& expr) { - oss << "Unknown type error `" << expr << "` occurs."; - return oss; -} - -template<> -std::ostringstream& SerializeExprError<ErrorType::kCondition>(std::ostringstream& oss, - const std::string& expr) { - oss << "Condition expression `" << expr << "` check failed."; - return oss; -} - -template<> -std::ostringstream& SerializeExprError<ErrorType::kEnforce>(std::ostringstream& oss, - const std::string& expr) { - oss << "Enforce error `" << expr << "` occurs."; - return oss; -} - -std::string Sprintf() { return ""; } - -template<typename... Args> -std::string Sprintf(const Args&... args) { - char buffer[2048]; - snprintf(buffer, sizeof(buffer), std::forward<const Args>(args)...); - return std::string(buffer); -} - -} // namespace - -#define OF_TEST_EQ(lhs, rhs) ((lhs) == (rhs)) -#define OF_TEST_GE(lhs, rhs) ((lhs) >= (rhs)) -#define OF_TEST_GT(lhs, rhs) ((lhs) > (rhs)) -#define OF_TEST_LE(lhs, rhs) ((lhs) <= (rhs)) -#define OF_TEST_LT(lhs, rhs) ((lhs) < (rhs)) -#define OF_TEST_NE(lhs, rhs) ((lhs) != (rhs)) - -#define GEN_ERROR_MSG(type, expr, ...) \ - [&]() -> std::string { \ - std::string detail = Sprintf(__VA_ARGS__); \ - std::ostringstream oss; \ - SerializeExprError<type>(oss, expr); \ - if (!detail.empty()) { oss << " " << detail; } \ - return oss.str(); \ - }() +#define OF_CHECK_NOTNULL(ptr) OF_CHECK(ptr != nullptr) +#define OF_CHECK_ISNULL(ptr) OF_CHECK(ptr == nullptr) +#define OF_CHECK_STREQ(lhs, rhs) OF_CHECK_EQ(std::string(lhs), std::string(rhs)) +#define OF_CHECK_STRNE(lhs, rhs) OF_CHECK_NE(std::string(lhs), std::string(rhs)) -#define CHECK_OR_RETURN(expr, ...) \ - { \ - if (!(expr)) { \ - auto error = std::make_shared<ErrorProto>(); \ - error->set_msg(GEN_ERROR_MSG(ErrorType::kCondition, #expr, __VA_ARGS__)); \ - error->mutable_unknown_error(); \ - return error; \ - } \ - } +#define OF_CHECK_EQ(lhs, rhs) OF_CHECK((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " +#define OF_CHECK_NE(lhs, rhs) OF_CHECK((lhs) != (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " +#define OF_CHECK_GT(lhs, rhs) OF_CHECK((lhs) > (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " +#define OF_CHECK_GE(lhs, rhs) OF_CHECK((lhs) >= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " +#define OF_CHECK_LT(lhs, rhs) OF_CHECK((lhs) < (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " +#define OF_CHECK_LE(lhs, rhs) OF_CHECK((lhs) <= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " -#define CHECK_EQ_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_EQ(lhs, rhs), __VA_ARGS__) +#define OF_TODO() return __LOC__ <= Error::Todo() +#define OF_UNIMPLEMENTED() return __LOC__ <= Error::Unimplemented() -#define CHECK_GE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_GE(lhs, rhs), __VA_ARGS__) +#define CHECK_OR_RETURN(expr) OF_CHECK(expr) -#define CHECK_GT_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_GT(lhs, rhs), __VA_ARGS__) +#define CHECK_EQ_OR_RETURN(lhs, rhs) OF_CHECK_EQ(lhs, rhs) -#define CHECK_LE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_LE(lhs, rhs), __VA_ARGS__) +#define CHECK_GE_OR_RETURN(lhs, rhs) OF_CHECK_GE(lhs, rhs) -#define CHECK_LT_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_LT(lhs, rhs), __VA_ARGS__) +#define CHECK_GT_OR_RETURN(lhs, rhs) OF_CHECK_GT(lhs, rhs) -#define CHECK_NE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_NE(lhs, rhs), __VA_ARGS__) +#define CHECK_LE_OR_RETURN(lhs, rhs) OF_CHECK_LE(lhs, rhs) -#define CHECK_STREQ_OR_RETURN(lhs, rhs, ...) \ - CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs), __VA_ARGS__) +#define CHECK_LT_OR_RETURN(lhs, rhs) OF_CHECK_LT(lhs, rhs) -#define ENFORCE_THEN_RETURN(type, ...) \ - { \ - auto error = std::make_shared<ErrorProto>(); \ - error->set_msg(GEN_ERROR_MSG(ErrorType::kEnforce, #type, __VA_ARGS__)); \ - error->mutable_unknown_error(); \ - return error; \ - } +#define CHECK_NE_OR_RETURN(lhs, rhs) OF_CHECK_NE(lhs, rhs) -#define UNSUPPORTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_UNSUPPORTED, __VA_ARGS__) +#define CHECK_STREQ_OR_RETURN(lhs, rhs) OF_CHECK_STREQ(lhs, rhs) -#define TODO_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_TODO, __VA_ARGS__) +#define TODO_THEN_RETURN() OF_TODO() -#define UNIMPLEMENTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_UNIMPLEMENTED, __VA_ARGS__) +#define UNIMPLEMENTED_THEN_RETURN() OF_UNIMPLEMENTED() #endif // ONEFLOW_CORE_COMMON_MAYBE_H_ diff --git a/oneflow/core/operator/acc_op.cpp b/oneflow/core/operator/acc_op.cpp index 80ca97b3e2c873ff45055b3062e6b88da19a4823..b0a96c69d31ff6a095523f58f4a604af19aa05ae 100644 --- a/oneflow/core/operator/acc_op.cpp +++ b/oneflow/core/operator/acc_op.cpp @@ -20,7 +20,7 @@ Maybe<void> AccOp::InferOutputBlobTimeShape( const ParallelContext* parallel_ctx, Shape* time_shape) const { const int32_t max_acc_num = op_conf().acc_conf().max_acc_num(); // CHECK_GE(GetTimeShape4BnInOp("one")->elem_cnt(), max_acc_num); - CHECK_GE_OR_RETURN(GetTimeShape4BnInOp("one")->elem_cnt(), max_acc_num, ""); + CHECK_GE_OR_RETURN(GetTimeShape4BnInOp("one")->elem_cnt(), max_acc_num); *time_shape = Shape({GetTimeShape4BnInOp("one")->elem_cnt() / max_acc_num}); return Maybe<void>::Ok(); } diff --git a/oneflow/core/operator/conv_bias_grad_op.cpp b/oneflow/core/operator/conv_bias_grad_op.cpp index a359dc94f27f5d254da6339a77d5d4c228544660..cbf036a1bf94025dd32b8bbf6333dc75a6fd04ac 100644 --- a/oneflow/core/operator/conv_bias_grad_op.cpp +++ b/oneflow/core/operator/conv_bias_grad_op.cpp @@ -28,7 +28,7 @@ Maybe<void> ConvBiasGradOp::InferBlobDescs( } else if (conf.data_format() == "channels_last") { bias_diff->mut_shape() = Shape({dy->shape().At(dy->shape().NumAxes() - 1)}); } else { - CHECK_OR_RETURN(false, "UNIMPLEMENTED"); + OF_UNIMPLEMENTED(); } return Maybe<void>::Ok(); } diff --git a/oneflow/core/operator/define_test_blob_op.cpp b/oneflow/core/operator/define_test_blob_op.cpp index b95f0d280aad2e7971525df97af0dffdef234a88..df716604f5d7689b399f53dd8cbf5e6765009e51 100644 --- a/oneflow/core/operator/define_test_blob_op.cpp +++ b/oneflow/core/operator/define_test_blob_op.cpp @@ -31,7 +31,7 @@ Maybe<void> DefineTestBlobOp::InferBlobDescs( if (conf.has_dim0_inner_shape()) { out_blob_desc->mut_dim0_inner_shape() = Shape(conf.dim0_inner_shape()); } - if (conf.has_dim0_valid_num()) { CHECK_OR_RETURN(conf.has_dim0_inner_shape()); } + if (conf.has_dim0_valid_num()) { OF_CHECK(conf.has_dim0_inner_shape()); } return Maybe<void>::Ok(); }