From 66ca02c15bb54155847c5b1596015390fd65ab74 Mon Sep 17 00:00:00 2001
From: Houjiang Chen <chenhoujiangcug@gmail.com>
Date: Mon, 12 Jul 2021 09:35:46 +0800
Subject: [PATCH] Throw exception if check failed (#5457)

* Throw exception if check failed.

* Fix undefined symbol
---
 oneflow/api/python/functional/py_function.h |  7 ++++---
 oneflow/api/python/functional/unpack_call.h |  7 ++++---
 tools/generate_functional_api.py            | 12 ++++++++++++
 3 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/oneflow/api/python/functional/py_function.h b/oneflow/api/python/functional/py_function.h
index acca9658f..64d415ea3 100644
--- a/oneflow/api/python/functional/py_function.h
+++ b/oneflow/api/python/functional/py_function.h
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "oneflow/api/python/functional/python_arg.h"
 #include "oneflow/api/python/functional/unpack_call.h"
+#include "oneflow/api/python/framework/throw.h"
 
 namespace py = pybind11;
 
@@ -27,9 +28,9 @@ namespace functional {
 template<typename SchemaT>
 inline py::object PyFunction(py::args args, py::kwargs kwargs) {
   // TODO(): Support multiple function signatures.
-  CHECK_LE(args.size(), SchemaT::max_positionals)
+  CHECK_LE_OR_THROW(args.size(), SchemaT::max_positionals)
       << "The maximum count of positional arguments is " << SchemaT::max_positionals;
-  CHECK_LE(kwargs.size(), SchemaT::max_keywords)
+  CHECK_LE_OR_THROW(kwargs.size(), SchemaT::max_keywords)
       << "The maximum count of keyword arguments is " << SchemaT::max_keywords;
 
   // TODO(): Check argument types.
@@ -40,7 +41,7 @@ inline py::object PyFunction(py::args args, py::kwargs kwargs) {
     if (kwargs.contains(arg.name.c_str())) {
       _args[i] = PythonArg(kwargs[arg.name.c_str()]);
     } else {
-      CHECK(arg.has_default_value)
+      CHECK_OR_THROW(arg.has_default_value)
           << "Argument " << arg.name << " is required, and the function def is \""
           << SchemaT::signature << "\".";
       _args[i] = PythonArg(arg.default_value);
diff --git a/oneflow/api/python/functional/unpack_call.h b/oneflow/api/python/functional/unpack_call.h
index 7ee6e4eee..2c7effa11 100644
--- a/oneflow/api/python/functional/unpack_call.h
+++ b/oneflow/api/python/functional/unpack_call.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include "oneflow/api/python/functional/python_arg.h"
 
 #include <tuple>
+#include "oneflow/api/python/framework/throw.h"
 #include "oneflow/core/framework/tensor.h"
 #include "oneflow/core/framework/tensor_tuple.h"
 #include "oneflow/core/common/function_traits.h"
@@ -52,8 +53,8 @@ template<typename F, typename R>
 struct unpack_call {
   static R apply(const F& f, const std::vector<PythonArg>& args) {
     constexpr size_t nargs = function_traits<F>::nargs;
-    CHECK_EQ(nargs, args.size()) << "Requires " << nargs << " arguments, but " << args.size()
-                                 << " is given.";
+    CHECK_EQ_OR_THROW(nargs, args.size())
+        << "Requires " << nargs << " arguments, but " << args.size() << " is given.";
     return unpack_call_dispatcher<F, R, nargs, 0>::apply(f, args);
   }
 };
@@ -63,7 +64,7 @@ struct unpack_call {
   struct unpack_call<F, K> {                                                            \
     static R apply(const F& f, const std::vector<PythonArg>& args) {                    \
       constexpr size_t nargs = function_traits<F>::nargs;                               \
-      CHECK_EQ(nargs, args.size())                                                      \
+      CHECK_EQ_OR_THROW(nargs, args.size())                                             \
           << "Requires " << nargs << " arguments, but " << args.size() << " is given."; \
       return (return_fn)(unpack_call_dispatcher<F, K, nargs, 0>::apply(f, args));       \
     }                                                                                   \
diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py
index 97a7598e8..8a583988c 100644
--- a/tools/generate_functional_api.py
+++ b/tools/generate_functional_api.py
@@ -456,6 +456,18 @@ class FunctionalGenerator:
             schema_fmt += "  static std::vector<ArgumentDef> argument_def;\n"
             schema_fmt += "};\n"
             schema_fmt += "\n"
+            schema_fmt += "constexpr size_t {0}Schema::max_args;\n".format(
+                signature._name
+            )
+            schema_fmt += "constexpr size_t {0}Schema::max_positionals;\n".format(
+                signature._name
+            )
+            schema_fmt += "constexpr size_t {0}Schema::max_keywords;\n".format(
+                signature._name
+            )
+            schema_fmt += "constexpr char const* {0}Schema::signature;\n".format(
+                signature._name
+            )
             schema_fmt += "ReturnDef {0}Schema::return_def = ReturnDef(ValueTypeOf<{1}>());\n".format(
                 signature._name, return_type,
             )
-- 
GitLab