Skip to content
Snippets Groups Projects
Unverified Commit 50325e92 authored by Houjiang Chen's avatar Houjiang Chen Committed by GitHub
Browse files

Support symbol placement type in functional. (#5627)


* Support symbol placement type in functional.

* add sbp and sbp list arg

* refine

Co-authored-by: default avatarclackhan <han_binbin@163.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 975d9a96
No related branches found
No related tags found
No related merge requests found
......@@ -120,7 +120,6 @@ template<>
Maybe<AttrMap> PythonArg::ObjectAs<AttrMap>() const {
const auto& attrs = *(JUST(detail::cast<std::shared_ptr<MutableCfgAttrMap>>(Borrow())));
return std::make_shared<AttrMap>(*attrs);
;
}
template<>
......@@ -169,6 +168,26 @@ Maybe<one::Generator> PythonArg::ObjectAs<one::Generator>() const {
return *JUST(detail::cast<std::shared_ptr<one::Generator>>(Borrow()));
}
template<>
Maybe<Symbol<ParallelDesc>> PythonArg::ObjectAs<Symbol<ParallelDesc>>() const {
return **JUST(detail::cast<std::shared_ptr<Symbol<ParallelDesc>>>(Borrow()));
}
template<>
Maybe<Symbol<cfg::SbpParallel>> PythonArg::ObjectAs<Symbol<cfg::SbpParallel>>() const {
return **JUST(detail::cast<std::shared_ptr<Symbol<cfg::SbpParallel>>>(Borrow()));
}
template<>
Maybe<std::vector<Symbol<cfg::SbpParallel>>>
PythonArg::ObjectAs<std::vector<Symbol<cfg::SbpParallel>>>() const {
const auto& v =
JUST(detail::cast<std::vector<std::shared_ptr<Symbol<cfg::SbpParallel>>>>(Borrow()));
auto sbp_list = std::make_shared<std::vector<Symbol<cfg::SbpParallel>>>(v->size());
for (int i = 0; i < v->size(); ++i) { sbp_list->at(i) = *(v->at(i)); }
return sbp_list;
}
template<>
Maybe<TensorIndex> PythonArg::ObjectAs<TensorIndex>() const {
auto tensor_index = std::make_shared<TensorIndex>();
......
......@@ -27,8 +27,14 @@ namespace oneflow {
class Shape;
class AttrMap;
template<typename T>
class Symbol;
class ParallelDesc;
namespace cfg {
class AttrValue;
class SbpParallel;
} // namespace cfg
namespace one {
......@@ -82,6 +88,9 @@ enum ValueType {
kGENERATOR_REF,
kGENERATOR_MAYBE,
kTENSOR_INDEX,
kPARALLEL_DESC,
kSBP_PARALLEL,
kSBP_PARALLEL_LIST,
};
#define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \
......@@ -132,6 +141,9 @@ VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR);
VALUE_TYPE_OF_IMPL(std::shared_ptr<one::Generator>, kGENERATOR_REF);
VALUE_TYPE_OF_IMPL(Maybe<one::Generator>, kGENERATOR_MAYBE);
VALUE_TYPE_OF_IMPL(TensorIndex, kTENSOR_INDEX);
VALUE_TYPE_OF_IMPL(Symbol<ParallelDesc>, kPARALLEL_DESC);
VALUE_TYPE_OF_IMPL(Symbol<cfg::SbpParallel>, kSBP_PARALLEL);
VALUE_TYPE_OF_IMPL(std::vector<Symbol<cfg::SbpParallel>>, kSBP_PARALLEL_LIST);
#undef VALUE_TYPE_OF_IMPL
......
......@@ -361,4 +361,4 @@ from oneflow.ops.user_op_builder import (
api_user_op_module_builder as user_op_module_builder,
)
from . import autograd, distributed, linalg, optim, saved_model
from . import autograd, distributed, linalg, optim, saved_model, sbp
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from oneflow.framework.distribute import split_sbp as split
broadcast = oneflow._oneflow_internal.sbp.broadcast()
......
......@@ -151,6 +151,9 @@ types_allowed = {
"Shape",
"Generator",
"TensorIndex",
"ParallelDesc",
"SbpParallel",
"SbpParallelList",
}
generic_type_aliases = {
......@@ -179,6 +182,9 @@ argument_type_aliases = {
"Shape": "const Shape&",
"Generator": "const std::shared_ptr<one::Generator>&",
"TensorIndex": "const TensorIndex&",
"ParallelDesc": "const Symbol<ParallelDesc>&",
"SbpParallel": "const Symbol<cfg::SbpParallel>&",
"SbpParallelList": "const std::vector<Symbol<cfg::SbpParallel>>&",
**generic_type_aliases,
}
......@@ -199,6 +205,9 @@ optional_argument_type_aliases = {
"Shape": "const Optional<Shape>&",
"Generator": "const Optional<one::Generator>&",
"TensorIndex": "const Optional<TensorIndex>&",
"ParallelDesc": "const Optional<Symbol<ParallelDesc>>&",
"SbpParallel": "const Optional<Symbol<SbpParallel>>&",
"SbpParallelList": "const Optional<std::vector<Symbol<cfg::SbpParallel>>>&",
**{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()},
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment