Skip to content
Snippets Groups Projects
Unverified Commit 66ca02c1 authored by Houjiang Chen's avatar Houjiang Chen Committed by GitHub
Browse files

Throw exception if check failed (#5457)

* Throw exception if check failed.

* Fix undefined symbol
parent 17482025
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@ limitations under the License.
#include "oneflow/api/python/functional/python_arg.h"
#include "oneflow/api/python/functional/unpack_call.h"
#include "oneflow/api/python/framework/throw.h"
namespace py = pybind11;
......@@ -27,9 +28,9 @@ namespace functional {
template<typename SchemaT>
inline py::object PyFunction(py::args args, py::kwargs kwargs) {
// TODO(): Support multiple function signatures.
CHECK_LE(args.size(), SchemaT::max_positionals)
CHECK_LE_OR_THROW(args.size(), SchemaT::max_positionals)
<< "The maximum count of positional arguments is " << SchemaT::max_positionals;
CHECK_LE(kwargs.size(), SchemaT::max_keywords)
CHECK_LE_OR_THROW(kwargs.size(), SchemaT::max_keywords)
<< "The maximum count of keyword arguments is " << SchemaT::max_keywords;
// TODO(): Check argument types.
......@@ -40,7 +41,7 @@ inline py::object PyFunction(py::args args, py::kwargs kwargs) {
if (kwargs.contains(arg.name.c_str())) {
_args[i] = PythonArg(kwargs[arg.name.c_str()]);
} else {
CHECK(arg.has_default_value)
CHECK_OR_THROW(arg.has_default_value)
<< "Argument " << arg.name << " is required, and the function def is \""
<< SchemaT::signature << "\".";
_args[i] = PythonArg(arg.default_value);
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/api/python/functional/python_arg.h"
#include <tuple>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/common/function_traits.h"
......@@ -52,8 +53,8 @@ template<typename F, typename R>
struct unpack_call {
static R apply(const F& f, const std::vector<PythonArg>& args) {
constexpr size_t nargs = function_traits<F>::nargs;
CHECK_EQ(nargs, args.size()) << "Requires " << nargs << " arguments, but " << args.size()
<< " is given.";
CHECK_EQ_OR_THROW(nargs, args.size())
<< "Requires " << nargs << " arguments, but " << args.size() << " is given.";
return unpack_call_dispatcher<F, R, nargs, 0>::apply(f, args);
}
};
......@@ -63,7 +64,7 @@ struct unpack_call {
struct unpack_call<F, K> { \
static R apply(const F& f, const std::vector<PythonArg>& args) { \
constexpr size_t nargs = function_traits<F>::nargs; \
CHECK_EQ(nargs, args.size()) \
CHECK_EQ_OR_THROW(nargs, args.size()) \
<< "Requires " << nargs << " arguments, but " << args.size() << " is given."; \
return (return_fn)(unpack_call_dispatcher<F, K, nargs, 0>::apply(f, args)); \
} \
......
......@@ -456,6 +456,18 @@ class FunctionalGenerator:
schema_fmt += " static std::vector<ArgumentDef> argument_def;\n"
schema_fmt += "};\n"
schema_fmt += "\n"
schema_fmt += "constexpr size_t {0}Schema::max_args;\n".format(
signature._name
)
schema_fmt += "constexpr size_t {0}Schema::max_positionals;\n".format(
signature._name
)
schema_fmt += "constexpr size_t {0}Schema::max_keywords;\n".format(
signature._name
)
schema_fmt += "constexpr char const* {0}Schema::signature;\n".format(
signature._name
)
schema_fmt += "ReturnDef {0}Schema::return_def = ReturnDef(ValueTypeOf<{1}>());\n".format(
signature._name, return_type,
)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment