diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp index c572a2885fda464bbf8eca2ca11acac1e02343ab..57d4280f678121eca8c8af36bf2fcb26e4a34f5f 100644 --- a/oneflow/api/python/functional/python_arg.cpp +++ b/oneflow/api/python/functional/python_arg.cpp @@ -70,6 +70,11 @@ Maybe<std::shared_ptr<one::Tensor>> PythonArg::ObjectAs<std::shared_ptr<one::Ten return detail::cast<std::shared_ptr<one::Tensor>>(Borrow()); } +template<> +Maybe<one::Tensor> PythonArg::ObjectAs<one::Tensor>() const { + return *JUST(detail::cast<std::shared_ptr<one::Tensor>>(Borrow())); +} + template<> Maybe<std::shared_ptr<one::TensorTuple>> PythonArg::ObjectAs<std::shared_ptr<one::TensorTuple>>() const { diff --git a/oneflow/api/python/functional/python_arg.h b/oneflow/api/python/functional/python_arg.h index 519f07e07a7eb20150bf863af83cc856f834d2fc..1ac8d231ebc787b371af7d855650d8dda8d307e6 100644 --- a/oneflow/api/python/functional/python_arg.h +++ b/oneflow/api/python/functional/python_arg.h @@ -61,6 +61,21 @@ class PythonArg { virtual ~PythonArg() = default; + template<typename T> + friend class ObjectAsHelper; + + template<typename T> + struct ObjectAsHelper { + Maybe<T> operator()(const PythonArg* self) { return self->ObjectAs<T>(); } + }; + template<typename T> + struct ObjectAsHelper<Optional<T>> { + Maybe<Optional<T>> operator()(const PythonArg* self) { + if (self->object_ == Py_None) { return std::make_shared<Optional<T>>(); } + return std::make_shared<Optional<T>>(JUST(self->ObjectAs<T>())); + } + }; + template<typename T> operator T() const { if (active_tag_ == HAS_IMMEDIATE) { @@ -70,7 +85,7 @@ class PythonArg { return *reinterpret_cast<const T*>(immediate_->Ptr()); } CHECK_EQ_OR_THROW(active_tag_, HAS_OBJECT); - return this->ObjectAs<oneflow::detail::remove_cvref_t<T>>().GetOrThrow(); + return ObjectAsHelper<oneflow::detail::remove_cvref_t<T>>()(this).GetOrThrow(); } private: diff --git a/oneflow/core/autograd/gradient_funcs/conv.cpp b/oneflow/core/autograd/gradient_funcs/conv.cpp index d87ecd8b928744d1de3653835f82bb3f517d302c..8080a93161878ea6b88cc4429aa030ee05b7056a 100644 --- a/oneflow/core/autograd/gradient_funcs/conv.cpp +++ b/oneflow/core/autograd/gradient_funcs/conv.cpp @@ -13,103 +13,93 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_expr.h" -#include "oneflow/core/framework/op_expr_helper.h" -#include "oneflow/core/framework/user_op_conf_trait.h" +#include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { -namespace { - -struct ConvInterpState : public OpExprInterpState { - bool weight_requires_grad = true; - bool input_requires_grad = true; +struct ConvolutionNdInterpState : public OpExprInterpState { + bool input_requires_grad = false; + bool weight_requires_grad = false; + size_t input_index; + size_t weight_index; + + std::string data_format; + std::vector<int32_t> padding_before; + std::vector<int32_t> kernel_size; + std::vector<int32_t> strides; + std::vector<int32_t> dilation_rate; + int32_t groups; }; -class ConvNdGrad : public OpExprGradFunction<ConvInterpState> { +class ConvolutionNd : public OpExprGradFunction<ConvolutionNdInterpState> { public: Maybe<void> Init(const OpExpr& op) override; - Maybe<void> Capture(ConvInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override; - Maybe<void> Apply(const ConvInterpState* ctx, const TensorTuple& out_grads, + Maybe<void> Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: - std::shared_ptr<user_op::UserOpConfTrait> op_trait_; - std::shared_ptr<std::string> data_format_; - std::shared_ptr<std::vector<int32_t>> padding_before_; - std::shared_ptr<std::vector<int32_t>> kernel_size_; - std::shared_ptr<std::vector<int32_t>> strides_; - std::shared_ptr<std::vector<int32_t>> dilation_rate_; - int32_t groups_; - - std::shared_ptr<OpExpr> data_grad_op_; - std::shared_ptr<OpExpr> weight_grad_op_; + AttrMap base_attrs_; }; -Maybe<void> ConvNdGrad::Init(const OpExpr& op) { +Maybe<void> ConvolutionNd::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); - const std::string& op_name = fw_op_expr->op_name(); - op_trait_ = std::make_shared<user_op::UserOpConfTrait>(op_name, fw_op_expr->proto()); - - data_format_ = JUST(op_trait_->GetAttr<std::string>("data_format")); - padding_before_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("padding_before")); - kernel_size_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("kernel_size")); - strides_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("strides")); - dilation_rate_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("dilation_rate")); - groups_ = JUST(op_trait_->GetAttr<int32_t>("groups")); - int32_t ndims = kernel_size_->size(); - CHECK_EQ_OR_RETURN(ndims, strides_->size()); - CHECK_EQ_OR_RETURN(ndims, dilation_rate_->size()); - data_grad_op_ = JUST(op_expr_helper::ConvNdDataGradOp(*kernel_size_, *strides_, *padding_before_, - *dilation_rate_, groups_, *data_format_)); - - weight_grad_op_ = JUST(op_expr_helper::ConvNdFilterGradOp( - *kernel_size_, *strides_, *padding_before_, *dilation_rate_, groups_, *data_format_)); - + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe<void>::Ok(); } -Maybe<void> ConvNdGrad::Capture(ConvInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const { +Maybe<void> ConvolutionNd::Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); - ctx->SaveTensorForBackward(inputs.at(0)); // x + if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe<void>::Ok(); } if (ctx->input_requires_grad) { - ctx->SaveTensorForBackward(inputs.at(1)); // weight + ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight } + ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format")); + ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before")); + ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size")); + ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides")); + ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate")); + ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups")); return Maybe<void>::Ok(); } -Maybe<void> ConvNdGrad::Apply(const ConvInterpState* ctx, const TensorTuple& out_grads, - TensorTuple* in_grads) const { - CHECK_EQ_OR_RETURN(out_grads.size(), 1); - const auto& dy = out_grads.at(0); - +Maybe<void> ConvolutionNd::Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { in_grads->resize(2); + size_t num_spatial_dims = ctx->kernel_size.size(); if (ctx->input_requires_grad) { - const auto& x = ctx->SavedTensors().at(0); - const auto& weight = ctx->SavedTensors().at(1); - in_grads->at(0) = - JUST(OpInterpUtil::Dispatch<Tensor>(*data_grad_op_, {dy, weight, x}, /*attrs=*/{})); + const auto& weight = ctx->SavedTensors().at(ctx->weight_index); + const auto& input = ctx->SavedTensors().at(ctx->input_index); + in_grads->at(0) = JUST(functional::ConvDataGrad( + out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides, + ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } if (ctx->weight_requires_grad) { - const auto& x = ctx->SavedTensors().at(0); - in_grads->at(1) = JUST(OpInterpUtil::Dispatch<Tensor>(*weight_grad_op_, {dy, x}, /*attrs=*/{})); + const auto& input = ctx->SavedTensors().at(ctx->input_index); + in_grads->at(1) = JUST(functional::ConvFilterGrad( + out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides, + ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); } return Maybe<void>::Ok(); } -} // namespace - -REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvNdGrad); -REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvNdGrad); -REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvNdGrad); +REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvolutionNd); +REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvolutionNd); +REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvolutionNd); } // namespace one } // namespace oneflow diff --git a/oneflow/core/common/optional.h b/oneflow/core/common/optional.h new file mode 100644 index 0000000000000000000000000000000000000000..866b6f5c339da118c798455c3d6757baa95454d2 --- /dev/null +++ b/oneflow/core/common/optional.h @@ -0,0 +1,202 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_CORE_COMMON_OPTIONAL_H_ +#define ONEFLOW_CORE_COMMON_OPTIONAL_H_ + +#include "oneflow/core/common/type_traits.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +namespace internal { + +template<typename T, typename U = void> +class Storage; + +template<typename T> +class Storage<T, typename std::enable_if<IsScalarType<T>::value>::type> { + public: + Storage() = default; + + template<typename... Args, + typename std::enable_if<std::is_constructible<T, Args...>::value, int>::type = 0> + Storage(Args&&... args) { + new (&value_) T(std::forward<Args>(args)...); + } + + Storage& operator=(const T& value) { + value_ = value; + return *this; + } + Storage& operator=(T&& value) { + value_ = std::move(value); + return *this; + } + Storage& operator=(const Storage<T>& rhs) { + value_ = rhs.value_; + return *this; + } + Storage& operator=(Storage<T>&& rhs) { + value_ = std::move(rhs.value_); + return *this; + } + + Maybe<T> value() const { return value_; } + + private: + T value_; +}; + +template<typename T> +class Storage<T, typename std::enable_if<!IsScalarType<T>::value>::type> { + public: + Storage() = default; + + template<typename... Args, + typename std::enable_if<std::is_constructible<T, Args...>::value, int>::type = 0> + Storage(Args&&... args) { + value_ = std::make_shared<T>(std::forward<Args>(args)...); + } + + Storage(const std::shared_ptr<T>& value) : value_(value) {} + + Storage& operator=(const T& value) { + if (value_) { + *value_ = value; + } else { + value_ = std::make_shared<T>(value); + } + return *this; + } + Storage& operator=(T&& value) { + if (value_) { + *value_ = std::move(value); + } else { + value_ = std::make_shared<T>(value); + } + return *this; + } + Storage& operator=(const Storage<T>& rhs) { + value_ = rhs.value_; + return *this; + } + Storage& operator=(Storage<T>&& rhs) { + value_ = std::move(rhs.value_); + return *this; + } + + Maybe<T> value() const { return value_; } + + private: + std::shared_ptr<T> value_; +}; + +} // namespace internal + +template<typename T> +class Optional { + public: + Optional() : init_(false) {} + + template<typename... Args, + typename std::enable_if<std::is_constructible<internal::Storage<T>, Args...>::value, + int>::type = 0> + Optional(Args&&... args) : init_(true), storage_(std::forward<Args>(args)...) {} + + ~Optional() = default; + + Optional(const Optional<T>& rhs) : init_(rhs.init_) { + if (init_) { storage_ = rhs.storage_; } + } + + Optional(Optional<T>&& rhs) : init_(rhs.init_) { + if (init_) { storage_ = std::move(rhs.storage_); } + } + + Optional& operator=(const T& val) { + init_ = true; + storage_ = val; + return *this; + } + + Optional& operator=(T&& val) { + init_ = true; + storage_ = std::move(val); + return *this; + } + + Optional& operator=(const Optional<T>& rhs) { + init_ = rhs.init_; + if (init_) { storage_ = rhs.storage_; } + return *this; + } + + Optional& operator=(Optional<T>&& rhs) { + init_ = rhs.init_; + if (init_) { storage_ = std::move(rhs.storage_); } + return *this; + } + + Maybe<T> value() const { + CHECK_OR_RETURN(has_value()) << "Optional has no value."; + return storage_.value(); + } + + bool has_value() const { return init_; } + operator bool() const { return has_value(); } + + private: + bool init_; + internal::Storage<T> storage_; +}; + +template<typename T> +class Optional<T&> { + public: + Optional() : value_ptr_(nullptr) {} + + Optional(T& val) : value_ptr_(&val) {} + + ~Optional() = default; + + Optional& operator=(T& val) { + value_ptr_ = &val; + return *this; + } + + Optional& operator=(const Optional<T&>& rhs) { + value_ptr_ = rhs.value_ptr_; + return *this; + } + + Maybe<T&> value() const { + CHECK_OR_RETURN(has_value()) << "Optional has no value."; + return *value_ptr_; + } + + void Clear() { value_ptr_ = nullptr; } + + bool has_value() const { return value_ptr_ != nullptr; } + operator bool() const { return has_value(); } + + private: + T* value_ptr_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_OPTIONAL_H_ diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 24884e18f2d95f746e50669dda70490d3d18a4e2..2175ce0598bfe6e7443e02b89f3f875361065c48 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -238,6 +238,31 @@ signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)" bind_python: True +- name: "conv2d" + signature: + "Tensor Conv2D(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, + Int32List padding, Int32List dilation, Int32 groups=1)" + bind_python: True + +- name: "conv_data_grad" + signature: + "Tensor ConvDataGrad(Tensor dy, Tensor weight, Tensor x, *, Int32 num_spatial_dims, + Int32List kernel_size, Int32List strides, Int32List padding_before, + Int32List dilation_rate, Int32 groups=1, + String data_format=\"channels_first\")" + bind_python: False + +- name: "conv_filter_grad" + signature: + "Tensor ConvFilterGrad(Tensor dy, Tensor x, *, Int32 num_spatial_dims, Int32List kernel_size, + Int32List strides, Int32List padding_before, Int32List dilation_rate, + Int32 groups=1, String data_format=\"channels_first\")" + bind_python: False + +- name: "conv_bias_grad" + signature: "Tensor ConvBiasGrad(Tensor dy, *, Int32 num_spatial_dims, String data_format=\"channels_first\")" + bind_python: False + - name: "expand" signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 3fb8358a724c28ede4b62b1f482fc10f5ec8a7cb..4c3606a0e9048c2d63811f9fec246e065d0bfda7 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/optional.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" @@ -47,6 +48,44 @@ class BiasAddFunctor { std::shared_ptr<OpExpr> op_; }; +class Conv2DFunctor { + public: + Conv2DFunctor() { + conv_op_ = + CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build()); + bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, + const std::shared_ptr<one::Tensor>& weight, + const Optional<one::Tensor>& bias, const std::vector<int32_t>& stride, + const std::vector<int32_t>& padding, + const std::vector<int32_t>& dilation, const int32_t& groups) const { + MutableAttrMap conv_attrs; + std::vector<int32_t> kernel_size_vec; + for (int i = 0; i < 2; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); } + JUST(conv_attrs.SetAttr<int32_t>("filters", (weight->shape())->At(0))); + JUST(conv_attrs.SetAttr<std::vector<int32_t>>("padding_before", padding)); + JUST(conv_attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size_vec)); + JUST(conv_attrs.SetAttr<std::vector<int32_t>>("strides", stride)); + JUST(conv_attrs.SetAttr<std::vector<int32_t>>("dilation_rate", dilation)); + JUST(conv_attrs.SetAttr<int32_t>("groups", groups)); + JUST(conv_attrs.SetAttr<std::string>("data_format", std::string("channels_first"))); + const std::shared_ptr<one::Tensor>& conv_out = + JUST(OpInterpUtil::Dispatch<Tensor>(*conv_op_, {x, weight}, conv_attrs)); + if (bias) { + MutableAttrMap bias_attrs; + JUST(bias_attrs.SetAttr<int32_t>("axis", 1)); + return OpInterpUtil::Dispatch<Tensor>(*bias_op_, {conv_out, JUST(bias.value())}, bias_attrs); + } else { + return conv_out; + } + } + + private: + std::shared_ptr<OpExpr> conv_op_; + std::shared_ptr<OpExpr> bias_op_; +}; + class MatMulBaseFunctor { public: MatMulBaseFunctor() = default; @@ -217,6 +256,7 @@ class NormalizationFunctor { ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::BiasAddFunctor>("BiasAdd"); + m.add_functor<impl::Conv2DFunctor>("Conv2D"); m.add_functor<impl::MatMulFunctor>("MatMul"); m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul"); m.add_functor<impl::BroadcastMatMulFunctor>("BroadcastMatMul"); diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6892ea4669ae5e02b9e2e8079748b0332e410e23 --- /dev/null +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -0,0 +1,122 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/functional/impl/unary_functor.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class ConvBiasGradFunctor { + public: + ConvBiasGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("conv_bias_grad").Input("dy").Output("bias_diff").Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, const int32_t& num_spatial_dims, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("num_spatial_dims", num_spatial_dims)); + JUST(attrs.SetAttr<std::string>("data_format", data_format)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class ConvFilterGradFunctor { + public: + ConvFilterGradFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("conv_filter_grad").Input("dy").Input("x").Output("filter_diff").Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, + const std::shared_ptr<one::Tensor>& x, const int32_t& num_spatial_dims, + const std::vector<int32_t>& kernel_size, + const std::vector<int32_t>& strides, + const std::vector<int32_t>& padding_before, + const std::vector<int32_t>& dilation_rate, const int32_t& groups, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("num_spatial_dims", num_spatial_dims)); + JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size)); + JUST(attrs.SetAttr<std::vector<int32_t>>("strides", strides)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before)); + JUST(attrs.SetAttr<std::vector<int32_t>>("dilation_rate", dilation_rate)); + JUST(attrs.SetAttr<int32_t>("groups", groups)); + JUST(attrs.SetAttr<std::string>("data_format", data_format)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +class ConvDataGradFunctor { + public: + ConvDataGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("conv_data_grad") + .Input("dy") + .Input("filter") + .Input("x_like") + .Output("dx") + .Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, + const std::shared_ptr<one::Tensor>& weight, + const std::shared_ptr<one::Tensor>& x, const int32_t& num_spatial_dims, + const std::vector<int32_t>& kernel_size, + const std::vector<int32_t>& strides, + const std::vector<int32_t>& padding_before, + const std::vector<int32_t>& dilation_rate, const int32_t& groups, + const std::string& data_format) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<int32_t>("num_spatial_dims", num_spatial_dims)); + JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size)); + JUST(attrs.SetAttr<std::vector<int32_t>>("strides", strides)); + JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before)); + JUST(attrs.SetAttr<std::vector<int32_t>>("dilation_rate", dilation_rate)); + JUST(attrs.SetAttr<int32_t>("groups", groups)); + JUST(attrs.SetAttr<std::string>("data_format", data_format)); + return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, weight, x}, attrs); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor<impl::ConvBiasGradFunctor>("ConvBiasGrad"); + m.add_functor<impl::ConvFilterGradFunctor>("ConvFilterGrad"); + m.add_functor<impl::ConvDataGradFunctor>("ConvDataGrad"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h index cd8904face4017123eedb497a2c3560de6e94a10..f47035a8e1dc3c0ff76a8dad69428bae48264474 100644 --- a/oneflow/core/functional/value_types.h +++ b/oneflow/core/functional/value_types.h @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" namespace oneflow { class Shape; @@ -79,6 +80,11 @@ enum ValueType { #define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \ template<typename T, typename std::enable_if<std::is_same<T, cpp_type>::value, int>::type = 0> \ + inline ValueType ValueTypeOf() { \ + return value_type; \ + } \ + template<typename T, \ + typename std::enable_if<std::is_same<T, Optional<cpp_type>>::value, int>::type = 0> \ inline ValueType ValueTypeOf() { \ return value_type; \ } diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py index 069e94b6a61fd5ea8eb8d80dcc7e733cfb2b5751..c1ab45fc754fe4dc5aeaafaf9b60950f27a4a47f 100644 --- a/oneflow/python/nn/modules/conv.py +++ b/oneflow/python/nn/modules/conv.py @@ -209,60 +209,22 @@ class Conv2d(Module): super().__init__() assert padding_mode == "zeros" - kernel_size = _pair(kernel_size) - self.kernel_size = kernel_size - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 + self.in_channels = in_channels self.out_channels = out_channels self.weight = flow.nn.Parameter( - flow.Tensor(out_channels, in_channels // groups, *kernel_size) + flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) ) self.out_channel_groups = out_channels // groups self.bias = None - self._bias_add_op = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) - self._bias_add_op = ( - flow.builtin_op("bias_add") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis", 1) - .Build() - ) - - self._op = ( - flow.builtin_op("conv2d") - .Input("in") - .Input("weight") - .Attr("filters", out_channels) - .Attr("padding_before", padding) - .Attr("strides", stride) - .Attr("kernel_size", kernel_size) - .Attr("dilation_rate", dilation) - .Attr("groups", groups) - .Attr("data_format", "channels_first") - .Output("out") - .Build() - ) - self._cpu_op = ( - flow.builtin_op("conv2d") - .Input("in") - .Input("weight") - .Attr("filters", out_channels // groups) - .Attr("padding_before", padding) - .Attr("strides", stride) - .Attr("kernel_size", kernel_size) - .Attr("dilation_rate", dilation) - .Attr("groups", 1) - .Attr("data_format", "channels_first") - .Output("out") - .Build() - ) self.reset_parameters() def reset_parameters(self) -> None: @@ -273,6 +235,8 @@ class Conv2d(Module): init.uniform_(self.bias, -bound, bound) def forward(self, x): + if x.shape[1] != self.in_channels: + raise ValueError("The input channels should be equal to self.in_channels") if x.device.type == "cpu" and self.groups > 1: in_channel_axis = 1 in_split_list = ConvUtil.split( @@ -281,7 +245,7 @@ class Conv2d(Module): out_list = [] for i in range(len(in_split_list)): out_list.append( - self._cpu_op( + flow.F.conv2d( in_split_list[i], self.weight[ i @@ -291,14 +255,30 @@ class Conv2d(Module): :, :, ], - )[0] + self.bias[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups + ] + if self.bias + else None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=1, + ) ) res = flow.experimental.cat(out_list, dim=in_channel_axis) else: - res = self._op(x, self.weight)[0] - - if self._bias_add_op is not None: - res = self._bias_add_op(res, self.bias)[0] + res = flow.F.conv2d( + x, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) return res diff --git a/oneflow/python/test/modules/test_conv.py b/oneflow/python/test/modules/test_conv.py index 9057dc77d3c36b1f9a891645a9c569af330dfd51..4f2e3e73b6b88fbd050d5a9f54ce4429286aa896 100644 --- a/oneflow/python/test/modules/test_conv.py +++ b/oneflow/python/test/modules/test_conv.py @@ -1478,6 +1478,119 @@ def _test_conv2d_large_in_channel(test_case, device): test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) +def _test_conv2d_large_out_channel(test_case, device): + np_arr = np.array( + [ + [ + [ + [0.56573248, -0.19689320, -0.67875558, 0.34328273, 0.31964567], + [-1.33715475, 0.33422229, -1.27643383, 0.37904647, 0.35891593], + [0.84579802, 2.12729621, -0.51423287, 0.61297560, -1.31156564], + [-0.71047139, 1.02679253, -0.76686019, -0.72969633, 0.73425150], + [-0.13592879, -1.03207183, -0.22554775, 0.74148071, 0.96601510], + ], + [ + [0.51595992, 0.49624804, 0.91145641, 0.49247262, 0.41002217], + [-1.08001196, 1.55497086, -0.81963140, -0.45511565, -0.60269165], + [0.05563145, -0.94318372, -1.17058158, -0.73568577, 0.57810956], + [-0.40260276, -0.10309298, 1.12378800, -0.23510537, -0.73893374], + [-0.52712536, -0.00717016, -1.85051966, -1.50790560, 1.38335907], + ], + ] + ] + ) + input = flow.Tensor( + np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + weight = np.array( + [ + [ + [ + [-0.19489679, -0.32377058, 0.21736273], + [0.04095296, -0.21552679, -0.14626531], + [-0.19359522, -0.00742865, -0.19832158], + ] + ], + [ + [ + [0.29926914, 0.00931164, 0.26197660], + [0.27611443, -0.15439281, -0.19027126], + [-0.28909120, 0.30367029, -0.05168664], + ] + ], + [ + [ + [-0.03155736, 0.17610769, 0.22111714], + [0.22790670, -0.32897446, -0.03260243], + [-0.10274851, -0.06903386, -0.19438276], + ] + ], + [ + [ + [-0.24573688, -0.06723209, -0.21363299], + [-0.02136187, -0.24994437, -0.18691199], + [0.12189507, 0.29469389, 0.03398871], + ] + ], + ] + ) + m = flow.nn.Conv2d(2, 4, 3, groups=2, bias=False) + m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True) + m = m.to(device) + output = m(input) + print(output) + np_out = np.array( + [ + [ + [ + [-0.21170563, 0.03652292, 0.25926736], + [-0.19168918, 0.49044561, 0.25099146], + [-1.02489340, 0.25361472, -0.51828313], + ], + [ + [0.23977707, -0.56090075, -0.19285655], + [-0.17167747, 0.24558367, -0.30935860], + [-0.33303234, 1.52472734, -0.49013454], + ], + [ + [-0.17137986, 1.21333742, 0.18988736], + [0.31785482, -0.12121570, -0.18676008], + [-0.10680684, -0.30298883, 0.41809759], + ], + [ + [-0.87821335, -0.51665992, -0.44061098], + [0.74804580, 0.53107250, 0.50418228], + [-0.00512899, -0.36455840, -0.23643512], + ], + ] + ] + ) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6)) + output = output.sum() + output.backward() + np_grad = np.array( + [ + [ + [ + [0.10437235, -0.21008658, 0.26925275, 0.16488039, 0.47933933], + [0.42143974, -0.26293880, -0.12013602, -0.54157579, 0.14280275], + [-0.06124666, -0.44938356, -0.55658901, -0.49534237, -0.10720548], + [-0.16561902, -0.23929697, -0.82584178, -0.66022277, -0.58654481], + [-0.48268640, -0.18644476, -0.43645298, 0.04623342, -0.25000823], + ], + [ + [-0.27729425, -0.16841865, -0.16093449, 0.11635975, 0.00748415], + [-0.07074942, -0.54079264, -0.75282294, -0.68207347, -0.21203026], + [-0.05160286, -0.29598606, -0.66841042, -0.61680746, -0.37242430], + [0.22569139, -0.12756741, -0.50747585, -0.73316729, -0.37990844], + [0.01914656, 0.24480659, 0.08441254, 0.06526598, -0.16039404], + ], + ] + ] + ) + test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6)) + + @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", @@ -1677,7 +1790,7 @@ class TestConv2d(flow.unittest.TestCase): device=device, ) - def test_large_channel_group_conv(test_case): + def test_large_in_channel_group_conv(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_conv2d_large_in_channel, @@ -1686,6 +1799,15 @@ class TestConv2d(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + def test_large_out_channel_group_conv(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_conv2d_large_out_channel, + ] + arg_dict["device"] = ["cuda", "cpu"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + if __name__ == "__main__": unittest.main() diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index a095e2404410c9b1fdd739c7dd3056509d35c339..c32975f92765a2acbfc7d24a60765378ba8e8f9f 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -149,8 +149,8 @@ class ConvGpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( - user_op::KernelInitContext* ctx) const { + std::shared_ptr<ConvCudnnOpKernelState> CreateConvCudnnOpKernelState( + user_op::KernelComputeContext* ctx) const { const auto& data_format = ctx->Attr<std::string>("data_format"); int32_t filters = ctx->Attr<int32_t>("filters"); @@ -166,7 +166,7 @@ class ConvGpuKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -185,8 +185,8 @@ class ConvGpuKernel final : public user_op::OpKernel { const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (bias != nullptr) { - ConvCudnnOpKernelState* conv_state = dynamic_cast<ConvCudnnOpKernelState*>(state); - CHECK_NOTNULL(conv_state); + const auto& conv_state = CreateConvCudnnOpKernelState(ctx); + CHECK_NOTNULL(conv_state.get()); OF_CUDNN_CHECK(cudnnAddTensor(ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr<T>(), conv_state->bias_desc->Get(), bias->dptr<T>(), CudnnSPOnePtr<T>(), args.ydesc.Get(), out->mut_dptr<T>())); @@ -352,8 +352,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( - user_op::KernelInitContext* ctx) const { + std::shared_ptr<ConvBiasGradState> CreateConvBiasGradState( + user_op::KernelComputeContext* ctx) const { const auto* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0); const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); const auto& data_format = ctx->Attr<std::string>("data_format"); @@ -375,7 +375,7 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex("bias_diff", 0); CHECK_EQ(bias_diff->shape().NumAxes(), 1); @@ -386,8 +386,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel { std::unique_ptr<CudnnTensorDesc> dy_desc; dy_desc.reset(new CudnnTensorDesc(dy->data_type(), dy->shape(), data_format)); - auto* bias_grad_state = dynamic_cast<ConvBiasGradState*>(state); - CHECK_NOTNULL(bias_grad_state); + const auto& bias_grad_state = CreateConvBiasGradState(ctx); + CHECK_NOTNULL(bias_grad_state.get()); OF_CUDNN_CHECK(cudnnConvolutionBackwardBias( ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr<T>(), dy_desc->Get(), dy->dptr<T>(), CudnnSPZeroPtr<T>(), bias_grad_state->bias_diff_desc->Get(), bias_diff->mut_dptr<T>())); diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp index 6c5a857347499abd34262c63a4073fb1cd6d5667..bb01035f37773297375fc30c4b818ba9801a6721 100644 --- a/oneflow/user/kernels/conv_kernels.cpp +++ b/oneflow/user/kernels/conv_kernels.cpp @@ -326,10 +326,10 @@ struct ConvOpKernelState final : public user_op::OpKernelState { }; template<typename T> -std::shared_ptr<user_op::OpKernelState> CreateConvOpKernelState(user_op::KernelInitContext* ctx, - const std::string& in_name, - const std::string& out_name, - const std::string& weight_name) { +std::shared_ptr<ConvOpKernelState<T>> CreateConvOpKernelState(user_op::KernelComputeContext* ctx, + const std::string& in_name, + const std::string& out_name, + const std::string& weight_name) { const auto& data_format = ctx->Attr<std::string>("data_format"); std::shared_ptr<ConvOpKernelState<T>> state(new ConvOpKernelState<T>()); @@ -394,13 +394,11 @@ class ConvCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( - user_op::KernelInitContext* ctx) const { - return CreateConvOpKernelState<T>(ctx, "in", "out", "weight"); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx) const override { + const auto& conv_state = CreateConvOpKernelState<T>(ctx, "in", "out", "weight"); + CHECK_NOTNULL(conv_state.get()); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -408,9 +406,6 @@ class ConvCpuKernel final : public user_op::OpKernel { T* col_buf_dptr = tmp_buffer->mut_dptr<T>(); - auto* conv_state = dynamic_cast<ConvOpKernelState<T>*>(state); - conv_state->Update(in->shape(), out->shape()); - CHECK_NOTNULL(conv_state); bool is_bias_mul_inited = false; for (int64_t i = 0; i < in->shape().At(0); ++i) { conv_state->im2col_func_(GetImgDptr<T>(in, i), ShapeView(conv_state->in_5d_shape_), @@ -495,20 +490,16 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( - user_op::KernelInitContext* ctx) const { - return CreateConvOpKernelState<T>(ctx, "dx", "dy", "filter"); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* conv_state = dynamic_cast<ConvOpKernelState<T>*>(state); - CHECK_NOTNULL(conv_state); + void Compute(user_op::KernelComputeContext* ctx) const override { + const auto& conv_state = CreateConvOpKernelState<T>(ctx, "dx", "dy", "filter"); + CHECK_NOTNULL(conv_state.get()); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - conv_state->Update(dx->shape(), dy->shape()); + Memset<DeviceType::kCPU>(ctx->device_ctx(), dx->mut_dptr<T>(), 0, dx->shape().elem_cnt() * sizeof(T)); @@ -571,21 +562,15 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( - user_op::KernelInitContext* ctx) const { - return CreateConvOpKernelState<T>(ctx, "x", "dy", "filter_diff"); - } - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* conv_state = dynamic_cast<ConvOpKernelState<T>*>(state); - CHECK_NOTNULL(conv_state); + void Compute(user_op::KernelComputeContext* ctx) const override { + const auto& conv_state = CreateConvOpKernelState<T>(ctx, "x", "dy", "filter_diff"); + CHECK_NOTNULL(conv_state.get()); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - conv_state->Update(x->shape(), dy->shape()); Memset<DeviceType::kCPU>(ctx->device_ctx(), filter_diff->mut_dptr<T>(), 0, filter_diff->shape().elem_cnt() * sizeof(T)); diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py index 4164dfc3b6f78196b7786d2d2cb0c474dd3d1340..409858934ad92125077a2a9d3836bf3c53db4233 100644 --- a/tools/generate_functional_api.py +++ b/tools/generate_functional_api.py @@ -61,6 +61,7 @@ header_fmt = ( #ifndef ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_ #define ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_ +#include "oneflow/core/common/optional.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/functional/scalar.h" @@ -104,6 +105,7 @@ pybind_fmt = ( #include "oneflow/api/python/functional/function_def.h" #include "oneflow/api/python/functional/py_function.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/functional/functional.h" namespace oneflow {{ @@ -174,6 +176,24 @@ argument_type_aliases = { **generic_type_aliases, } +optional_argument_type_aliases = { + "Tensor": "const Optional<one::Tensor>&", + "TensorTuple": "const Optional<TensorTuple>&", + "Scalar": "const Optional<Scalar>&", + "ScalarList": "const Optional<std::vector<Scalar>>&", + "IntList": "const Optional<std::vector<int32_t>>&", + "Int32List": "const Optional<std::vector<int32_t>>&", + "Int64List": "const Optional<std::vector<int64_t>>&", + "FloatList": "const Optional<std::vector<float>>&", + "DoubleList": "const Optional<std::vector<double>>&", + "String": "const Optional<std::string>&", + "StringList": "const Optional<std::vector<std::string>>&", + "BoolList": "const Optional<std::vector<bool>>&", + "DataType": "const Optional<DataType>&", + "Shape": "const Optional<Shape>&", + **{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()}, +} + return_type_aliases = { "Void": "Maybe<void>", "Tensor": "Maybe<one::Tensor>", @@ -254,20 +274,27 @@ class Argument: self._type = _normalize(fmt[0:sp]) assert self._type in types_allowed, "Unknow type: " + self._type - if self._type in argument_type_aliases: - self._cpp_type = argument_type_aliases[self._type] - else: - self._cpp_type = self._type + optional = False self._name = _normalize(fmt[sp + 1 :]) sp = self._name.find("=") if sp != -1: self._default_value = _normalize(self._name[sp + 1 :]) - if self._default_value in value_aliases: + if self._default_value == "None": + optional = True + self._default_cpp_value = "" + elif self._default_value in value_aliases: self._default_cpp_value = value_aliases[self._default_value] else: self._default_cpp_value = self._default_value self._name = _normalize(self._name[0:sp]) + if not optional and self._type in argument_type_aliases: + self._cpp_type = argument_type_aliases[self._type] + elif optional and self._type in optional_argument_type_aliases: + self._cpp_type = optional_argument_type_aliases[self._type] + else: + self._cpp_type = self._type + @property def has_default_value(self): return self._default_value is not None