From 66ca02c15bb54155847c5b1596015390fd65ab74 Mon Sep 17 00:00:00 2001 From: Houjiang Chen <chenhoujiangcug@gmail.com> Date: Mon, 12 Jul 2021 09:35:46 +0800 Subject: [PATCH] Throw exception if check failed (#5457) * Throw exception if check failed. * Fix undefined symbol --- oneflow/api/python/functional/py_function.h | 7 ++++--- oneflow/api/python/functional/unpack_call.h | 7 ++++--- tools/generate_functional_api.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/oneflow/api/python/functional/py_function.h b/oneflow/api/python/functional/py_function.h index acca9658f..64d415ea3 100644 --- a/oneflow/api/python/functional/py_function.h +++ b/oneflow/api/python/functional/py_function.h @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/unpack_call.h" +#include "oneflow/api/python/framework/throw.h" namespace py = pybind11; @@ -27,9 +28,9 @@ namespace functional { template<typename SchemaT> inline py::object PyFunction(py::args args, py::kwargs kwargs) { // TODO(): Support multiple function signatures. - CHECK_LE(args.size(), SchemaT::max_positionals) + CHECK_LE_OR_THROW(args.size(), SchemaT::max_positionals) << "The maximum count of positional arguments is " << SchemaT::max_positionals; - CHECK_LE(kwargs.size(), SchemaT::max_keywords) + CHECK_LE_OR_THROW(kwargs.size(), SchemaT::max_keywords) << "The maximum count of keyword arguments is " << SchemaT::max_keywords; // TODO(): Check argument types. @@ -40,7 +41,7 @@ inline py::object PyFunction(py::args args, py::kwargs kwargs) { if (kwargs.contains(arg.name.c_str())) { _args[i] = PythonArg(kwargs[arg.name.c_str()]); } else { - CHECK(arg.has_default_value) + CHECK_OR_THROW(arg.has_default_value) << "Argument " << arg.name << " is required, and the function def is \"" << SchemaT::signature << "\"."; _args[i] = PythonArg(arg.default_value); diff --git a/oneflow/api/python/functional/unpack_call.h b/oneflow/api/python/functional/unpack_call.h index 7ee6e4eee..2c7effa11 100644 --- a/oneflow/api/python/functional/unpack_call.h +++ b/oneflow/api/python/functional/unpack_call.h @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/api/python/functional/python_arg.h" #include <tuple> +#include "oneflow/api/python/framework/throw.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/common/function_traits.h" @@ -52,8 +53,8 @@ template<typename F, typename R> struct unpack_call { static R apply(const F& f, const std::vector<PythonArg>& args) { constexpr size_t nargs = function_traits<F>::nargs; - CHECK_EQ(nargs, args.size()) << "Requires " << nargs << " arguments, but " << args.size() - << " is given."; + CHECK_EQ_OR_THROW(nargs, args.size()) + << "Requires " << nargs << " arguments, but " << args.size() << " is given."; return unpack_call_dispatcher<F, R, nargs, 0>::apply(f, args); } }; @@ -63,7 +64,7 @@ struct unpack_call { struct unpack_call<F, K> { \ static R apply(const F& f, const std::vector<PythonArg>& args) { \ constexpr size_t nargs = function_traits<F>::nargs; \ - CHECK_EQ(nargs, args.size()) \ + CHECK_EQ_OR_THROW(nargs, args.size()) \ << "Requires " << nargs << " arguments, but " << args.size() << " is given."; \ return (return_fn)(unpack_call_dispatcher<F, K, nargs, 0>::apply(f, args)); \ } \ diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py index 97a7598e8..8a583988c 100644 --- a/tools/generate_functional_api.py +++ b/tools/generate_functional_api.py @@ -456,6 +456,18 @@ class FunctionalGenerator: schema_fmt += " static std::vector<ArgumentDef> argument_def;\n" schema_fmt += "};\n" schema_fmt += "\n" + schema_fmt += "constexpr size_t {0}Schema::max_args;\n".format( + signature._name + ) + schema_fmt += "constexpr size_t {0}Schema::max_positionals;\n".format( + signature._name + ) + schema_fmt += "constexpr size_t {0}Schema::max_keywords;\n".format( + signature._name + ) + schema_fmt += "constexpr char const* {0}Schema::signature;\n".format( + signature._name + ) schema_fmt += "ReturnDef {0}Schema::return_def = ReturnDef(ValueTypeOf<{1}>());\n".format( signature._name, return_type, ) -- GitLab