From bf4bdd62a37f31a5972fd7e7a66209c7056b5cb4 Mon Sep 17 00:00:00 2001 From: Li Xinqi <lixinqi2010@gmail.com> Date: Fri, 30 Jul 2021 17:04:23 +0800 Subject: [PATCH] rebase (#5601) * rebase * check in gen py * merge master and fix bugs * address pr comments * address pr comments * auto format by CI * functional python_arg * auto format by CI * remove unused files * fix return type error on gcc 4.8.5 Signed-off-by: daquexian <daquexian566@gmail.com> * auto format by CI * fix return type error in xrt Signed-off-by: daquexian <daquexian566@gmail.com> * fix tick ibn sbp signature * auto format by CI Co-authored-by: tsai <jackalcooper@gmail.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: daquexian <daquexian566@gmail.com> --- oneflow/api/python/autograd/autograd.cpp | 3 +- oneflow/api/python/framework/device.cpp | 6 +- oneflow/api/python/framework/op_expr.cpp | 3 +- oneflow/api/python/framework/tensor.cpp | 16 +- oneflow/api/python/functional/python_arg.cpp | 6 + .../api/python/symbol/placement_symbol.cpp | 4 +- oneflow/api/python/symbol/sbp_symbol.cpp | 5 +- oneflow/core/autograd/autograd_meta.cpp | 36 ++++- oneflow/core/autograd/autograd_meta.h | 12 +- oneflow/core/common/symbol.h | 122 +++++++++------- .../consistent_tensor_infer_cache.cpp | 106 ++++++++++++-- .../framework/consistent_tensor_infer_cache.h | 50 ++++++- oneflow/core/framework/device.cpp | 19 +-- oneflow/core/framework/op_interpreter.h | 35 ++++- .../eager_consistent_op_interpreter.cpp | 30 ++-- .../eager_mirrored_op_interpreter.cpp | 10 +- .../op_interpreter/op_interpreter_util.cpp | 56 +++++-- .../op_interpreter/op_interpreter_util.h | 19 ++- oneflow/core/functional/functional_api.yaml | 9 +- .../functional/impl/activation_functor.cpp | 2 +- .../core/functional/impl/array_functor.cpp | 67 ++++++++- oneflow/core/functional/impl/binary_functor.h | 2 +- oneflow/core/functional/impl/math_functor.cpp | 5 +- oneflow/core/functional/impl/nn_functor.cpp | 2 +- .../core/functional/impl/random_functor.cpp | 3 +- oneflow/core/functional/impl/unary_functor.h | 2 +- oneflow/core/functional/value_types.h | 3 + oneflow/core/job/parallel_desc.cpp | 19 +++ oneflow/core/job/parallel_desc.h | 3 + oneflow/core/operator/user_op.cpp | 17 ++- oneflow/user/ops/constant_op.cpp | 23 ++- python/oneflow/framework/distribute.py | 4 +- python/oneflow/nn/modules/constant.py | 137 ++++++++++++++---- python/oneflow/sbp.py | 2 + python/oneflow/test/modules/test_constant.py | 16 +- tools/generate_functional_api.py | 21 +-- 36 files changed, 681 insertions(+), 194 deletions(-) diff --git a/oneflow/api/python/autograd/autograd.cpp b/oneflow/api/python/autograd/autograd.cpp index a2386b5ec..44e429303 100644 --- a/oneflow/api/python/autograd/autograd.cpp +++ b/oneflow/api/python/autograd/autograd.cpp @@ -54,9 +54,8 @@ Maybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs, CHECK_OR_RETURN(IsScalarTensor(*outputs.at(i))) << "Grad can be implicitly created only for scalar outputs"; const auto& ones_like = JUST(op_expr_helper::OnesLikeOp()); - const auto& interpreter = JUST(one::OpInterpUtil::GetInterpreter()); one::TensorTuple grad_output(1); - JUST(interpreter->Apply(*ones_like, one::TensorTuple{outputs.at(i)}, &grad_output)); + JUST(one::OpInterpUtil::Dispatch(*ones_like, one::TensorTuple{outputs.at(i)}, &grad_output)); gradients->at(i) = grad_output.at(0); } else { CHECK_OR_RETURN(*(outputs.at(i)->shape()) == *(out_grads.at(i)->shape())) diff --git a/oneflow/api/python/framework/device.cpp b/oneflow/api/python/framework/device.cpp index 842a72099..1d629f61f 100644 --- a/oneflow/api/python/framework/device.cpp +++ b/oneflow/api/python/framework/device.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include <pybind11/pybind11.h> +#include <pybind11/operators.h> #include "oneflow/api/python/common.h" #include "oneflow/api/python/framework/device.h" #include "oneflow/api/python/of_api_registry.h" @@ -50,9 +51,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { })) .def_property_readonly("type", [](const Symbol<Device>& d) { return d->type(); }) .def_property_readonly("index", [](const Symbol<Device>& d) { return d->device_id(); }) - .def("__eq__", [](const Symbol<Device>& d1, const Symbol<Device>& d2) { return *d1 == *d2; }) .def("__str__", [](const Symbol<Device>& d) { return d->ToString(); }) - .def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); }); + .def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); }) + .def(py::self == py::self) + .def(py::hash(py::self)); } } // namespace oneflow diff --git a/oneflow/api/python/framework/op_expr.cpp b/oneflow/api/python/framework/op_expr.cpp index b2c3e6240..e708f8daa 100644 --- a/oneflow/api/python/framework/op_expr.cpp +++ b/oneflow/api/python/framework/op_expr.cpp @@ -37,8 +37,7 @@ Maybe<one::TensorTuple> Interpret(const one::OpExpr& op, const one::TensorTuple& << "The operation requires " << op.input_size() << " inputs, but " << inputs.size() << " is given."; auto outputs = std::make_shared<one::TensorTuple>(op.output_size()); - auto interperter = JUST(one::OpInterpUtil::GetInterpreter()); - JUST(interperter->Apply(op, inputs, outputs.get(), attrs)); + JUST(one::OpInterpUtil::Dispatch(op, inputs, outputs.get(), attrs)); return outputs; } diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp index 0dda48d98..f0b2744af 100644 --- a/oneflow/api/python/framework/tensor.cpp +++ b/oneflow/api/python/framework/tensor.cpp @@ -251,6 +251,19 @@ bool ApiIsContiguous(const std::shared_ptr<Tensor>& tensor) { return IsContiguous(tensor).GetOrThrow(); } +Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) { + const auto& nd_sbp = JUST(tensor.parallel_distribution()); + const auto& tuple = std::make_shared<py::tuple>(nd_sbp->sbp_parallel_size()); + for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { + (*tuple)[i] = SymbolOf(nd_sbp->sbp_parallel(i)); + } + return tuple; +} + +py::tuple ApiTensorGetPyTupleOfSbp(const Tensor& tensor) { + return *TensorGetPyTupleOfSbp(tensor).GetPtrOrThrow(); +} + Maybe<Tensor> NewTensor(py::args args, py::kwargs kwargs, const DType* desired_dtype, bool treat_single_int_as_size) { Symbol<Device> device; @@ -399,7 +412,8 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { .def("_get_copy_mirrored_tensor_from_numpy_func_name", &ApiGetCopyMirroredTensorFromNumpyFuncName) // consistent tensor only - .def_property_readonly("placement", &TensorGetParallelDesc); + .def_property_readonly("placement", &TensorGetParallelDesc) + .def_property_readonly("sbp", &ApiTensorGetPyTupleOfSbp); auto nn = m.def_submodule("nn"); py::class_<Parameter, std::shared_ptr<Parameter>, Tensor>(nn, "Parameter") diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp index 21656d8d0..662607774 100644 --- a/oneflow/api/python/functional/python_arg.cpp +++ b/oneflow/api/python/functional/python_arg.cpp @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/common/data_type.cfg.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/dtype.h" +#include "oneflow/core/framework/device.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/user_op_attr.cfg.h" @@ -168,6 +169,11 @@ Maybe<one::Generator> PythonArg::ObjectAs<one::Generator>() const { return *JUST(detail::cast<std::shared_ptr<one::Generator>>(Borrow())); } +template<> +Maybe<Symbol<Device>> PythonArg::ObjectAs<Symbol<Device>>() const { + return **JUST(detail::cast<std::shared_ptr<Symbol<Device>>>(Borrow())); +} + template<> Maybe<Symbol<ParallelDesc>> PythonArg::ObjectAs<Symbol<ParallelDesc>>() const { return **JUST(detail::cast<std::shared_ptr<Symbol<ParallelDesc>>>(Borrow())); diff --git a/oneflow/api/python/symbol/placement_symbol.cpp b/oneflow/api/python/symbol/placement_symbol.cpp index 4764292b8..bf255f494 100644 --- a/oneflow/api/python/symbol/placement_symbol.cpp +++ b/oneflow/api/python/symbol/placement_symbol.cpp @@ -239,7 +239,9 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { }) .def_property_readonly("hierarchy", [](Symbol<ParallelDesc> p) { return p->hierarchy(); }) .def("__str__", &PlacementSymbolExportUtil::PlacementSymbol2String) - .def("__repr__", &PlacementSymbolExportUtil::PlacementSymbol2String); + .def("__repr__", &PlacementSymbolExportUtil::PlacementSymbol2String) + .def(py::self == py::self) + .def(py::hash(py::self)); m.def("AllDevicePlacement", &PlacementSymbolExportUtil::AllDevicePlacement); } diff --git a/oneflow/api/python/symbol/sbp_symbol.cpp b/oneflow/api/python/symbol/sbp_symbol.cpp index cff2ed0f2..f3c6a516e 100644 --- a/oneflow/api/python/symbol/sbp_symbol.cpp +++ b/oneflow/api/python/symbol/sbp_symbol.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include <pybind11/pybind11.h> +#include <pybind11/operators.h> #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/maybe.h" @@ -88,7 +89,9 @@ ONEFLOW_API_PYBIND11_MODULE("sbp", m) { m.attr("max_split_axis") = kMaxSplitAxis; py::class_<Symbol<cfg::SbpParallel>, std::shared_ptr<Symbol<cfg::SbpParallel>>>(m, "sbp") .def("__str__", &SbpParallelSymbolToString) - .def("__repr__", &SbpParallelSymbolToString); + .def("__repr__", &SbpParallelSymbolToString) + .def(py::self == py::self) + .def(py::hash(py::self)); m.def( "split", [](int axis) { return GetSplitSbpParallel(axis).GetOrThrow(); }, py::arg("axis")); m.def("broadcast", []() { return GetBroadcastSbpParallel().GetOrThrow(); }); diff --git a/oneflow/core/autograd/autograd_meta.cpp b/oneflow/core/autograd/autograd_meta.cpp index b82dcef88..348a8969a 100644 --- a/oneflow/core/autograd/autograd_meta.cpp +++ b/oneflow/core/autograd/autograd_meta.cpp @@ -23,9 +23,41 @@ namespace oneflow { namespace one { -TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {} +TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) { + if (TRY(tensor.device()).IsOk()) { device_ = CHECK_JUST(tensor.device()); } + if (TRY(tensor.parallel_desc()).IsOk()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); } + if (TRY(tensor.parallel_distribution()).IsOk()) { + parallel_distribution_ = CHECK_JUST(tensor.parallel_distribution()); + } +} -Maybe<Tensor> TensorInfo::zeros() const { return functional::Constant(*shape_.get(), 0, dtype_); } +Maybe<const std::vector<Symbol<cfg::SbpParallel>>&> GetSbpTuple( + Symbol<cfg::ParallelDistribution> parallel_distribution) { + static thread_local HashMap<Symbol<cfg::ParallelDistribution>, + std::vector<Symbol<cfg::SbpParallel>>> + map; + auto iter = map.find(parallel_distribution); + if (iter == map.end()) { + std::vector<Symbol<cfg::SbpParallel>> sbp_tuple; + for (const auto& sbp_parallel : parallel_distribution->sbp_parallel()) { + sbp_tuple.push_back(SymbolOf(sbp_parallel)); + } + iter = map.emplace(parallel_distribution, sbp_tuple).first; + } + return iter->second; +} + +Maybe<Tensor> TensorInfo::zeros() const { + if (device_.has_value()) { + const auto& device = JUST(device_.value()); + return functional::Constant(*shape_.get(), 0, dtype_, device); + } else { + const auto& parallel_desc = JUST(parallel_desc_.value()); + const auto& parallel_distribution = JUST(parallel_distribution_.value()); + const auto& sbp_tuple = JUST(GetSbpTuple(parallel_distribution)); + return functional::ConsistentConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple); + } +} } // namespace one diff --git a/oneflow/core/autograd/autograd_meta.h b/oneflow/core/autograd/autograd_meta.h index 47aea2542..e4894912e 100644 --- a/oneflow/core/autograd/autograd_meta.h +++ b/oneflow/core/autograd/autograd_meta.h @@ -21,11 +21,19 @@ limitations under the License. #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/optional.h" namespace oneflow { class Shape; +class Device; +class ParallelDesc; +namespace cfg { +class ParallelDistribution; +} + namespace one { class Tensor; @@ -86,7 +94,9 @@ class TensorInfo final { private: std::shared_ptr<const Shape> shape_; DataType dtype_; - // TODO: Add device info + Optional<Symbol<Device>> device_; // for local tensor + Optional<Symbol<ParallelDesc>> parallel_desc_; // for consistent tensor + Optional<Symbol<cfg::ParallelDistribution>> parallel_distribution_; // for consistent tensor }; } // namespace one diff --git a/oneflow/core/common/symbol.h b/oneflow/core/common/symbol.h index b4e5a9be4..fe65bdf30 100644 --- a/oneflow/core/common/symbol.h +++ b/oneflow/core/common/symbol.h @@ -19,12 +19,17 @@ limitations under the License. #include <mutex> #include <memory> #include <unordered_map> +#include <unordered_set> #include <glog/logging.h> #include "oneflow/core/common/type_traits.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/hash_eq_trait_ptr.h" namespace oneflow { +template<typename T> +class SymbolUtil; + template<typename T> class Symbol final { public: @@ -51,6 +56,8 @@ class Symbol final { std::shared_ptr<const T> shared_from_symbol() const; private: + template<typename SymbolT> + friend struct SymbolUtil; static const T* GetOrCreatePtr(const T& obj); const T* ptr_; @@ -61,80 +68,83 @@ struct IsScalarType<Symbol<T>> final { static const bool value = true; }; -namespace sym { template<typename T> -using SymbolTable = std::unordered_map<HashEqTraitPtr<const T>, std::shared_ptr<const T>>; +struct SymbolUtil final { + using SymbolMap = std::unordered_map<HashEqTraitPtr<const T>, std::shared_ptr<const T>>; -template<typename T> -SymbolTable<T>* GlobalSymbolTable() { - static SymbolTable<T> symbol_table; - return &symbol_table; -} + static SymbolMap* GlobalSymbolMap() { + static SymbolMap symbol_map; + return &symbol_map; + } -template<typename T> -std::mutex* GlobalSymbolTableMutex() { - static std::mutex mutex; - return &mutex; -} + static std::mutex* GlobalSymbolMapMutex() { + static std::mutex mutex; + return &mutex; + } -template<typename T> -SymbolTable<T>* ThreadLocalSymbolTable() { - static thread_local SymbolTable<T> thread_local_symbol_table; - return &thread_local_symbol_table; -} + static SymbolMap* ThreadLocalSymbolMap() { + static thread_local SymbolMap thread_local_symbol_map; + return &thread_local_symbol_map; + } -template<typename T, - typename SymbolTable<T>::iterator (*GetIter4ObjectAndHashValue)(const T&, size_t)> -std::shared_ptr<const T> LocalThreadGetOr(const T& obj) { - auto* thread_local_symbol_table = ThreadLocalSymbolTable<T>(); - size_t hash_value = std::hash<T>()(obj); - HashEqTraitPtr<const T> obj_ptr_wraper(&obj, hash_value); - const auto& local_iter = thread_local_symbol_table->find(obj_ptr_wraper); - if (local_iter != thread_local_symbol_table->end()) { return local_iter->second; } - const auto& iter = GetIter4ObjectAndHashValue(obj, hash_value); - (*thread_local_symbol_table)[iter->first] = iter->second; - return iter->second; -} + static std::unordered_set<const T*>* ThreadLocalSymbolPtrSet() { + static thread_local std::unordered_set<const T*> thread_local_symbol_ptr_set; + return &thread_local_symbol_ptr_set; + } -template<typename T> -typename SymbolTable<T>::iterator FindGlobalSymbol(const T& obj, size_t hash_value) { - HashEqTraitPtr<const T> new_obj_ptr_wraper(&obj, hash_value); - auto* symbol_table = GlobalSymbolTable<T>(); - std::unique_lock<std::mutex> lock(*GlobalSymbolTableMutex<T>()); - const auto& iter = symbol_table->find(new_obj_ptr_wraper); - CHECK(iter != symbol_table->end()); - return iter; -} + template<typename SymbolMap::iterator (*GetIter4ObjectAndHashValue)(const T&, size_t)> + static std::shared_ptr<const T> LocalThreadGetOr(const T& obj) { + auto* thread_local_symbol_map = ThreadLocalSymbolMap(); + size_t hash_value = std::hash<T>()(obj); + HashEqTraitPtr<const T> obj_ptr_wraper(&obj, hash_value); + const auto& local_iter = thread_local_symbol_map->find(obj_ptr_wraper); + if (local_iter != thread_local_symbol_map->end()) { return local_iter->second; } + const auto& iter = GetIter4ObjectAndHashValue(obj, hash_value); + (*thread_local_symbol_map)[iter->first] = iter->second; + CHECK(ThreadLocalSymbolPtrSet()->emplace(iter->second.get()).second); + return iter->second; + } -template<typename T> -std::shared_ptr<const T> SharedFromObject(const T& obj) { - return LocalThreadGetOr<T, FindGlobalSymbol<T>>(obj); -} + static typename SymbolMap::iterator FindGlobalSymbol(const T& obj, size_t hash_value) { + HashEqTraitPtr<const T> new_obj_ptr_wraper(&obj, hash_value); + auto* symbol_map = GlobalSymbolMap(); + std::unique_lock<std::mutex> lock(*GlobalSymbolMapMutex()); + const auto& iter = symbol_map->find(new_obj_ptr_wraper); + CHECK(iter != symbol_map->end()); + return iter; + } -template<typename T> -typename SymbolTable<T>::iterator CreateGlobalSymbol(const T& obj, size_t hash_value) { - std::shared_ptr<const T> ptr(new T(obj)); - HashEqTraitPtr<const T> new_obj_ptr_wraper(ptr.get(), hash_value); - std::unique_lock<std::mutex> lock(*GlobalSymbolTableMutex<T>()); - return GlobalSymbolTable<T>()->emplace(new_obj_ptr_wraper, ptr).first; -} + static std::shared_ptr<const T> SharedFromObject(const T& obj) { + return LocalThreadGetOr<FindGlobalSymbol>(obj); + } -template<typename T> -std::shared_ptr<const T> GetOrCreatePtr(const T& obj) { - return LocalThreadGetOr<T, CreateGlobalSymbol<T>>(obj); -} + static typename SymbolMap::iterator CreateGlobalSymbol(const T& obj, size_t hash_value) { + std::shared_ptr<const T> ptr(new T(obj)); + HashEqTraitPtr<const T> new_obj_ptr_wraper(ptr.get(), hash_value); + std::unique_lock<std::mutex> lock(*GlobalSymbolMapMutex()); + return GlobalSymbolMap()->emplace(new_obj_ptr_wraper, ptr).first; + } -} // namespace sym + static std::shared_ptr<const T> GetOrCreatePtr(const T& obj) { + return LocalThreadGetOr<CreateGlobalSymbol>(obj); + } + static Maybe<Symbol<T>> GetSymbolByExistedRawPtr(const T* ptr) { + CHECK_GT_OR_RETURN(ThreadLocalSymbolPtrSet()->count(ptr), 0) << "ptr: " << ptr; + Symbol<T> symbol; + symbol.ptr_ = ptr; + return symbol; + } +}; template<typename T> std::shared_ptr<const T> Symbol<T>::shared_from_symbol() const { if (this->ptr_ == nullptr) { return std::shared_ptr<const T>(); } - return sym::SharedFromObject(*this->ptr_); + return SymbolUtil<T>::SharedFromObject(*this->ptr_); } template<typename T> const T* Symbol<T>::GetOrCreatePtr(const T& obj) { - return sym::GetOrCreatePtr(obj).get(); + return SymbolUtil<T>::GetOrCreatePtr(obj).get(); } template<typename T> diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.cpp b/oneflow/core/framework/consistent_tensor_infer_cache.cpp index 4fc2e82fd..07c4865e6 100644 --- a/oneflow/core/framework/consistent_tensor_infer_cache.cpp +++ b/oneflow/core/framework/consistent_tensor_infer_cache.cpp @@ -15,7 +15,6 @@ limitations under the License. */ #include "oneflow/core/framework/consistent_tensor_infer_cache.h" #include "oneflow/core/framework/tensor_tuple.h" -#include "oneflow/core/job/placement_scope.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/tensor.h" @@ -44,8 +43,7 @@ void InputConsistentTensorMeta::assign( } size_t ConsistentTensorMetaInferArgs::hash_value() const { - size_t hash_value = std::hash<Symbol<PlacementScope>>()(placement_scope_); - hash_value ^= std::hash<AttrMap>()(attrs_); + size_t hash_value = std::hash<AttrMap>()(attrs_); const auto& tensor_meta_hash_functor = std::hash<InputConsistentTensorMeta>(); for (const auto& tensor_meta : input_consistent_tensor_metas_) { HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta)); @@ -53,9 +51,22 @@ size_t ConsistentTensorMetaInferArgs::hash_value() const { return hash_value; } +size_t SrcOpConsistentTensorMetaInferArgs::hash_value() const { + size_t hash_value = std::hash<AttrMap>()(attrs_); + hash_value ^= std::hash<Symbol<ParallelDesc>>()(parallel_desc_); + hash_value ^= std::hash<Symbol<cfg::ParallelDistribution>>()(parallel_distribution_); + return hash_value; +} + bool ConsistentTensorMetaInferArgs::operator==(const ConsistentTensorMetaInferArgs& other) const { - return this->input_consistent_tensor_metas_ == other.input_consistent_tensor_metas_ - && this->placement_scope_ == other.placement_scope_ && this->attrs_ == other.attrs_; + return this->attrs_ == other.attrs_ + && this->input_consistent_tensor_metas_ == other.input_consistent_tensor_metas_; +} + +bool SrcOpConsistentTensorMetaInferArgs::operator==( + const SrcOpConsistentTensorMetaInferArgs& other) const { + return this->attrs_ == other.attrs_ && this->parallel_desc_ == other.parallel_desc_ + && this->parallel_distribution_ == other.parallel_distribution_; } Maybe<void> ConsistentTensorMetaInferArgs::MakeParallelDistributionConstraints( @@ -101,16 +112,25 @@ Maybe<void> ConsistentTensorMetaInferArgs::MakeParallelDistributionInferHints( } Maybe<ConsistentTensorMetaInferArgs> ConsistentTensorMetaInferArgs::New( - const TensorTuple& input_tensors, Symbol<PlacementScope> placement_scope, - const AttrMap& attrs) { + const AttrMap& attrs, const TensorTuple& input_tensors) { std::shared_ptr<ConsistentTensorMetaInferArgs> infer_args(new ConsistentTensorMetaInferArgs()); - infer_args->input_consistent_tensor_metas_.resize(input_tensors.size()); - infer_args->placement_scope_ = placement_scope; infer_args->attrs_ = attrs; + infer_args->input_consistent_tensor_metas_.resize(input_tensors.size()); JUST(infer_args->InitInputConsistentTensorMetas(input_tensors)); return infer_args; } +Maybe<SrcOpConsistentTensorMetaInferArgs> SrcOpConsistentTensorMetaInferArgs::New( + const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc, + Symbol<cfg::ParallelDistribution> parallel_distribution) { + std::shared_ptr<SrcOpConsistentTensorMetaInferArgs> infer_args( + new SrcOpConsistentTensorMetaInferArgs()); + infer_args->attrs_ = attrs; + infer_args->parallel_desc_ = parallel_desc; + infer_args->parallel_distribution_ = parallel_distribution; + return infer_args; +} + Maybe<void> ConsistentTensorMetaInferArgs::InitInputConsistentTensorMetas( const TensorTuple& input_tensors) { for (int i = 0; i < input_tensors.size(); ++i) { @@ -132,16 +152,31 @@ Maybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs, return JUST(ConstructOp(op_conf, device_type)); } +Maybe<void> CheckInputParallelDescIdentical(const ConsistentTensorMetaInferArgs& infer_args) { + if (infer_args.input_consistent_tensor_metas().empty()) { return Maybe<void>::Ok(); } + const auto& first_parallel_desc = + infer_args.input_consistent_tensor_metas().begin()->tensor_meta()->parallel_desc(); + for (const auto& input_meta : infer_args.input_consistent_tensor_metas()) { + CHECK_OR_RETURN(first_parallel_desc == input_meta.tensor_meta()->parallel_desc()); + } + return Maybe<void>::Ok(); +} + +Maybe<void> CheckIsDeviceSupportedByOp(const ParallelDesc& parallel_desc, + const std::string& op_type_name) { + if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(parallel_desc.device_tag(), "cpu"); } + return Maybe<void>::Ok(); +} + } // namespace /* static */ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::Infer( const UserOpExpr& user_op_expr, const ConsistentTensorMetaInferArgs& infer_args) { - Symbol<ParallelDesc> parallel_desc; - { - // Get parallel description. - const auto& placement_scope = infer_args.placement_scope(); - parallel_desc = JUST(placement_scope->GetParallelDesc(user_op_expr.op_type_name())); - } + CHECK_GT_OR_RETURN(infer_args.input_consistent_tensor_metas().size(), 0); + Symbol<ParallelDesc> parallel_desc = + infer_args.input_consistent_tensor_metas().at(0).tensor_meta()->parallel_desc(); + JUST(CheckInputParallelDescIdentical(infer_args)); + JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name())); std::vector<OpArgMutConsistentTensorMeta> output_mut_metas(user_op_expr.output_size()); { // Infer OpArgMutConsistentTensorMeta. @@ -194,6 +229,35 @@ Maybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs, return std::shared_ptr<const ConsistentTensorInferResult>(result); } +/* static */ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::Infer( + const UserOpExpr& user_op_expr, const SrcOpConsistentTensorMetaInferArgs& infer_args) { + Symbol<ParallelDesc> parallel_desc = infer_args.parallel_desc(); + JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name())); + std::vector<OpArgMutConsistentTensorMeta> output_mut_metas(user_op_expr.output_size()); + { + // Infer OpArgMutConsistentTensorMeta. + const auto& GetInputTensorMeta = [](int32_t i) { + UNIMPLEMENTED(); + return nullptr; + }; + JUST(user_op_expr.InferLogicalShapeAndDType( + infer_args.attrs(), parallel_desc->device_tag(), GetInputTensorMeta, + [&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); })); + } + auto* result = + new ConsistentTensorInferResult(user_op_expr.input_size(), user_op_expr.output_size()); + auto* output_metas = result->mut_output_tensor_metas(); + for (int32_t i = 0; i < user_op_expr.output_size(); ++i) { + const auto& output_mut_meta = output_mut_metas.at(i); + const auto& shape = output_mut_meta.tensor_meta().shape_ptr(); + DataType data_type = output_mut_meta.tensor_meta().data_type(); + const auto& parallel_distribution = infer_args.parallel_distribution(); + ConsistentTensorMeta tensor_meta(shape, data_type, parallel_distribution, parallel_desc); + output_metas->at(i) = SymbolOf(tensor_meta); + } + return std::shared_ptr<const ConsistentTensorInferResult>(result); +} + Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer( const ConsistentTensorMetaInferArgs& infer_args) { auto iter = cache_.find(infer_args); @@ -206,5 +270,17 @@ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer( return iter->second; } +Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer( + const SrcOpConsistentTensorMetaInferArgs& infer_args) { + auto iter = src_op_cache_.find(infer_args); + if (iter == src_op_cache_.end()) { + const auto& user_op_expr = user_op_expr_.lock(); + CHECK_OR_RETURN(static_cast<bool>(user_op_expr)); + const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args)); + iter = src_op_cache_.emplace(infer_args, output_tensor_metas).first; + } + return iter->second; +} + } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.h b/oneflow/core/framework/consistent_tensor_infer_cache.h index 852441cff..0520f3511 100644 --- a/oneflow/core/framework/consistent_tensor_infer_cache.h +++ b/oneflow/core/framework/consistent_tensor_infer_cache.h @@ -30,7 +30,7 @@ namespace cfg { class ParallelDistribution; } -class PlacementScope; +class ParallelDesc; namespace one { @@ -75,7 +75,6 @@ class ConsistentTensorMetaInferArgs final { const std::vector<InputConsistentTensorMeta>& input_consistent_tensor_metas() const { return input_consistent_tensor_metas_; } - Symbol<PlacementScope> placement_scope() const { return placement_scope_; } const AttrMap& attrs() const { return attrs_; } size_t hash_value() const; @@ -93,17 +92,41 @@ class ConsistentTensorMetaInferArgs final { const UserOpExpr& user_op_expr, const std::vector<BlobDesc>& blob_descs, std::vector<ParallelDistributionInferHint>* hints) const; - static Maybe<ConsistentTensorMetaInferArgs> New(const TensorTuple& input_tensors, - Symbol<PlacementScope> placement_scope, - const AttrMap& attrs); + static Maybe<ConsistentTensorMetaInferArgs> New(const AttrMap& attrs, + const TensorTuple& input_tensors); private: ConsistentTensorMetaInferArgs() = default; Maybe<void> InitInputConsistentTensorMetas(const TensorTuple& input_tensors); + AttrMap attrs_; std::vector<InputConsistentTensorMeta> input_consistent_tensor_metas_; - Symbol<PlacementScope> placement_scope_; +}; + +class SrcOpConsistentTensorMetaInferArgs final { + public: + SrcOpConsistentTensorMetaInferArgs(const SrcOpConsistentTensorMetaInferArgs&) = default; + SrcOpConsistentTensorMetaInferArgs(SrcOpConsistentTensorMetaInferArgs&&) = default; + ~SrcOpConsistentTensorMetaInferArgs() = default; + + Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; } + Symbol<cfg::ParallelDistribution> parallel_distribution() const { return parallel_distribution_; } + const AttrMap& attrs() const { return attrs_; } + + size_t hash_value() const; + + bool operator==(const SrcOpConsistentTensorMetaInferArgs& other) const; + + static Maybe<SrcOpConsistentTensorMetaInferArgs> New( + const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc, + Symbol<cfg::ParallelDistribution> parallel_distribution); + + private: + SrcOpConsistentTensorMetaInferArgs() = default; + AttrMap attrs_; + Symbol<ParallelDesc> parallel_desc_; + Symbol<cfg::ParallelDistribution> parallel_distribution_; }; class OpArgMutConsistentTensorMeta final { @@ -142,6 +165,13 @@ struct hash<oneflow::one::ConsistentTensorMetaInferArgs> final { } }; +template<> +struct hash<oneflow::one::SrcOpConsistentTensorMetaInferArgs> final { + size_t operator()(const oneflow::one::SrcOpConsistentTensorMetaInferArgs& val) const { + return val.hash_value(); + } +}; + } // namespace std namespace oneflow { @@ -185,9 +215,17 @@ class ConsistentTensorInferCache final { static Maybe<const ConsistentTensorInferResult> Infer( const UserOpExpr& user_op_expr, const ConsistentTensorMetaInferArgs& infer_args); + Maybe<const ConsistentTensorInferResult> GetOrInfer( + const SrcOpConsistentTensorMetaInferArgs& infer_args); + + static Maybe<const ConsistentTensorInferResult> Infer( + const UserOpExpr& user_op_expr, const SrcOpConsistentTensorMetaInferArgs& infer_args); + private: std::weak_ptr<const UserOpExpr> user_op_expr_; HashMap<ConsistentTensorMetaInferArgs, std::shared_ptr<const ConsistentTensorInferResult>> cache_; + HashMap<SrcOpConsistentTensorMetaInferArgs, std::shared_ptr<const ConsistentTensorInferResult>> + src_op_cache_; }; } // namespace one diff --git a/oneflow/core/framework/device.cpp b/oneflow/core/framework/device.cpp index b34808236..e4e1d3ec5 100644 --- a/oneflow/core/framework/device.cpp +++ b/oneflow/core/framework/device.cpp @@ -61,20 +61,21 @@ Maybe<void> Device::Init() { } /* static */ Maybe<Symbol<Device>> Device::New(const std::string& type, int64_t device_id) { - Device device(type, device_id); - JUST(device.Init()); - return SymbolOf(device); + return ThreadLocalGetOrNew(type, device_id); } /* static */ Maybe<Symbol<Device>> Device::ThreadLocalGetOrNew(const std::string& type, int64_t device_id) { CHECK_GE_OR_RETURN(device_id, 0); - static thread_local HashMap<std::string, std::vector<Symbol<Device>>> type2device_id2device; - auto* vec = &type2device_id2device[type]; - if (vec->size() <= device_id) { vec->resize(device_id + 1); } - auto* pptr = &vec->at(device_id); - if (!*pptr) { *pptr = JUST(New(type, device_id)); } - return *pptr; + static thread_local HashMap<std::string, HashMap<int64_t, Symbol<Device>>> map; + auto* device_id2symbol = &map[type]; + auto iter = device_id2symbol->find(device_id); + if (iter == device_id2symbol->end()) { + Device device(type, device_id); + JUST(device.Init()); + iter = device_id2symbol->emplace(device_id, SymbolOf(device)).first; + } + return iter->second; } /* static */ Maybe<Symbol<Device>> Device::New(const std::string& type) { diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h index 3d7f6a7de..abaa45aba 100644 --- a/oneflow/core/framework/op_interpreter.h +++ b/oneflow/core/framework/op_interpreter.h @@ -21,8 +21,16 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/common/optional.h" namespace oneflow { + +class Device; +class ParallelDesc; +namespace cfg { +class ParallelDistribution; +} + namespace one { class OpExprInterpState { @@ -43,7 +51,24 @@ class OpExprInterpState { }; struct OpExprInterpContext { + OpExprInterpContext(const AttrMap& attrs_arg) : attrs(attrs_arg) {} + OpExprInterpContext(const AttrMap& attrs_arg, Symbol<Device> device_arg) + : attrs(attrs_arg), device(device_arg) {} + OpExprInterpContext(const AttrMap& attrs_arg, std::shared_ptr<user_op::OpKernelState> state_arg) + : attrs(attrs_arg), state(state_arg) {} + OpExprInterpContext(const AttrMap& attrs_arg, Symbol<Device> device_arg, + std::shared_ptr<user_op::OpKernelState> state_arg) + : attrs(attrs_arg), device(device_arg), state(state_arg) {} + OpExprInterpContext(const AttrMap& attrs_arg, Symbol<ParallelDesc> parallel_desc_arg, + Symbol<cfg::ParallelDistribution> parallel_distribution_arg) + : attrs(attrs_arg), + parallel_desc(parallel_desc_arg), + parallel_distribution(parallel_distribution_arg) {} + AttrMap attrs; + Optional<Symbol<Device>> device; // for local op + Optional<Symbol<ParallelDesc>> parallel_desc; // for consistent op + Optional<Symbol<cfg::ParallelDistribution>> parallel_distribution; // for consistent op std::shared_ptr<user_op::OpKernelState> state; }; @@ -54,7 +79,7 @@ class OpExprInterpreter { Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) const { - return Apply(op, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + return Apply(op, inputs, outputs, OpExprInterpContext(attrs)); } Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const { @@ -92,7 +117,7 @@ class LazyInterpreter : public OpExprInterpreter { Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) const { - return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs)); } Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, @@ -113,7 +138,7 @@ class EagerInterpreter : public OpExprInterpreter { Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) const { - return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs)); } Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, @@ -156,11 +181,11 @@ class AutogradInterpreter { Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const AttrMap& attrs) const { - return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs)); } Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) const { - return Apply(op_expr, inputs, outputs, OpExprInterpContext{}); + return Apply(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{})); } Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index 99f185ad4..d42c5b4bd 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -26,7 +26,6 @@ limitations under the License. #include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/consistent_tensor_infer_cache.h" -#include "oneflow/core/job/placement_scope.h" #include "oneflow/core/eager/foreign_boxing_util.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/kernels/stateful_local_opkernel.h" @@ -36,6 +35,12 @@ namespace one { namespace { +Maybe<Symbol<ParallelDesc>> GetParallelDesc(const TensorTuple& inputs, + const OpExprInterpContext& ctx) { + if (!inputs.empty()) { return inputs.at(0)->parallel_desc(); } + return ctx.parallel_desc.value(); +} + std::string GetDynamicOpConsistentFailedDebugString(const UserOpExpr& user_op_expr, const StatefulLocalOpKernel& kernel) { CHECK(!kernel.output_tuple_indexes4mut2_obns().empty()); @@ -58,16 +63,19 @@ std::string GetDynamicOpConsistentFailedDebugString(const UserOpExpr& user_op_ex Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) { CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); - const auto& placement_scope = JUST(GetCurrentScope())->placement_scope(); - const auto& infer_args = - JUST(ConsistentTensorMetaInferArgs::New(inputs, placement_scope, ctx.attrs)); - const auto& result = - JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args)); + const auto& parallel_desc = JUST(GetParallelDesc(inputs, ctx)); + std::shared_ptr<const ConsistentTensorInferResult> result; + if (inputs.empty()) { + const auto& infer_args = JUST(SrcOpConsistentTensorMetaInferArgs::New( + ctx.attrs, parallel_desc, JUST(ctx.parallel_distribution.value()))); + result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args)); + } else { + const auto& infer_args = JUST(ConsistentTensorMetaInferArgs::New(ctx.attrs, inputs)); + result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args)); + } const auto& output_tensor_metas = result->output_tensor_metas(); - const auto& parallel_desc = - JUST(placement_scope->GetParallelDesc(user_op_expr.op_type_name())).shared_from_symbol(); int64_t parallel_id = -1; - const auto& device = JUST(parallel_desc->GetDevice4CurrentProcessCtx(¶llel_id)); + const auto& device = JUST(GetDevice4CurrentProcessCtx(parallel_desc, ¶llel_id)); using TensorImpl = EagerConsistentTensorImpl; TensorImpl::NewMethod New = (device ? &TensorImpl::NewWithPhyTensor : &TensorImpl::NewWithoutPhyTensor); @@ -97,7 +105,7 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const auto& instr_type_name = JUST(GetLocalCallInstructionName(parallel_desc->device_tag())); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, - ctx, parallel_desc, instr_type_name); + ctx, parallel_desc.shared_from_symbol(), instr_type_name); })); return Maybe<void>::Ok(); } @@ -105,7 +113,7 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, Maybe<void> EagerConsistentInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { - OF_UNIMPLEMENTED(); + return Interpret(op_expr, inputs, outputs, ctx); } Maybe<void> EagerConsistentInterpreter::ApplyImpl(const VariableOpExpr& op_expr, diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index 86ada7f46..728ab0880 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -37,7 +37,10 @@ namespace one { namespace { -Maybe<Symbol<Device>> GetDefaultDevice() { return Device::New("cpu", 0); } +Maybe<Symbol<Device>> GetDefaultDevice(const OpExprInterpContext& ctx) { + if (ctx.device.has_value()) { return ctx.device.value(); } + return Device::New("cpu", 0); +} Maybe<EagerMirroredTensorImpl*> TensorImpl4Tensor(const std::shared_ptr<Tensor>& tensor) { CHECK_OR_RETURN(static_cast<bool>(tensor)); @@ -145,8 +148,7 @@ Maybe<void> RunEmptyOp(TensorTuple* outputs) { const auto& device = tensor_impl->device(); const auto empty_expr = JUST(op_expr_helper::EmptyOp(*shape, data_type)); std::shared_ptr<TensorTuple> inputs = std::make_shared<TensorTuple>(); - JUST(NaiveInterpret(*empty_expr, *inputs, device, outputs, - OpExprInterpContext{AttrMap{}, nullptr})); + JUST(NaiveInterpret(*empty_expr, *inputs, device, outputs, OpExprInterpContext(AttrMap{}))); return Maybe<void>::Ok(); } @@ -155,7 +157,7 @@ static Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTu CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); Symbol<Device> default_device; if (inputs.empty()) { - default_device = JUST(GetDefaultDevice()); + default_device = JUST(GetDefaultDevice(ctx)); } else { default_device = JUST(inputs.at(0)->device()); } diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp index eaf0d20ac..f855e295e 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp @@ -45,30 +45,60 @@ std::shared_ptr<AutogradInterpreter> BuildLazyInterpreter() { return std::make_shared<AutogradInterpreter>(internal); } -} // namespace - -/* static */ Maybe<AutogradInterpreter> OpInterpUtil::GetInterpreter() { +Maybe<AutogradInterpreter> GetInterpreter(const TensorTuple& inputs, + const OpExprInterpContext& ctx) { static const auto& g_lazy_interpreter = BuildLazyInterpreter(); static const auto& g_eager_consistent_interpreter = BuildEagerInterpreter(/*is_mirrored=*/false); static const auto& g_eager_mirrored_interpreter = BuildEagerInterpreter(/*is_mirrored=*/true); if (!LazyMode::is_enabled()) { - const auto& session = JUST(GetDefaultSession()); - bool is_mirrored_strategy_enabled = session->is_mirrored_strategy_enabled_stack()->empty() - || JUST(session->IsMirroredStrategyEnabled()); - if (is_mirrored_strategy_enabled) { - return g_eager_mirrored_interpreter; + if (inputs.empty()) { + if (ctx.parallel_desc.has_value()) { + JUST(ctx.parallel_distribution.value()); + CHECK_OR_RETURN(!ctx.device.has_value()); + return g_eager_consistent_interpreter; + } else { + CHECK_OR_RETURN(!ctx.parallel_distribution.has_value()); + return g_eager_mirrored_interpreter; + } } else { - return g_eager_consistent_interpreter; + if (inputs.at(0)->is_consistent()) { + if (inputs.size() == 1) { + // do nothing + } else if (inputs.size() == 2) { + CHECK_OR_RETURN(inputs.at(1)->is_consistent()); // unroll loop for efficiency + } else if (inputs.size() == 3) { + CHECK_OR_RETURN(inputs.at(1)->is_consistent()); // unroll loop for efficiency + CHECK_OR_RETURN(inputs.at(2)->is_consistent()); // unroll loop for efficiency + } else { + for (const auto& tensor : inputs) { CHECK_OR_RETURN(tensor->is_consistent()); } + } + return g_eager_consistent_interpreter; + } else { + if (inputs.size() == 1) { + // do nothing + } else if (inputs.size() == 2) { + CHECK_OR_RETURN(inputs.at(1)->is_local()); // unroll loop for efficiency + } else if (inputs.size() == 3) { + CHECK_OR_RETURN(inputs.at(1)->is_local()); // unroll loop for efficiency + CHECK_OR_RETURN(inputs.at(2)->is_local()); // unroll loop for efficiency + } else { + for (const auto& tensor : inputs) { CHECK_OR_RETURN(tensor->is_local()); } + } + return g_eager_mirrored_interpreter; + } } + UNIMPLEMENTED_THEN_RETURN(); } return g_lazy_interpreter; } +} // namespace + template<> /* static */ Maybe<TensorTuple> OpInterpUtil::Dispatch<TensorTuple>( const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) { auto outputs = std::make_shared<TensorTuple>(op_expr.output_size()); - JUST(JUST(GetInterpreter())->Apply(op_expr, inputs, outputs.get(), ctx)); + JUST(Dispatch(op_expr, inputs, outputs.get(), ctx)); return outputs; } @@ -79,6 +109,12 @@ template<> return JUST(Dispatch<TensorTuple>(op_expr, inputs, ctx))->at(0); } +/* static */ Maybe<void> OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, + TensorTuple* outputs, + const OpExprInterpContext& ctx) { + return JUST(GetInterpreter(inputs, ctx))->Apply(op_expr, inputs, outputs, ctx); +} + /* static */ Maybe<cfg::OpAttribute> OpInterpUtil::AddOpAndInferOpAttribute( const OperatorConf& op_conf, const bool is_mirrored_strategy_enabled) { std::shared_ptr<OpAttribute> op_attribute = JUST([&]() -> Maybe<OpAttribute> { diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.h b/oneflow/core/framework/op_interpreter/op_interpreter_util.h index 0f24c020a..b080ba14f 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.h +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.h @@ -31,22 +31,33 @@ namespace one { class OpInterpUtil { public: - static Maybe<AutogradInterpreter> GetInterpreter(); - template<typename T> static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs) { - return Dispatch<T>(op_expr, inputs, OpExprInterpContext{attrs, nullptr}); + return Dispatch<T>(op_expr, inputs, OpExprInterpContext(attrs)); } template<typename T> static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs) { - return Dispatch<T>(op_expr, inputs, OpExprInterpContext{AttrMap{}, nullptr}); + return Dispatch<T>(op_expr, inputs, OpExprInterpContext(AttrMap{})); } template<typename T> static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx); + static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, + TensorTuple* outputs, const AttrMap& attrs) { + return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(attrs)); + } + + static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, + TensorTuple* outputs) { + return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{})); + } + + static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, + TensorTuple* outputs, const OpExprInterpContext& ctx); + static Maybe<cfg::OpAttribute> AddOpAndInferOpAttribute(const OperatorConf& op_conf, const bool is_mirrored_strategy_enabled); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index d30e68d95..276ed2e10 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -16,7 +16,8 @@ # { # "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", # "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", -# "BoolList", "DataType", "Shape", "Generator", "TensorIndex" +# "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement", +# "Sbp", "SbpList" # } - name: "add_n" @@ -264,7 +265,11 @@ bind_python: True - name: "constant" - signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype)" + signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype, Device device=None)" + bind_python: True + +- name: "consistent_constant" + signature: "Tensor ConsistentConstant(*, Shape shape, Scalar value, DataType dtype, Placement placement, SbpList sbp_tuple)" bind_python: True - name: "zeros_like" diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp index c12981038..06af6ef03 100644 --- a/oneflow/core/functional/impl/activation_functor.cpp +++ b/oneflow/core/functional/impl/activation_functor.cpp @@ -42,7 +42,7 @@ class ReluFunctor { if (inplace) { std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1); outputs->at(0) = x; - JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get(), AttrMap{})); + JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), AttrMap{})); return outputs->at(0); } else { return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index c1394ca3d..f40ebc64e 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -25,6 +25,11 @@ limitations under the License. #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/functional/scalar.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/global_for.h" +#include "oneflow/core/common/global.h" +#include "oneflow/core/common/optional.h" +#include "oneflow/core/common/protobuf.h" namespace oneflow { namespace one { @@ -32,10 +37,57 @@ namespace functional { namespace impl { +class ConsistentConstantFunctor { + public: + ConsistentConstantFunctor() { + op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build()); + } + Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const DataType& dtype, + const Symbol<ParallelDesc>& placement, + const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr<Shape>("shape", shape)); + JUST(attrs.SetAttr<DataType>("dtype", dtype)); + if (IsIntegralDataType(dtype)) { + JUST(attrs.SetAttr<bool>("is_floating_value", false)); + JUST(attrs.SetAttr<int64_t>("integer_value", JUST(value.As<int64_t>()))); + } else { + JUST(attrs.SetAttr<bool>("is_floating_value", true)); + JUST(attrs.SetAttr<double>("floating_value", JUST(value.As<double>()))); + } + const auto& parallel_distribution = JUST(MakeParallelDistribution(sbp_tuple)); + if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + JUST(attrs.SetAttr<std::string>("nd_sbp", parallel_distribution->DebugString())); + } + return OpInterpUtil::Dispatch<Tensor>( + *op_, {}, OpExprInterpContext(attrs, placement, parallel_distribution)); + } + + Maybe<Symbol<cfg::ParallelDistribution>> MakeParallelDistribution( + const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const { + static thread_local std::map<std::vector<Symbol<cfg::SbpParallel>>, + Symbol<cfg::ParallelDistribution>> + map; + auto iter = map.find(sbp_tuple); + if (iter == map.end()) { + cfg::ParallelDistribution parallel_distribution; + for (const auto& sbp_parallel : sbp_tuple) { + *parallel_distribution.mutable_sbp_parallel()->Add() = *sbp_parallel; + } + iter = map.emplace(sbp_tuple, SymbolOf(parallel_distribution)).first; + } + return iter->second; + } + + private: + std::shared_ptr<OpExpr> op_; +}; + class ConstantFunctor { public: ConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build()); } - Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const DataType& dtype) const { + Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const DataType& dtype, + const Optional<Symbol<Device>>& device) const { MutableAttrMap attrs; JUST(attrs.SetAttr<Shape>("shape", shape)); JUST(attrs.SetAttr<DataType>("dtype", dtype)); @@ -46,7 +98,17 @@ class ConstantFunctor { JUST(attrs.SetAttr<bool>("is_floating_value", true)); JUST(attrs.SetAttr<double>("floating_value", JUST(value.As<double>()))); } - return OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs); + { + ParallelDistribution parallel_distribution; + parallel_distribution.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); + JUST(attrs.SetAttr<std::string>("nd_sbp", PbMessage2TxtString(parallel_distribution))); + } + if (device.has_value()) { + Symbol<Device> device_symbol = JUST(device.value()); + return OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, device_symbol)); + } else { + return OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs); + } } private: @@ -953,6 +1015,7 @@ class TensorSetItemFunctor { } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor<impl::ConsistentConstantFunctor>("ConsistentConstant"); m.add_functor<impl::ConstantFunctor>("Constant"); m.add_functor<impl::ZerosLikeFunctor>("ZerosLike"); m.add_functor<impl::OnesLikeFunctor>("OnesLike"); diff --git a/oneflow/core/functional/impl/binary_functor.h b/oneflow/core/functional/impl/binary_functor.h index 7be347c82..46e7bffb2 100644 --- a/oneflow/core/functional/impl/binary_functor.h +++ b/oneflow/core/functional/impl/binary_functor.h @@ -48,7 +48,7 @@ class InplaceableBinaryFunctor { if (inplace) { std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1); outputs->at(0) = x; - JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x, y}, outputs.get())); + JUST(OpInterpUtil::Dispatch(*op_, {x, y}, outputs.get())); return outputs->at(0); } else { return OpInterpUtil::Dispatch<Tensor>(*op_, {x, y}); diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index 079eba44f..a451ccbef 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -50,8 +50,7 @@ class AddNFunctor { if (i == 0 && inplace) { std::shared_ptr<TensorTuple> outs = std::make_shared<TensorTuple>(1); outs->at(0) = partial_inputs.at(0); - JUST(JUST(OpInterpUtil::GetInterpreter()) - ->Apply(*op_.at(size - 1), partial_inputs, outs.get())); + JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs, outs.get())); outputs.push_back(outs->at(0)); } else { outputs.push_back(JUST(OpInterpUtil::Dispatch<Tensor>(*op_.at(size - 1), partial_inputs))); @@ -87,7 +86,7 @@ class ScalarAddFunctor { if (inplace) { std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1); outputs->at(0) = x; - JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get(), attrs)); + JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs)); return outputs->at(0); } else { return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs); diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index d00eeb783..29ba46df2 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -489,7 +489,7 @@ class DropoutFunctor { const auto& mask = JUST(OpInterpUtil::Dispatch<Tensor>( *random_mask_like_op_, {x}, - OpExprInterpContext{.attrs = random_mask_like_attrs, .state = random_mask_like_state})); + OpExprInterpContext(random_mask_like_attrs, random_mask_like_state))); float scale = 1.0; if (p != 1.0) { scale = 1.0 / (1.0 - p); } MutableAttrMap dropout_attrs; diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp index f630bd28c..5216dccd0 100644 --- a/oneflow/core/functional/impl/random_functor.cpp +++ b/oneflow/core/functional/impl/random_functor.cpp @@ -55,8 +55,7 @@ class BernoulliFunctor { const auto& bernoulli_kernel_state = std::make_shared<BernoulliKernelState>(gen); return OpInterpUtil::Dispatch<Tensor>( - *bernoulli_op_, {x}, - OpExprInterpContext{.attrs = bernoulli_attrs, .state = bernoulli_kernel_state}); + *bernoulli_op_, {x}, OpExprInterpContext(bernoulli_attrs, bernoulli_kernel_state)); } private: diff --git a/oneflow/core/functional/impl/unary_functor.h b/oneflow/core/functional/impl/unary_functor.h index dac994aa7..a13f08d4b 100644 --- a/oneflow/core/functional/impl/unary_functor.h +++ b/oneflow/core/functional/impl/unary_functor.h @@ -46,7 +46,7 @@ class InplaceableUnaryFunctor { if (inplace) { std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1); outputs->at(0) = x; - JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get())); + JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get())); return outputs->at(0); } else { return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h index e88b48c52..29834f4db 100644 --- a/oneflow/core/functional/value_types.h +++ b/oneflow/core/functional/value_types.h @@ -30,6 +30,7 @@ class AttrMap; template<typename T> class Symbol; +class Device; class ParallelDesc; namespace cfg { @@ -88,6 +89,7 @@ enum ValueType { kGENERATOR_REF, kGENERATOR_MAYBE, kTENSOR_INDEX, + kDEVICE, kPARALLEL_DESC, kSBP_PARALLEL, kSBP_PARALLEL_LIST, @@ -141,6 +143,7 @@ VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR); VALUE_TYPE_OF_IMPL(std::shared_ptr<one::Generator>, kGENERATOR_REF); VALUE_TYPE_OF_IMPL(Maybe<one::Generator>, kGENERATOR_MAYBE); VALUE_TYPE_OF_IMPL(TensorIndex, kTENSOR_INDEX); +VALUE_TYPE_OF_IMPL(Symbol<Device>, kDEVICE); VALUE_TYPE_OF_IMPL(Symbol<ParallelDesc>, kPARALLEL_DESC); VALUE_TYPE_OF_IMPL(Symbol<cfg::SbpParallel>, kSBP_PARALLEL); VALUE_TYPE_OF_IMPL(std::vector<Symbol<cfg::SbpParallel>>, kSBP_PARALLEL_LIST); diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index e491262fc..81a60f162 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -159,6 +159,25 @@ Maybe<Symbol<Device>> ParallelDesc::GetDevice4CurrentProcessCtx(int64_t* paralle } } +Maybe<Symbol<Device>> GetDevice4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc, + int64_t* parallel_id) { + static thread_local HashMap<Symbol<ParallelDesc>, int64_t> parallel_desc2parallel_id; + static thread_local HashMap<Symbol<ParallelDesc>, Symbol<Device>> parallel_desc2device; + auto parallel_id_iter = parallel_desc2parallel_id.find(parallel_desc); + auto device_iter = parallel_desc2device.find(parallel_desc); + if (device_iter == parallel_desc2device.end()) { + CHECK_OR_RETURN(parallel_id_iter == parallel_desc2parallel_id.end()); + int64_t id_val = 0; + const auto& device = JUST(parallel_desc->GetDevice4CurrentProcessCtx(&id_val)); + parallel_id_iter = parallel_desc2parallel_id.emplace(parallel_desc, id_val).first; + device_iter = parallel_desc2device.emplace(parallel_desc, device).first; + } else { + CHECK_OR_RETURN(parallel_id_iter != parallel_desc2parallel_id.end()); + } + *parallel_id = parallel_id_iter->second; + return device_iter->second; +} + bool ParallelDesc::TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const { const auto& machine_iter = machine_id2device_id2parallel_id_.find(machine_id); diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index 127c5d128..221ea8ec9 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -131,6 +131,9 @@ class ParallelDesc final { std::shared_ptr<cfg::ParallelConf> cfg_parallel_conf_; }; +Maybe<Symbol<Device>> GetDevice4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc, + int64_t* parallel_id); + inline bool operator==(const ParallelConf& lhs, const ParallelConf& rhs) { return ParallelDesc(lhs) == ParallelDesc(rhs); } diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 59ab04844..8cb46e10c 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -768,12 +768,23 @@ Maybe<void> UserOp::InferParallelDistributionSignature( UserOpInferParallelDistributionFnContext ctx(this, parallel_distribution_signature, parallel_distribution_constraints, ParallelDistributionInferHint4Ibn); - return val_->parallel_distribution_infer_fn(&ctx); + JUST(val_->parallel_distribution_infer_fn(&ctx)); } else { - return Operator::InferParallelDistributionSignature( + JUST(Operator::InferParallelDistributionSignature( parallel_distribution_signature, parallel_distribution_constraints, parallel_desc, - ParallelDistributionInferHint4Ibn); + ParallelDistributionInferHint4Ibn)); } + std::string tick_bn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0); + if (std::find(input_bns().begin(), input_bns().end(), tick_bn) != input_bns().end()) { + auto* map = parallel_distribution_signature->mutable_bn_in_op2parallel_distribution(); + if (map->count(tick_bn) == 0) { + auto* sbp_list = (*map)[tick_bn].mutable_sbp_parallel(); + for (int i = 0; i < parallel_desc.hierarchy()->NumAxes(); ++i) { + sbp_list->Add()->mutable_broadcast_parallel(); + } + } + } + return Maybe<void>::Ok(); } Symbol<OperatorConf> UserOp::GetOpConfWithoutOpNameAndLbn() const { diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp index 2ec6d60f3..1b5457881 100644 --- a/oneflow/user/ops/constant_op.cpp +++ b/oneflow/user/ops/constant_op.cpp @@ -14,8 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/global.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { + +Maybe<void> InferConstantParallelDistribution(user_op::InferParallelDistributionFnContext* ctx); + REGISTER_NO_GRAD_USER_OP("constant") .Output("out") .SetOutputBufferNum(1) @@ -24,6 +30,7 @@ REGISTER_NO_GRAD_USER_OP("constant") .Attr<bool>("is_floating_value") .Attr<DataType>("dtype") .Attr<Shape>("shape") + .Attr<std::string>("nd_sbp") .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { Shape* out_shape = ctx->OutputShape("out", 0); const Shape& shape = ctx->Attr<Shape>("shape"); @@ -43,6 +50,20 @@ REGISTER_NO_GRAD_USER_OP("constant") auto dtype = ctx->Attr<DataType>("dtype"); *ctx->OutputDType("out", 0) = dtype; return Maybe<void>::Ok(); - }); + }) + .SetParallelDistributionInferFn(&InferConstantParallelDistribution); + +Maybe<void> InferConstantParallelDistribution(user_op::InferParallelDistributionFnContext* ctx) { + cfg::ParallelDistribution* out = ctx->ParallelDistribution4ArgNameAndIndex("out", 0); + if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { + const auto& pb_str = ctx->user_op_conf().attr<std::string>("nd_sbp"); + ParallelDistribution pb; + CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb)); + out->InitFromProto(pb); + } else { + out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); + } + return Maybe<void>::Ok(); +} } // namespace oneflow diff --git a/python/oneflow/framework/distribute.py b/python/oneflow/framework/distribute.py index 117ab34e7..3bdfae3b7 100644 --- a/python/oneflow/framework/distribute.py +++ b/python/oneflow/framework/distribute.py @@ -205,9 +205,7 @@ def is_multi_client(): return oneflow._oneflow_internal.IsMultiClient() -def split_sbp( - axis: int, -) -> oneflow._oneflow_internal.oneflow.core.job.sbp_parallel.SbpParallel: +def split_sbp(axis: int) -> oneflow._oneflow_internal.sbp.sbp: """Generate a split scheme in which op will be splitted at `axis`. Args: diff --git a/python/oneflow/nn/modules/constant.py b/python/oneflow/nn/modules/constant.py index f0de614c8..050213cea 100644 --- a/python/oneflow/nn/modules/constant.py +++ b/python/oneflow/nn/modules/constant.py @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ -from typing import Optional, Union +from typing import List, Optional, Union import oneflow as flow from oneflow.framework.tensor import register_tensor_op @@ -29,6 +29,10 @@ class _ConstantBase(Module): value: Union[float, int], dtype: Optional[flow.dtype], device: Union[flow.device, str] = None, + placement: flow.placement = None, + sbp: Union[ + flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp] + ] = None, requires_grad: bool = False, ) -> None: super().__init__() @@ -37,32 +41,63 @@ class _ConstantBase(Module): size, (int, tuple, flow.Size) ), "shape should be int or tuple int!" self.device = device + if isinstance(self.device, str): + self.device = flow.device(self.device) self.requires_grad = requires_grad size = _single(size) if dtype is None: dtype = flow.float32 - if device is None: - self.device = flow.device("cpu") + if placement is None: + if device is None: + self.device = flow.device("cpu") + else: + assert device is None + self.placement = placement + self.sbp = sbp + if placement is not None: + assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp + if isinstance(self.sbp, flow.sbp.sbp): + self.sbp = (self.sbp,) + else: + for elem in sbp: + assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp + assert len(self.sbp) == len(placement.hierarchy) + else: + assert sbp is None, "sbp: %s" % sbp self.shape = size self.value = value self.dtype = dtype def forward(self): - res = flow.F.constant(self.shape, self.value, self.dtype) - res = res.to(device=self.device) + if self.placement is not None: + res = flow.F.consistent_constant( + self.shape, self.value, self.dtype, self.placement, self.sbp, + ) + else: + res = flow.F.constant(self.shape, self.value, self.dtype, self.device,) res.requires_grad = self.requires_grad return res class Ones(_ConstantBase): - def __init__(self, size, dtype=None, device=None, requires_grad=False): - super().__init__(size, 1, dtype, device, requires_grad) + def __init__( + self, + size, + dtype=None, + device=None, + placement=None, + sbp=None, + requires_grad=False, + ): + super().__init__(size, 1, dtype, device, placement, sbp, requires_grad) def ones_op( size: Union[_size_any_t, flow.Size], dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, + placement: flow.placement = None, + sbp: flow._oneflow_internal.sbp.sbp = None, requires_grad: bool = False, ): """ @@ -73,7 +108,9 @@ def ones_op( size (an integer or tuple of integer values) – defining the shape of the output tensor. Can be \\ a variable number of arguments or a collection like a list or tuple. dtype (flow.dtype, optional) – the desired data type of returned tensor. - device (torch.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type + device (flow.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type + placement (flow.placement, optional) – the desired placement of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`. + sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional) – the desired sbp descriptor of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default: False. For example: @@ -86,18 +123,28 @@ def ones_op( tensor([1., 1., 1., 1., 1.], dtype=oneflow.float32) """ - return Ones(size, dtype, device, requires_grad)() + return Ones(size, dtype, device, placement, sbp, requires_grad)() class Zeros(_ConstantBase): - def __init__(self, size, dtype=None, device=None, requires_grad=False): - super().__init__(size, 0, dtype, device, requires_grad) + def __init__( + self, + size, + dtype=None, + device=None, + placement=None, + sbp=None, + requires_grad=False, + ): + super().__init__(size, 0, dtype, device, placement, sbp, requires_grad) def zeros_op( size: Union[_size_any_t, flow.Size], dtype: Optional[flow.dtype] = None, device: Union[flow.device, str, None] = None, + placement: flow.placement = None, + sbp: flow._oneflow_internal.sbp.sbp = None, requires_grad: bool = False, ): """ @@ -108,7 +155,9 @@ def zeros_op( size(an integer or tuple of integer values) - defining the shape of the output tensor. Can be \\ a variable number of arguments or a collection like a list or tuple. dtype (flow.dtype, optional) – the desired data type of returned tensor. - device (torch.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type + device (flow.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type + placement (flow.placement, optional) – the desired placement of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`. + sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional) – the desired sbp descriptor of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default: False. For example: @@ -121,7 +170,7 @@ def zeros_op( tensor([0., 0., 0., 0., 0.], dtype=oneflow.float32) """ - return Zeros(size, dtype, device, requires_grad)() + return Zeros(size, dtype, device, placement, sbp, requires_grad)() class ZerosLike(Module): @@ -192,10 +241,16 @@ class NewOnes(Module): size: Union[_size_any_t, flow.Size] = None, dtype: Optional[flow.dtype] = None, device: Union[flow.device, str] = None, + placement: flow.placement = None, + sbp: flow._oneflow_internal.sbp.sbp = None, requires_grad: bool = False, ): super().__init__() self.device = device + if isinstance(self.device, str): + self.device = flow.device(self.device) + self.placement = placement + self.sbp = sbp self.requires_grad = requires_grad if size != None: size = _single(size) @@ -206,59 +261,89 @@ class NewOnes(Module): new_size = self.size new_dtype = self.dtype new_device = self.device + new_placement = self.placement + new_sbp = self.sbp new_requires_grad = self.requires_grad if self.size is None: new_size = x.shape if self.dtype is None: new_dtype = x.dtype if self.device is None: - new_device = x.device + new_device = x.device if x.is_local else None + if self.placement is None: + new_placement = x.placement if x.is_consistent else None + if self.sbp is None: + new_sbp = x.sbp if x.is_consistent else None + if new_placement is not None: + assert self.device is None + assert new_sbp is not None assert isinstance( new_size, (int, tuple, flow.Size) ), f"size parameter not correct, please check!" assert isinstance( new_dtype, flow.dtype ), f"dtype parameter not correct, please check!" - assert isinstance( - new_device, (str, flow.device) - ), f"device parameter not correct, please check!" + if new_placement is not None: + assert isinstance( + new_placement, flow.placement + ), f"device parameter not correct, please check!" + assert isinstance( + new_sbp, flow.sbp.sbp + ), f"device parameter not correct, please check!" + else: + assert isinstance( + new_device, (str, flow.device) + ), f"device parameter not correct, please check!" assert isinstance( new_requires_grad, bool ), f"requires_grad parameter not correct, please check!" - res = flow.F.constant(new_size, 1.0, new_dtype) - res = res.to(new_device) + if self.placement is not None: + res = flow.F.consistent_constant( + new_size, 1.0, new_dtype, self.placement, self.sbp + ) + else: + res = flow.F.constant(new_size, 1.0, new_dtype, new_device) res.requires_grad = new_requires_grad return res @register_tensor_op("new_ones") -def new_ones_op(x, size=None, dtype=None, device=None, requires_grad=False): +def new_ones_op( + x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False +): """ - + Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same torch.dtype and torch.device as this tensor. Args: size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor. dtype (flow.dtype, optional): the desired type of returned tensor. Default: if None, same flow.dtype as this tensor. device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor. + placement (flow.placement, optional) – the desired placement of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`. + sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional) – the desired sbp descriptor of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - + For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow - + >>> x = flow.Tensor(np.ones((1, 2, 3))) >>> y = x.new_ones((2, 2)) >>> y tensor([[1., 1.], [1., 1.]], dtype=oneflow.float32) """ - return NewOnes(size=size, dtype=dtype, device=device, requires_grad=requires_grad)( - x - ) + return NewOnes( + size=size, + dtype=dtype, + device=device, + placement=placement, + sbp=sbp, + requires_grad=requires_grad, + )(x) if __name__ == "__main__": diff --git a/python/oneflow/sbp.py b/python/oneflow/sbp.py index 8eb354821..e652f7744 100644 --- a/python/oneflow/sbp.py +++ b/python/oneflow/sbp.py @@ -15,6 +15,8 @@ limitations under the License. """ import oneflow from oneflow.framework.distribute import split_sbp as split +import oneflow._oneflow_internal +sbp = oneflow._oneflow_internal.sbp.sbp broadcast = oneflow._oneflow_internal.sbp.broadcast() partial_sum = oneflow._oneflow_internal.sbp.partial_sum() diff --git a/python/oneflow/test/modules/test_constant.py b/python/oneflow/test/modules/test_constant.py index deaeb0606..84d33da60 100644 --- a/python/oneflow/test/modules/test_constant.py +++ b/python/oneflow/test/modules/test_constant.py @@ -85,15 +85,10 @@ def _test_zeros_like(test_case, device, shape): def _test_new_ones(test_case, device, shape): - x = flow.Tensor(np.ones(shape), device=flow.device(device)) + x = flow.ones(shape, device=flow.device("cpu")) y = x.new_ones(shape, device=device) test_case.assertTrue(x.dtype == y.dtype) - test_case.assertTrue(x.device == y.device) - test_case.assertTrue(x.requires_grad == y.requires_grad) - x = flow.Tensor(np.ones(shape), device=flow.device(device)) - y = x.new_ones(x.shape, device=device) - test_case.assertTrue(x.dtype == y.dtype) - test_case.assertTrue(x.device == y.device) + test_case.assertEqual(flow.device(device), y.device) test_case.assertTrue(x.requires_grad == y.requires_grad) x = flow.Tensor(np.ones(shape), device=flow.device(device)) x = x.new_ones(shape, device=device, requires_grad=True) @@ -104,6 +99,13 @@ def _test_new_ones(test_case, device, shape): @flow.unittest.skip_unless_1n1d() class TestConstantModule(flow.unittest.TestCase): + def test_consistent_naive(test_case): + placement = flow.placement("cpu", {0: [0]}) + sbp = (flow.sbp.broadcast,) + x = flow.ones((16, 16), placement=placement, sbp=sbp) + test_case.assertEqual(x.sbp, sbp) + test_case.assertEqual(x.placement, placement) + def test_cast(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py index 9b71af013..00f64934a 100644 --- a/tools/generate_functional_api.py +++ b/tools/generate_functional_api.py @@ -151,9 +151,10 @@ types_allowed = { "Shape", "Generator", "TensorIndex", - "ParallelDesc", - "SbpParallel", - "SbpParallelList", + "Device", + "Placement", + "Sbp", + "SbpList", } generic_type_aliases = { @@ -182,9 +183,10 @@ argument_type_aliases = { "Shape": "const Shape&", "Generator": "const std::shared_ptr<one::Generator>&", "TensorIndex": "const TensorIndex&", - "ParallelDesc": "const Symbol<ParallelDesc>&", - "SbpParallel": "const Symbol<cfg::SbpParallel>&", - "SbpParallelList": "const std::vector<Symbol<cfg::SbpParallel>>&", + "Device": "const Symbol<Device>&", + "Placement": "const Symbol<ParallelDesc>&", + "Sbp": "const Symbol<cfg::SbpParallel>&", + "SbpList": "const std::vector<Symbol<cfg::SbpParallel>>&", **generic_type_aliases, } @@ -205,9 +207,10 @@ optional_argument_type_aliases = { "Shape": "const Optional<Shape>&", "Generator": "const Optional<one::Generator>&", "TensorIndex": "const Optional<TensorIndex>&", - "ParallelDesc": "const Optional<Symbol<ParallelDesc>>&", - "SbpParallel": "const Optional<Symbol<SbpParallel>>&", - "SbpParallelList": "const Optional<std::vector<Symbol<cfg::SbpParallel>>>&", + "Device": "const Optional<Symbol<Device>>&", + "Placement": "const Optional<Symbol<ParallelDesc>>&", + "Sbp": "const Optional<Symbol<SbpParallel>>&", + "SbpList": "const Optional<std::vector<Symbol<cfg::SbpParallel>>>&", **{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()}, } -- GitLab