From bf4bdd62a37f31a5972fd7e7a66209c7056b5cb4 Mon Sep 17 00:00:00 2001
From: Li Xinqi <lixinqi2010@gmail.com>
Date: Fri, 30 Jul 2021 17:04:23 +0800
Subject: [PATCH] rebase (#5601)

* rebase

* check in gen py

* merge master and fix bugs

* address pr comments

* address pr comments

* auto format by CI

* functional python_arg

* auto format by CI

* remove unused files

* fix return type error on gcc 4.8.5

Signed-off-by: daquexian <daquexian566@gmail.com>

* auto format by CI

* fix return type error in xrt

Signed-off-by: daquexian <daquexian566@gmail.com>

* fix tick ibn sbp signature

* auto format by CI

Co-authored-by: tsai <jackalcooper@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: daquexian <daquexian566@gmail.com>
---
 oneflow/api/python/autograd/autograd.cpp      |   3 +-
 oneflow/api/python/framework/device.cpp       |   6 +-
 oneflow/api/python/framework/op_expr.cpp      |   3 +-
 oneflow/api/python/framework/tensor.cpp       |  16 +-
 oneflow/api/python/functional/python_arg.cpp  |   6 +
 .../api/python/symbol/placement_symbol.cpp    |   4 +-
 oneflow/api/python/symbol/sbp_symbol.cpp      |   5 +-
 oneflow/core/autograd/autograd_meta.cpp       |  36 ++++-
 oneflow/core/autograd/autograd_meta.h         |  12 +-
 oneflow/core/common/symbol.h                  | 122 +++++++++-------
 .../consistent_tensor_infer_cache.cpp         | 106 ++++++++++++--
 .../framework/consistent_tensor_infer_cache.h |  50 ++++++-
 oneflow/core/framework/device.cpp             |  19 +--
 oneflow/core/framework/op_interpreter.h       |  35 ++++-
 .../eager_consistent_op_interpreter.cpp       |  30 ++--
 .../eager_mirrored_op_interpreter.cpp         |  10 +-
 .../op_interpreter/op_interpreter_util.cpp    |  56 +++++--
 .../op_interpreter/op_interpreter_util.h      |  19 ++-
 oneflow/core/functional/functional_api.yaml   |   9 +-
 .../functional/impl/activation_functor.cpp    |   2 +-
 .../core/functional/impl/array_functor.cpp    |  67 ++++++++-
 oneflow/core/functional/impl/binary_functor.h |   2 +-
 oneflow/core/functional/impl/math_functor.cpp |   5 +-
 oneflow/core/functional/impl/nn_functor.cpp   |   2 +-
 .../core/functional/impl/random_functor.cpp   |   3 +-
 oneflow/core/functional/impl/unary_functor.h  |   2 +-
 oneflow/core/functional/value_types.h         |   3 +
 oneflow/core/job/parallel_desc.cpp            |  19 +++
 oneflow/core/job/parallel_desc.h              |   3 +
 oneflow/core/operator/user_op.cpp             |  17 ++-
 oneflow/user/ops/constant_op.cpp              |  23 ++-
 python/oneflow/framework/distribute.py        |   4 +-
 python/oneflow/nn/modules/constant.py         | 137 ++++++++++++++----
 python/oneflow/sbp.py                         |   2 +
 python/oneflow/test/modules/test_constant.py  |  16 +-
 tools/generate_functional_api.py              |  21 +--
 36 files changed, 681 insertions(+), 194 deletions(-)

diff --git a/oneflow/api/python/autograd/autograd.cpp b/oneflow/api/python/autograd/autograd.cpp
index a2386b5ec..44e429303 100644
--- a/oneflow/api/python/autograd/autograd.cpp
+++ b/oneflow/api/python/autograd/autograd.cpp
@@ -54,9 +54,8 @@ Maybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs,
       CHECK_OR_RETURN(IsScalarTensor(*outputs.at(i)))
           << "Grad can be implicitly created only for scalar outputs";
       const auto& ones_like = JUST(op_expr_helper::OnesLikeOp());
-      const auto& interpreter = JUST(one::OpInterpUtil::GetInterpreter());
       one::TensorTuple grad_output(1);
-      JUST(interpreter->Apply(*ones_like, one::TensorTuple{outputs.at(i)}, &grad_output));
+      JUST(one::OpInterpUtil::Dispatch(*ones_like, one::TensorTuple{outputs.at(i)}, &grad_output));
       gradients->at(i) = grad_output.at(0);
     } else {
       CHECK_OR_RETURN(*(outputs.at(i)->shape()) == *(out_grads.at(i)->shape()))
diff --git a/oneflow/api/python/framework/device.cpp b/oneflow/api/python/framework/device.cpp
index 842a72099..1d629f61f 100644
--- a/oneflow/api/python/framework/device.cpp
+++ b/oneflow/api/python/framework/device.cpp
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include <pybind11/pybind11.h>
+#include <pybind11/operators.h>
 #include "oneflow/api/python/common.h"
 #include "oneflow/api/python/framework/device.h"
 #include "oneflow/api/python/of_api_registry.h"
@@ -50,9 +51,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
       }))
       .def_property_readonly("type", [](const Symbol<Device>& d) { return d->type(); })
       .def_property_readonly("index", [](const Symbol<Device>& d) { return d->device_id(); })
-      .def("__eq__", [](const Symbol<Device>& d1, const Symbol<Device>& d2) { return *d1 == *d2; })
       .def("__str__", [](const Symbol<Device>& d) { return d->ToString(); })
-      .def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); });
+      .def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); })
+      .def(py::self == py::self)
+      .def(py::hash(py::self));
 }
 
 }  // namespace oneflow
diff --git a/oneflow/api/python/framework/op_expr.cpp b/oneflow/api/python/framework/op_expr.cpp
index b2c3e6240..e708f8daa 100644
--- a/oneflow/api/python/framework/op_expr.cpp
+++ b/oneflow/api/python/framework/op_expr.cpp
@@ -37,8 +37,7 @@ Maybe<one::TensorTuple> Interpret(const one::OpExpr& op, const one::TensorTuple&
       << "The operation requires " << op.input_size() << " inputs, but " << inputs.size()
       << " is given.";
   auto outputs = std::make_shared<one::TensorTuple>(op.output_size());
-  auto interperter = JUST(one::OpInterpUtil::GetInterpreter());
-  JUST(interperter->Apply(op, inputs, outputs.get(), attrs));
+  JUST(one::OpInterpUtil::Dispatch(op, inputs, outputs.get(), attrs));
   return outputs;
 }
 
diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp
index 0dda48d98..f0b2744af 100644
--- a/oneflow/api/python/framework/tensor.cpp
+++ b/oneflow/api/python/framework/tensor.cpp
@@ -251,6 +251,19 @@ bool ApiIsContiguous(const std::shared_ptr<Tensor>& tensor) {
   return IsContiguous(tensor).GetOrThrow();
 }
 
+Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) {
+  const auto& nd_sbp = JUST(tensor.parallel_distribution());
+  const auto& tuple = std::make_shared<py::tuple>(nd_sbp->sbp_parallel_size());
+  for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {
+    (*tuple)[i] = SymbolOf(nd_sbp->sbp_parallel(i));
+  }
+  return tuple;
+}
+
+py::tuple ApiTensorGetPyTupleOfSbp(const Tensor& tensor) {
+  return *TensorGetPyTupleOfSbp(tensor).GetPtrOrThrow();
+}
+
 Maybe<Tensor> NewTensor(py::args args, py::kwargs kwargs, const DType* desired_dtype,
                         bool treat_single_int_as_size) {
   Symbol<Device> device;
@@ -399,7 +412,8 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
       .def("_get_copy_mirrored_tensor_from_numpy_func_name",
            &ApiGetCopyMirroredTensorFromNumpyFuncName)
       // consistent tensor only
-      .def_property_readonly("placement", &TensorGetParallelDesc);
+      .def_property_readonly("placement", &TensorGetParallelDesc)
+      .def_property_readonly("sbp", &ApiTensorGetPyTupleOfSbp);
 
   auto nn = m.def_submodule("nn");
   py::class_<Parameter, std::shared_ptr<Parameter>, Tensor>(nn, "Parameter")
diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp
index 21656d8d0..662607774 100644
--- a/oneflow/api/python/functional/python_arg.cpp
+++ b/oneflow/api/python/functional/python_arg.cpp
@@ -21,6 +21,7 @@ limitations under the License.
 #include "oneflow/core/common/data_type.cfg.h"
 #include "oneflow/core/framework/attr_map.h"
 #include "oneflow/core/framework/dtype.h"
+#include "oneflow/core/framework/device.h"
 #include "oneflow/core/framework/tensor.h"
 #include "oneflow/core/framework/tensor_tuple.h"
 #include "oneflow/core/framework/user_op_attr.cfg.h"
@@ -168,6 +169,11 @@ Maybe<one::Generator> PythonArg::ObjectAs<one::Generator>() const {
   return *JUST(detail::cast<std::shared_ptr<one::Generator>>(Borrow()));
 }
 
+template<>
+Maybe<Symbol<Device>> PythonArg::ObjectAs<Symbol<Device>>() const {
+  return **JUST(detail::cast<std::shared_ptr<Symbol<Device>>>(Borrow()));
+}
+
 template<>
 Maybe<Symbol<ParallelDesc>> PythonArg::ObjectAs<Symbol<ParallelDesc>>() const {
   return **JUST(detail::cast<std::shared_ptr<Symbol<ParallelDesc>>>(Borrow()));
diff --git a/oneflow/api/python/symbol/placement_symbol.cpp b/oneflow/api/python/symbol/placement_symbol.cpp
index 4764292b8..bf255f494 100644
--- a/oneflow/api/python/symbol/placement_symbol.cpp
+++ b/oneflow/api/python/symbol/placement_symbol.cpp
@@ -239,7 +239,9 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
                              })
       .def_property_readonly("hierarchy", [](Symbol<ParallelDesc> p) { return p->hierarchy(); })
       .def("__str__", &PlacementSymbolExportUtil::PlacementSymbol2String)
-      .def("__repr__", &PlacementSymbolExportUtil::PlacementSymbol2String);
+      .def("__repr__", &PlacementSymbolExportUtil::PlacementSymbol2String)
+      .def(py::self == py::self)
+      .def(py::hash(py::self));
   m.def("AllDevicePlacement", &PlacementSymbolExportUtil::AllDevicePlacement);
 }
 
diff --git a/oneflow/api/python/symbol/sbp_symbol.cpp b/oneflow/api/python/symbol/sbp_symbol.cpp
index cff2ed0f2..f3c6a516e 100644
--- a/oneflow/api/python/symbol/sbp_symbol.cpp
+++ b/oneflow/api/python/symbol/sbp_symbol.cpp
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include <pybind11/pybind11.h>
+#include <pybind11/operators.h>
 #include "oneflow/api/python/of_api_registry.h"
 #include "oneflow/core/common/util.h"
 #include "oneflow/core/common/maybe.h"
@@ -88,7 +89,9 @@ ONEFLOW_API_PYBIND11_MODULE("sbp", m) {
   m.attr("max_split_axis") = kMaxSplitAxis;
   py::class_<Symbol<cfg::SbpParallel>, std::shared_ptr<Symbol<cfg::SbpParallel>>>(m, "sbp")
       .def("__str__", &SbpParallelSymbolToString)
-      .def("__repr__", &SbpParallelSymbolToString);
+      .def("__repr__", &SbpParallelSymbolToString)
+      .def(py::self == py::self)
+      .def(py::hash(py::self));
   m.def(
       "split", [](int axis) { return GetSplitSbpParallel(axis).GetOrThrow(); }, py::arg("axis"));
   m.def("broadcast", []() { return GetBroadcastSbpParallel().GetOrThrow(); });
diff --git a/oneflow/core/autograd/autograd_meta.cpp b/oneflow/core/autograd/autograd_meta.cpp
index b82dcef88..348a8969a 100644
--- a/oneflow/core/autograd/autograd_meta.cpp
+++ b/oneflow/core/autograd/autograd_meta.cpp
@@ -23,9 +23,41 @@ namespace oneflow {
 
 namespace one {
 
-TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {}
+TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {
+  if (TRY(tensor.device()).IsOk()) { device_ = CHECK_JUST(tensor.device()); }
+  if (TRY(tensor.parallel_desc()).IsOk()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); }
+  if (TRY(tensor.parallel_distribution()).IsOk()) {
+    parallel_distribution_ = CHECK_JUST(tensor.parallel_distribution());
+  }
+}
 
-Maybe<Tensor> TensorInfo::zeros() const { return functional::Constant(*shape_.get(), 0, dtype_); }
+Maybe<const std::vector<Symbol<cfg::SbpParallel>>&> GetSbpTuple(
+    Symbol<cfg::ParallelDistribution> parallel_distribution) {
+  static thread_local HashMap<Symbol<cfg::ParallelDistribution>,
+                              std::vector<Symbol<cfg::SbpParallel>>>
+      map;
+  auto iter = map.find(parallel_distribution);
+  if (iter == map.end()) {
+    std::vector<Symbol<cfg::SbpParallel>> sbp_tuple;
+    for (const auto& sbp_parallel : parallel_distribution->sbp_parallel()) {
+      sbp_tuple.push_back(SymbolOf(sbp_parallel));
+    }
+    iter = map.emplace(parallel_distribution, sbp_tuple).first;
+  }
+  return iter->second;
+}
+
+Maybe<Tensor> TensorInfo::zeros() const {
+  if (device_.has_value()) {
+    const auto& device = JUST(device_.value());
+    return functional::Constant(*shape_.get(), 0, dtype_, device);
+  } else {
+    const auto& parallel_desc = JUST(parallel_desc_.value());
+    const auto& parallel_distribution = JUST(parallel_distribution_.value());
+    const auto& sbp_tuple = JUST(GetSbpTuple(parallel_distribution));
+    return functional::ConsistentConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);
+  }
+}
 
 }  // namespace one
 
diff --git a/oneflow/core/autograd/autograd_meta.h b/oneflow/core/autograd/autograd_meta.h
index 47aea2542..e4894912e 100644
--- a/oneflow/core/autograd/autograd_meta.h
+++ b/oneflow/core/autograd/autograd_meta.h
@@ -21,11 +21,19 @@ limitations under the License.
 #include "oneflow/core/common/data_type.pb.h"
 #include "oneflow/core/framework/tensor_arg.h"
 #include "oneflow/core/common/util.h"
+#include "oneflow/core/common/symbol.h"
+#include "oneflow/core/common/optional.h"
 
 namespace oneflow {
 
 class Shape;
 
+class Device;
+class ParallelDesc;
+namespace cfg {
+class ParallelDistribution;
+}
+
 namespace one {
 
 class Tensor;
@@ -86,7 +94,9 @@ class TensorInfo final {
  private:
   std::shared_ptr<const Shape> shape_;
   DataType dtype_;
-  // TODO: Add device info
+  Optional<Symbol<Device>> device_;                                    // for local tensor
+  Optional<Symbol<ParallelDesc>> parallel_desc_;                       // for consistent tensor
+  Optional<Symbol<cfg::ParallelDistribution>> parallel_distribution_;  // for consistent tensor
 };
 
 }  // namespace one
diff --git a/oneflow/core/common/symbol.h b/oneflow/core/common/symbol.h
index b4e5a9be4..fe65bdf30 100644
--- a/oneflow/core/common/symbol.h
+++ b/oneflow/core/common/symbol.h
@@ -19,12 +19,17 @@ limitations under the License.
 #include <mutex>
 #include <memory>
 #include <unordered_map>
+#include <unordered_set>
 #include <glog/logging.h>
 #include "oneflow/core/common/type_traits.h"
+#include "oneflow/core/common/maybe.h"
 #include "oneflow/core/common/hash_eq_trait_ptr.h"
 
 namespace oneflow {
 
+template<typename T>
+class SymbolUtil;
+
 template<typename T>
 class Symbol final {
  public:
@@ -51,6 +56,8 @@ class Symbol final {
   std::shared_ptr<const T> shared_from_symbol() const;
 
  private:
+  template<typename SymbolT>
+  friend struct SymbolUtil;
   static const T* GetOrCreatePtr(const T& obj);
 
   const T* ptr_;
@@ -61,80 +68,83 @@ struct IsScalarType<Symbol<T>> final {
   static const bool value = true;
 };
 
-namespace sym {
 template<typename T>
-using SymbolTable = std::unordered_map<HashEqTraitPtr<const T>, std::shared_ptr<const T>>;
+struct SymbolUtil final {
+  using SymbolMap = std::unordered_map<HashEqTraitPtr<const T>, std::shared_ptr<const T>>;
 
-template<typename T>
-SymbolTable<T>* GlobalSymbolTable() {
-  static SymbolTable<T> symbol_table;
-  return &symbol_table;
-}
+  static SymbolMap* GlobalSymbolMap() {
+    static SymbolMap symbol_map;
+    return &symbol_map;
+  }
 
-template<typename T>
-std::mutex* GlobalSymbolTableMutex() {
-  static std::mutex mutex;
-  return &mutex;
-}
+  static std::mutex* GlobalSymbolMapMutex() {
+    static std::mutex mutex;
+    return &mutex;
+  }
 
-template<typename T>
-SymbolTable<T>* ThreadLocalSymbolTable() {
-  static thread_local SymbolTable<T> thread_local_symbol_table;
-  return &thread_local_symbol_table;
-}
+  static SymbolMap* ThreadLocalSymbolMap() {
+    static thread_local SymbolMap thread_local_symbol_map;
+    return &thread_local_symbol_map;
+  }
 
-template<typename T,
-         typename SymbolTable<T>::iterator (*GetIter4ObjectAndHashValue)(const T&, size_t)>
-std::shared_ptr<const T> LocalThreadGetOr(const T& obj) {
-  auto* thread_local_symbol_table = ThreadLocalSymbolTable<T>();
-  size_t hash_value = std::hash<T>()(obj);
-  HashEqTraitPtr<const T> obj_ptr_wraper(&obj, hash_value);
-  const auto& local_iter = thread_local_symbol_table->find(obj_ptr_wraper);
-  if (local_iter != thread_local_symbol_table->end()) { return local_iter->second; }
-  const auto& iter = GetIter4ObjectAndHashValue(obj, hash_value);
-  (*thread_local_symbol_table)[iter->first] = iter->second;
-  return iter->second;
-}
+  static std::unordered_set<const T*>* ThreadLocalSymbolPtrSet() {
+    static thread_local std::unordered_set<const T*> thread_local_symbol_ptr_set;
+    return &thread_local_symbol_ptr_set;
+  }
 
-template<typename T>
-typename SymbolTable<T>::iterator FindGlobalSymbol(const T& obj, size_t hash_value) {
-  HashEqTraitPtr<const T> new_obj_ptr_wraper(&obj, hash_value);
-  auto* symbol_table = GlobalSymbolTable<T>();
-  std::unique_lock<std::mutex> lock(*GlobalSymbolTableMutex<T>());
-  const auto& iter = symbol_table->find(new_obj_ptr_wraper);
-  CHECK(iter != symbol_table->end());
-  return iter;
-}
+  template<typename SymbolMap::iterator (*GetIter4ObjectAndHashValue)(const T&, size_t)>
+  static std::shared_ptr<const T> LocalThreadGetOr(const T& obj) {
+    auto* thread_local_symbol_map = ThreadLocalSymbolMap();
+    size_t hash_value = std::hash<T>()(obj);
+    HashEqTraitPtr<const T> obj_ptr_wraper(&obj, hash_value);
+    const auto& local_iter = thread_local_symbol_map->find(obj_ptr_wraper);
+    if (local_iter != thread_local_symbol_map->end()) { return local_iter->second; }
+    const auto& iter = GetIter4ObjectAndHashValue(obj, hash_value);
+    (*thread_local_symbol_map)[iter->first] = iter->second;
+    CHECK(ThreadLocalSymbolPtrSet()->emplace(iter->second.get()).second);
+    return iter->second;
+  }
 
-template<typename T>
-std::shared_ptr<const T> SharedFromObject(const T& obj) {
-  return LocalThreadGetOr<T, FindGlobalSymbol<T>>(obj);
-}
+  static typename SymbolMap::iterator FindGlobalSymbol(const T& obj, size_t hash_value) {
+    HashEqTraitPtr<const T> new_obj_ptr_wraper(&obj, hash_value);
+    auto* symbol_map = GlobalSymbolMap();
+    std::unique_lock<std::mutex> lock(*GlobalSymbolMapMutex());
+    const auto& iter = symbol_map->find(new_obj_ptr_wraper);
+    CHECK(iter != symbol_map->end());
+    return iter;
+  }
 
-template<typename T>
-typename SymbolTable<T>::iterator CreateGlobalSymbol(const T& obj, size_t hash_value) {
-  std::shared_ptr<const T> ptr(new T(obj));
-  HashEqTraitPtr<const T> new_obj_ptr_wraper(ptr.get(), hash_value);
-  std::unique_lock<std::mutex> lock(*GlobalSymbolTableMutex<T>());
-  return GlobalSymbolTable<T>()->emplace(new_obj_ptr_wraper, ptr).first;
-}
+  static std::shared_ptr<const T> SharedFromObject(const T& obj) {
+    return LocalThreadGetOr<FindGlobalSymbol>(obj);
+  }
 
-template<typename T>
-std::shared_ptr<const T> GetOrCreatePtr(const T& obj) {
-  return LocalThreadGetOr<T, CreateGlobalSymbol<T>>(obj);
-}
+  static typename SymbolMap::iterator CreateGlobalSymbol(const T& obj, size_t hash_value) {
+    std::shared_ptr<const T> ptr(new T(obj));
+    HashEqTraitPtr<const T> new_obj_ptr_wraper(ptr.get(), hash_value);
+    std::unique_lock<std::mutex> lock(*GlobalSymbolMapMutex());
+    return GlobalSymbolMap()->emplace(new_obj_ptr_wraper, ptr).first;
+  }
 
-}  // namespace sym
+  static std::shared_ptr<const T> GetOrCreatePtr(const T& obj) {
+    return LocalThreadGetOr<CreateGlobalSymbol>(obj);
+  }
+  static Maybe<Symbol<T>> GetSymbolByExistedRawPtr(const T* ptr) {
+    CHECK_GT_OR_RETURN(ThreadLocalSymbolPtrSet()->count(ptr), 0) << "ptr: " << ptr;
+    Symbol<T> symbol;
+    symbol.ptr_ = ptr;
+    return symbol;
+  }
+};
 
 template<typename T>
 std::shared_ptr<const T> Symbol<T>::shared_from_symbol() const {
   if (this->ptr_ == nullptr) { return std::shared_ptr<const T>(); }
-  return sym::SharedFromObject(*this->ptr_);
+  return SymbolUtil<T>::SharedFromObject(*this->ptr_);
 }
 
 template<typename T>
 const T* Symbol<T>::GetOrCreatePtr(const T& obj) {
-  return sym::GetOrCreatePtr(obj).get();
+  return SymbolUtil<T>::GetOrCreatePtr(obj).get();
 }
 
 template<typename T>
diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.cpp b/oneflow/core/framework/consistent_tensor_infer_cache.cpp
index 4fc2e82fd..07c4865e6 100644
--- a/oneflow/core/framework/consistent_tensor_infer_cache.cpp
+++ b/oneflow/core/framework/consistent_tensor_infer_cache.cpp
@@ -15,7 +15,6 @@ limitations under the License.
 */
 #include "oneflow/core/framework/consistent_tensor_infer_cache.h"
 #include "oneflow/core/framework/tensor_tuple.h"
-#include "oneflow/core/job/placement_scope.h"
 #include "oneflow/core/operator/operator.h"
 #include "oneflow/core/framework/to_string.h"
 #include "oneflow/core/framework/tensor.h"
@@ -44,8 +43,7 @@ void InputConsistentTensorMeta::assign(
 }
 
 size_t ConsistentTensorMetaInferArgs::hash_value() const {
-  size_t hash_value = std::hash<Symbol<PlacementScope>>()(placement_scope_);
-  hash_value ^= std::hash<AttrMap>()(attrs_);
+  size_t hash_value = std::hash<AttrMap>()(attrs_);
   const auto& tensor_meta_hash_functor = std::hash<InputConsistentTensorMeta>();
   for (const auto& tensor_meta : input_consistent_tensor_metas_) {
     HashCombine(&hash_value, tensor_meta_hash_functor(tensor_meta));
@@ -53,9 +51,22 @@ size_t ConsistentTensorMetaInferArgs::hash_value() const {
   return hash_value;
 }
 
+size_t SrcOpConsistentTensorMetaInferArgs::hash_value() const {
+  size_t hash_value = std::hash<AttrMap>()(attrs_);
+  hash_value ^= std::hash<Symbol<ParallelDesc>>()(parallel_desc_);
+  hash_value ^= std::hash<Symbol<cfg::ParallelDistribution>>()(parallel_distribution_);
+  return hash_value;
+}
+
 bool ConsistentTensorMetaInferArgs::operator==(const ConsistentTensorMetaInferArgs& other) const {
-  return this->input_consistent_tensor_metas_ == other.input_consistent_tensor_metas_
-         && this->placement_scope_ == other.placement_scope_ && this->attrs_ == other.attrs_;
+  return this->attrs_ == other.attrs_
+         && this->input_consistent_tensor_metas_ == other.input_consistent_tensor_metas_;
+}
+
+bool SrcOpConsistentTensorMetaInferArgs::operator==(
+    const SrcOpConsistentTensorMetaInferArgs& other) const {
+  return this->attrs_ == other.attrs_ && this->parallel_desc_ == other.parallel_desc_
+         && this->parallel_distribution_ == other.parallel_distribution_;
 }
 
 Maybe<void> ConsistentTensorMetaInferArgs::MakeParallelDistributionConstraints(
@@ -101,16 +112,25 @@ Maybe<void> ConsistentTensorMetaInferArgs::MakeParallelDistributionInferHints(
 }
 
 Maybe<ConsistentTensorMetaInferArgs> ConsistentTensorMetaInferArgs::New(
-    const TensorTuple& input_tensors, Symbol<PlacementScope> placement_scope,
-    const AttrMap& attrs) {
+    const AttrMap& attrs, const TensorTuple& input_tensors) {
   std::shared_ptr<ConsistentTensorMetaInferArgs> infer_args(new ConsistentTensorMetaInferArgs());
-  infer_args->input_consistent_tensor_metas_.resize(input_tensors.size());
-  infer_args->placement_scope_ = placement_scope;
   infer_args->attrs_ = attrs;
+  infer_args->input_consistent_tensor_metas_.resize(input_tensors.size());
   JUST(infer_args->InitInputConsistentTensorMetas(input_tensors));
   return infer_args;
 }
 
+Maybe<SrcOpConsistentTensorMetaInferArgs> SrcOpConsistentTensorMetaInferArgs::New(
+    const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc,
+    Symbol<cfg::ParallelDistribution> parallel_distribution) {
+  std::shared_ptr<SrcOpConsistentTensorMetaInferArgs> infer_args(
+      new SrcOpConsistentTensorMetaInferArgs());
+  infer_args->attrs_ = attrs;
+  infer_args->parallel_desc_ = parallel_desc;
+  infer_args->parallel_distribution_ = parallel_distribution;
+  return infer_args;
+}
+
 Maybe<void> ConsistentTensorMetaInferArgs::InitInputConsistentTensorMetas(
     const TensorTuple& input_tensors) {
   for (int i = 0; i < input_tensors.size(); ++i) {
@@ -132,16 +152,31 @@ Maybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs,
   return JUST(ConstructOp(op_conf, device_type));
 }
 
+Maybe<void> CheckInputParallelDescIdentical(const ConsistentTensorMetaInferArgs& infer_args) {
+  if (infer_args.input_consistent_tensor_metas().empty()) { return Maybe<void>::Ok(); }
+  const auto& first_parallel_desc =
+      infer_args.input_consistent_tensor_metas().begin()->tensor_meta()->parallel_desc();
+  for (const auto& input_meta : infer_args.input_consistent_tensor_metas()) {
+    CHECK_OR_RETURN(first_parallel_desc == input_meta.tensor_meta()->parallel_desc());
+  }
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> CheckIsDeviceSupportedByOp(const ParallelDesc& parallel_desc,
+                                       const std::string& op_type_name) {
+  if (IsCpuOnly(op_type_name)) { CHECK_EQ_OR_RETURN(parallel_desc.device_tag(), "cpu"); }
+  return Maybe<void>::Ok();
+}
+
 }  // namespace
 
 /* static */ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::Infer(
     const UserOpExpr& user_op_expr, const ConsistentTensorMetaInferArgs& infer_args) {
-  Symbol<ParallelDesc> parallel_desc;
-  {
-    // Get parallel description.
-    const auto& placement_scope = infer_args.placement_scope();
-    parallel_desc = JUST(placement_scope->GetParallelDesc(user_op_expr.op_type_name()));
-  }
+  CHECK_GT_OR_RETURN(infer_args.input_consistent_tensor_metas().size(), 0);
+  Symbol<ParallelDesc> parallel_desc =
+      infer_args.input_consistent_tensor_metas().at(0).tensor_meta()->parallel_desc();
+  JUST(CheckInputParallelDescIdentical(infer_args));
+  JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name()));
   std::vector<OpArgMutConsistentTensorMeta> output_mut_metas(user_op_expr.output_size());
   {
     // Infer OpArgMutConsistentTensorMeta.
@@ -194,6 +229,35 @@ Maybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs,
   return std::shared_ptr<const ConsistentTensorInferResult>(result);
 }
 
+/* static */ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::Infer(
+    const UserOpExpr& user_op_expr, const SrcOpConsistentTensorMetaInferArgs& infer_args) {
+  Symbol<ParallelDesc> parallel_desc = infer_args.parallel_desc();
+  JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name()));
+  std::vector<OpArgMutConsistentTensorMeta> output_mut_metas(user_op_expr.output_size());
+  {
+    // Infer OpArgMutConsistentTensorMeta.
+    const auto& GetInputTensorMeta = [](int32_t i) {
+      UNIMPLEMENTED();
+      return nullptr;
+    };
+    JUST(user_op_expr.InferLogicalShapeAndDType(
+        infer_args.attrs(), parallel_desc->device_tag(), GetInputTensorMeta,
+        [&](int32_t i) { return output_mut_metas.at(i).mut_tensor_meta(); }));
+  }
+  auto* result =
+      new ConsistentTensorInferResult(user_op_expr.input_size(), user_op_expr.output_size());
+  auto* output_metas = result->mut_output_tensor_metas();
+  for (int32_t i = 0; i < user_op_expr.output_size(); ++i) {
+    const auto& output_mut_meta = output_mut_metas.at(i);
+    const auto& shape = output_mut_meta.tensor_meta().shape_ptr();
+    DataType data_type = output_mut_meta.tensor_meta().data_type();
+    const auto& parallel_distribution = infer_args.parallel_distribution();
+    ConsistentTensorMeta tensor_meta(shape, data_type, parallel_distribution, parallel_desc);
+    output_metas->at(i) = SymbolOf(tensor_meta);
+  }
+  return std::shared_ptr<const ConsistentTensorInferResult>(result);
+}
+
 Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer(
     const ConsistentTensorMetaInferArgs& infer_args) {
   auto iter = cache_.find(infer_args);
@@ -206,5 +270,17 @@ Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer(
   return iter->second;
 }
 
+Maybe<const ConsistentTensorInferResult> ConsistentTensorInferCache::GetOrInfer(
+    const SrcOpConsistentTensorMetaInferArgs& infer_args) {
+  auto iter = src_op_cache_.find(infer_args);
+  if (iter == src_op_cache_.end()) {
+    const auto& user_op_expr = user_op_expr_.lock();
+    CHECK_OR_RETURN(static_cast<bool>(user_op_expr));
+    const auto& output_tensor_metas = JUST(Infer(*user_op_expr, infer_args));
+    iter = src_op_cache_.emplace(infer_args, output_tensor_metas).first;
+  }
+  return iter->second;
+}
+
 }  // namespace one
 }  // namespace oneflow
diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.h b/oneflow/core/framework/consistent_tensor_infer_cache.h
index 852441cff..0520f3511 100644
--- a/oneflow/core/framework/consistent_tensor_infer_cache.h
+++ b/oneflow/core/framework/consistent_tensor_infer_cache.h
@@ -30,7 +30,7 @@ namespace cfg {
 class ParallelDistribution;
 }
 
-class PlacementScope;
+class ParallelDesc;
 
 namespace one {
 
@@ -75,7 +75,6 @@ class ConsistentTensorMetaInferArgs final {
   const std::vector<InputConsistentTensorMeta>& input_consistent_tensor_metas() const {
     return input_consistent_tensor_metas_;
   }
-  Symbol<PlacementScope> placement_scope() const { return placement_scope_; }
   const AttrMap& attrs() const { return attrs_; }
 
   size_t hash_value() const;
@@ -93,17 +92,41 @@ class ConsistentTensorMetaInferArgs final {
       const UserOpExpr& user_op_expr, const std::vector<BlobDesc>& blob_descs,
       std::vector<ParallelDistributionInferHint>* hints) const;
 
-  static Maybe<ConsistentTensorMetaInferArgs> New(const TensorTuple& input_tensors,
-                                                  Symbol<PlacementScope> placement_scope,
-                                                  const AttrMap& attrs);
+  static Maybe<ConsistentTensorMetaInferArgs> New(const AttrMap& attrs,
+                                                  const TensorTuple& input_tensors);
 
  private:
   ConsistentTensorMetaInferArgs() = default;
   Maybe<void> InitInputConsistentTensorMetas(const TensorTuple& input_tensors);
 
+  AttrMap attrs_;
   std::vector<InputConsistentTensorMeta> input_consistent_tensor_metas_;
-  Symbol<PlacementScope> placement_scope_;
+};
+
+class SrcOpConsistentTensorMetaInferArgs final {
+ public:
+  SrcOpConsistentTensorMetaInferArgs(const SrcOpConsistentTensorMetaInferArgs&) = default;
+  SrcOpConsistentTensorMetaInferArgs(SrcOpConsistentTensorMetaInferArgs&&) = default;
+  ~SrcOpConsistentTensorMetaInferArgs() = default;
+
+  Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }
+  Symbol<cfg::ParallelDistribution> parallel_distribution() const { return parallel_distribution_; }
+  const AttrMap& attrs() const { return attrs_; }
+
+  size_t hash_value() const;
+
+  bool operator==(const SrcOpConsistentTensorMetaInferArgs& other) const;
+
+  static Maybe<SrcOpConsistentTensorMetaInferArgs> New(
+      const AttrMap& attrs, Symbol<ParallelDesc> parallel_desc,
+      Symbol<cfg::ParallelDistribution> parallel_distribution);
+
+ private:
+  SrcOpConsistentTensorMetaInferArgs() = default;
+
   AttrMap attrs_;
+  Symbol<ParallelDesc> parallel_desc_;
+  Symbol<cfg::ParallelDistribution> parallel_distribution_;
 };
 
 class OpArgMutConsistentTensorMeta final {
@@ -142,6 +165,13 @@ struct hash<oneflow::one::ConsistentTensorMetaInferArgs> final {
   }
 };
 
+template<>
+struct hash<oneflow::one::SrcOpConsistentTensorMetaInferArgs> final {
+  size_t operator()(const oneflow::one::SrcOpConsistentTensorMetaInferArgs& val) const {
+    return val.hash_value();
+  }
+};
+
 }  // namespace std
 
 namespace oneflow {
@@ -185,9 +215,17 @@ class ConsistentTensorInferCache final {
   static Maybe<const ConsistentTensorInferResult> Infer(
       const UserOpExpr& user_op_expr, const ConsistentTensorMetaInferArgs& infer_args);
 
+  Maybe<const ConsistentTensorInferResult> GetOrInfer(
+      const SrcOpConsistentTensorMetaInferArgs& infer_args);
+
+  static Maybe<const ConsistentTensorInferResult> Infer(
+      const UserOpExpr& user_op_expr, const SrcOpConsistentTensorMetaInferArgs& infer_args);
+
  private:
   std::weak_ptr<const UserOpExpr> user_op_expr_;
   HashMap<ConsistentTensorMetaInferArgs, std::shared_ptr<const ConsistentTensorInferResult>> cache_;
+  HashMap<SrcOpConsistentTensorMetaInferArgs, std::shared_ptr<const ConsistentTensorInferResult>>
+      src_op_cache_;
 };
 
 }  // namespace one
diff --git a/oneflow/core/framework/device.cpp b/oneflow/core/framework/device.cpp
index b34808236..e4e1d3ec5 100644
--- a/oneflow/core/framework/device.cpp
+++ b/oneflow/core/framework/device.cpp
@@ -61,20 +61,21 @@ Maybe<void> Device::Init() {
 }
 
 /* static */ Maybe<Symbol<Device>> Device::New(const std::string& type, int64_t device_id) {
-  Device device(type, device_id);
-  JUST(device.Init());
-  return SymbolOf(device);
+  return ThreadLocalGetOrNew(type, device_id);
 }
 
 /* static */ Maybe<Symbol<Device>> Device::ThreadLocalGetOrNew(const std::string& type,
                                                                int64_t device_id) {
   CHECK_GE_OR_RETURN(device_id, 0);
-  static thread_local HashMap<std::string, std::vector<Symbol<Device>>> type2device_id2device;
-  auto* vec = &type2device_id2device[type];
-  if (vec->size() <= device_id) { vec->resize(device_id + 1); }
-  auto* pptr = &vec->at(device_id);
-  if (!*pptr) { *pptr = JUST(New(type, device_id)); }
-  return *pptr;
+  static thread_local HashMap<std::string, HashMap<int64_t, Symbol<Device>>> map;
+  auto* device_id2symbol = &map[type];
+  auto iter = device_id2symbol->find(device_id);
+  if (iter == device_id2symbol->end()) {
+    Device device(type, device_id);
+    JUST(device.Init());
+    iter = device_id2symbol->emplace(device_id, SymbolOf(device)).first;
+  }
+  return iter->second;
 }
 
 /* static */ Maybe<Symbol<Device>> Device::New(const std::string& type) {
diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h
index 3d7f6a7de..abaa45aba 100644
--- a/oneflow/core/framework/op_interpreter.h
+++ b/oneflow/core/framework/op_interpreter.h
@@ -21,8 +21,16 @@ limitations under the License.
 #include "oneflow/core/framework/tensor.h"
 #include "oneflow/core/framework/tensor_tuple.h"
 #include "oneflow/core/framework/op_kernel.h"
+#include "oneflow/core/common/optional.h"
 
 namespace oneflow {
+
+class Device;
+class ParallelDesc;
+namespace cfg {
+class ParallelDistribution;
+}
+
 namespace one {
 
 class OpExprInterpState {
@@ -43,7 +51,24 @@ class OpExprInterpState {
 };
 
 struct OpExprInterpContext {
+  OpExprInterpContext(const AttrMap& attrs_arg) : attrs(attrs_arg) {}
+  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<Device> device_arg)
+      : attrs(attrs_arg), device(device_arg) {}
+  OpExprInterpContext(const AttrMap& attrs_arg, std::shared_ptr<user_op::OpKernelState> state_arg)
+      : attrs(attrs_arg), state(state_arg) {}
+  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<Device> device_arg,
+                      std::shared_ptr<user_op::OpKernelState> state_arg)
+      : attrs(attrs_arg), device(device_arg), state(state_arg) {}
+  OpExprInterpContext(const AttrMap& attrs_arg, Symbol<ParallelDesc> parallel_desc_arg,
+                      Symbol<cfg::ParallelDistribution> parallel_distribution_arg)
+      : attrs(attrs_arg),
+        parallel_desc(parallel_desc_arg),
+        parallel_distribution(parallel_distribution_arg) {}
+
   AttrMap attrs;
+  Optional<Symbol<Device>> device;                                    // for local op
+  Optional<Symbol<ParallelDesc>> parallel_desc;                       // for consistent op
+  Optional<Symbol<cfg::ParallelDistribution>> parallel_distribution;  // for consistent op
   std::shared_ptr<user_op::OpKernelState> state;
 };
 
@@ -54,7 +79,7 @@ class OpExprInterpreter {
 
   Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs,
                     const AttrMap& attrs) const {
-    return Apply(op, inputs, outputs, OpExprInterpContext{attrs, nullptr});
+    return Apply(op, inputs, outputs, OpExprInterpContext(attrs));
   }
 
   Maybe<void> Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const {
@@ -92,7 +117,7 @@ class LazyInterpreter : public OpExprInterpreter {
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
                     const AttrMap& attrs) const {
-    return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr});
+    return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs));
   }
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
@@ -113,7 +138,7 @@ class EagerInterpreter : public OpExprInterpreter {
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
                     const AttrMap& attrs) const {
-    return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr});
+    return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs));
   }
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
@@ -156,11 +181,11 @@ class AutogradInterpreter {
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
                     const AttrMap& attrs) const {
-    return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr});
+    return Apply(op_expr, inputs, outputs, OpExprInterpContext(attrs));
   }
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) const {
-    return Apply(op_expr, inputs, outputs, OpExprInterpContext{});
+    return Apply(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{}));
   }
 
   Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp
index 99f185ad4..d42c5b4bd 100644
--- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp
+++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp
@@ -26,7 +26,6 @@ limitations under the License.
 #include "oneflow/core/framework/tensor_name_scope.h"
 #include "oneflow/core/framework/tensor_tuple.h"
 #include "oneflow/core/framework/consistent_tensor_infer_cache.h"
-#include "oneflow/core/job/placement_scope.h"
 #include "oneflow/core/eager/foreign_boxing_util.h"
 #include "oneflow/core/operator/operator.h"
 #include "oneflow/user/kernels/stateful_local_opkernel.h"
@@ -36,6 +35,12 @@ namespace one {
 
 namespace {
 
+Maybe<Symbol<ParallelDesc>> GetParallelDesc(const TensorTuple& inputs,
+                                            const OpExprInterpContext& ctx) {
+  if (!inputs.empty()) { return inputs.at(0)->parallel_desc(); }
+  return ctx.parallel_desc.value();
+}
+
 std::string GetDynamicOpConsistentFailedDebugString(const UserOpExpr& user_op_expr,
                                                     const StatefulLocalOpKernel& kernel) {
   CHECK(!kernel.output_tuple_indexes4mut2_obns().empty());
@@ -58,16 +63,19 @@ std::string GetDynamicOpConsistentFailedDebugString(const UserOpExpr& user_op_ex
 Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
                       TensorTuple* outputs, const OpExprInterpContext& ctx) {
   CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size());
-  const auto& placement_scope = JUST(GetCurrentScope())->placement_scope();
-  const auto& infer_args =
-      JUST(ConsistentTensorMetaInferArgs::New(inputs, placement_scope, ctx.attrs));
-  const auto& result =
-      JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
+  const auto& parallel_desc = JUST(GetParallelDesc(inputs, ctx));
+  std::shared_ptr<const ConsistentTensorInferResult> result;
+  if (inputs.empty()) {
+    const auto& infer_args = JUST(SrcOpConsistentTensorMetaInferArgs::New(
+        ctx.attrs, parallel_desc, JUST(ctx.parallel_distribution.value())));
+    result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
+  } else {
+    const auto& infer_args = JUST(ConsistentTensorMetaInferArgs::New(ctx.attrs, inputs));
+    result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
+  }
   const auto& output_tensor_metas = result->output_tensor_metas();
-  const auto& parallel_desc =
-      JUST(placement_scope->GetParallelDesc(user_op_expr.op_type_name())).shared_from_symbol();
   int64_t parallel_id = -1;
-  const auto& device = JUST(parallel_desc->GetDevice4CurrentProcessCtx(&parallel_id));
+  const auto& device = JUST(GetDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
   using TensorImpl = EagerConsistentTensorImpl;
   TensorImpl::NewMethod New =
       (device ? &TensorImpl::NewWithPhyTensor : &TensorImpl::NewWithoutPhyTensor);
@@ -97,7 +105,7 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
   const auto& instr_type_name = JUST(GetLocalCallInstructionName(parallel_desc->device_tag()));
   JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
     return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects,
-                                      ctx, parallel_desc, instr_type_name);
+                                      ctx, parallel_desc.shared_from_symbol(), instr_type_name);
   }));
   return Maybe<void>::Ok();
 }
@@ -105,7 +113,7 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
 Maybe<void> EagerConsistentInterpreter::ApplyImpl(const UserOpExpr& op_expr,
                                                   const TensorTuple& inputs, TensorTuple* outputs,
                                                   const OpExprInterpContext& ctx) const {
-  OF_UNIMPLEMENTED();
+  return Interpret(op_expr, inputs, outputs, ctx);
 }
 
 Maybe<void> EagerConsistentInterpreter::ApplyImpl(const VariableOpExpr& op_expr,
diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp
index 86ada7f46..728ab0880 100644
--- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp
+++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp
@@ -37,7 +37,10 @@ namespace one {
 
 namespace {
 
-Maybe<Symbol<Device>> GetDefaultDevice() { return Device::New("cpu", 0); }
+Maybe<Symbol<Device>> GetDefaultDevice(const OpExprInterpContext& ctx) {
+  if (ctx.device.has_value()) { return ctx.device.value(); }
+  return Device::New("cpu", 0);
+}
 
 Maybe<EagerMirroredTensorImpl*> TensorImpl4Tensor(const std::shared_ptr<Tensor>& tensor) {
   CHECK_OR_RETURN(static_cast<bool>(tensor));
@@ -145,8 +148,7 @@ Maybe<void> RunEmptyOp(TensorTuple* outputs) {
   const auto& device = tensor_impl->device();
   const auto empty_expr = JUST(op_expr_helper::EmptyOp(*shape, data_type));
   std::shared_ptr<TensorTuple> inputs = std::make_shared<TensorTuple>();
-  JUST(NaiveInterpret(*empty_expr, *inputs, device, outputs,
-                      OpExprInterpContext{AttrMap{}, nullptr}));
+  JUST(NaiveInterpret(*empty_expr, *inputs, device, outputs, OpExprInterpContext(AttrMap{})));
   return Maybe<void>::Ok();
 }
 
@@ -155,7 +157,7 @@ static Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTu
   CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size());
   Symbol<Device> default_device;
   if (inputs.empty()) {
-    default_device = JUST(GetDefaultDevice());
+    default_device = JUST(GetDefaultDevice(ctx));
   } else {
     default_device = JUST(inputs.at(0)->device());
   }
diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp
index eaf0d20ac..f855e295e 100644
--- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp
+++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp
@@ -45,30 +45,60 @@ std::shared_ptr<AutogradInterpreter> BuildLazyInterpreter() {
   return std::make_shared<AutogradInterpreter>(internal);
 }
 
-}  // namespace
-
-/* static */ Maybe<AutogradInterpreter> OpInterpUtil::GetInterpreter() {
+Maybe<AutogradInterpreter> GetInterpreter(const TensorTuple& inputs,
+                                          const OpExprInterpContext& ctx) {
   static const auto& g_lazy_interpreter = BuildLazyInterpreter();
   static const auto& g_eager_consistent_interpreter = BuildEagerInterpreter(/*is_mirrored=*/false);
   static const auto& g_eager_mirrored_interpreter = BuildEagerInterpreter(/*is_mirrored=*/true);
   if (!LazyMode::is_enabled()) {
-    const auto& session = JUST(GetDefaultSession());
-    bool is_mirrored_strategy_enabled = session->is_mirrored_strategy_enabled_stack()->empty()
-                                        || JUST(session->IsMirroredStrategyEnabled());
-    if (is_mirrored_strategy_enabled) {
-      return g_eager_mirrored_interpreter;
+    if (inputs.empty()) {
+      if (ctx.parallel_desc.has_value()) {
+        JUST(ctx.parallel_distribution.value());
+        CHECK_OR_RETURN(!ctx.device.has_value());
+        return g_eager_consistent_interpreter;
+      } else {
+        CHECK_OR_RETURN(!ctx.parallel_distribution.has_value());
+        return g_eager_mirrored_interpreter;
+      }
     } else {
-      return g_eager_consistent_interpreter;
+      if (inputs.at(0)->is_consistent()) {
+        if (inputs.size() == 1) {
+          // do nothing
+        } else if (inputs.size() == 2) {
+          CHECK_OR_RETURN(inputs.at(1)->is_consistent());  // unroll loop for efficiency
+        } else if (inputs.size() == 3) {
+          CHECK_OR_RETURN(inputs.at(1)->is_consistent());  // unroll loop for efficiency
+          CHECK_OR_RETURN(inputs.at(2)->is_consistent());  // unroll loop for efficiency
+        } else {
+          for (const auto& tensor : inputs) { CHECK_OR_RETURN(tensor->is_consistent()); }
+        }
+        return g_eager_consistent_interpreter;
+      } else {
+        if (inputs.size() == 1) {
+          // do nothing
+        } else if (inputs.size() == 2) {
+          CHECK_OR_RETURN(inputs.at(1)->is_local());  // unroll loop for efficiency
+        } else if (inputs.size() == 3) {
+          CHECK_OR_RETURN(inputs.at(1)->is_local());  // unroll loop for efficiency
+          CHECK_OR_RETURN(inputs.at(2)->is_local());  // unroll loop for efficiency
+        } else {
+          for (const auto& tensor : inputs) { CHECK_OR_RETURN(tensor->is_local()); }
+        }
+        return g_eager_mirrored_interpreter;
+      }
     }
+    UNIMPLEMENTED_THEN_RETURN();
   }
   return g_lazy_interpreter;
 }
 
+}  // namespace
+
 template<>
 /* static */ Maybe<TensorTuple> OpInterpUtil::Dispatch<TensorTuple>(
     const OpExpr& op_expr, const TensorTuple& inputs, const OpExprInterpContext& ctx) {
   auto outputs = std::make_shared<TensorTuple>(op_expr.output_size());
-  JUST(JUST(GetInterpreter())->Apply(op_expr, inputs, outputs.get(), ctx));
+  JUST(Dispatch(op_expr, inputs, outputs.get(), ctx));
   return outputs;
 }
 
@@ -79,6 +109,12 @@ template<>
   return JUST(Dispatch<TensorTuple>(op_expr, inputs, ctx))->at(0);
 }
 
+/* static */ Maybe<void> OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,
+                                                TensorTuple* outputs,
+                                                const OpExprInterpContext& ctx) {
+  return JUST(GetInterpreter(inputs, ctx))->Apply(op_expr, inputs, outputs, ctx);
+}
+
 /* static */ Maybe<cfg::OpAttribute> OpInterpUtil::AddOpAndInferOpAttribute(
     const OperatorConf& op_conf, const bool is_mirrored_strategy_enabled) {
   std::shared_ptr<OpAttribute> op_attribute = JUST([&]() -> Maybe<OpAttribute> {
diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.h b/oneflow/core/framework/op_interpreter/op_interpreter_util.h
index 0f24c020a..b080ba14f 100644
--- a/oneflow/core/framework/op_interpreter/op_interpreter_util.h
+++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.h
@@ -31,22 +31,33 @@ namespace one {
 
 class OpInterpUtil {
  public:
-  static Maybe<AutogradInterpreter> GetInterpreter();
-
   template<typename T>
   static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs) {
-    return Dispatch<T>(op_expr, inputs, OpExprInterpContext{attrs, nullptr});
+    return Dispatch<T>(op_expr, inputs, OpExprInterpContext(attrs));
   }
 
   template<typename T>
   static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs) {
-    return Dispatch<T>(op_expr, inputs, OpExprInterpContext{AttrMap{}, nullptr});
+    return Dispatch<T>(op_expr, inputs, OpExprInterpContext(AttrMap{}));
   }
 
   template<typename T>
   static Maybe<T> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,
                            const OpExprInterpContext& ctx);
 
+  static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,
+                              TensorTuple* outputs, const AttrMap& attrs) {
+    return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(attrs));
+  }
+
+  static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,
+                              TensorTuple* outputs) {
+    return Dispatch(op_expr, inputs, outputs, OpExprInterpContext(AttrMap{}));
+  }
+
+  static Maybe<void> Dispatch(const OpExpr& op_expr, const TensorTuple& inputs,
+                              TensorTuple* outputs, const OpExprInterpContext& ctx);
+
   static Maybe<cfg::OpAttribute> AddOpAndInferOpAttribute(const OperatorConf& op_conf,
                                                           const bool is_mirrored_strategy_enabled);
 
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index d30e68d95..276ed2e10 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -16,7 +16,8 @@
 # {
 #   "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool",
 #   "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList",
-#   "BoolList", "DataType", "Shape", "Generator", "TensorIndex"
+#   "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement",
+#   "Sbp", "SbpList"
 # }
 
 - name: "add_n"
@@ -264,7 +265,11 @@
   bind_python: True
 
 - name: "constant"
-  signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype)"
+  signature: "Tensor Constant(*, Shape shape, Scalar value, DataType dtype, Device device=None)"
+  bind_python: True
+
+- name: "consistent_constant"
+  signature: "Tensor ConsistentConstant(*, Shape shape, Scalar value, DataType dtype, Placement placement, SbpList sbp_tuple)"
   bind_python: True
 
 - name: "zeros_like"
diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp
index c12981038..06af6ef03 100644
--- a/oneflow/core/functional/impl/activation_functor.cpp
+++ b/oneflow/core/functional/impl/activation_functor.cpp
@@ -42,7 +42,7 @@ class ReluFunctor {
     if (inplace) {
       std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
       outputs->at(0) = x;
-      JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get(), AttrMap{}));
+      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), AttrMap{}));
       return outputs->at(0);
     } else {
       return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index c1394ca3d..f40ebc64e 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -25,6 +25,11 @@ limitations under the License.
 #include "oneflow/core/functional/impl/common.h"
 #include "oneflow/core/functional/impl/unary_functor.h"
 #include "oneflow/core/functional/scalar.h"
+#include "oneflow/core/job/parallel_desc.h"
+#include "oneflow/core/job/global_for.h"
+#include "oneflow/core/common/global.h"
+#include "oneflow/core/common/optional.h"
+#include "oneflow/core/common/protobuf.h"
 
 namespace oneflow {
 namespace one {
@@ -32,10 +37,57 @@ namespace functional {
 
 namespace impl {
 
+class ConsistentConstantFunctor {
+ public:
+  ConsistentConstantFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build());
+  }
+  Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const DataType& dtype,
+                           const Symbol<ParallelDesc>& placement,
+                           const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<Shape>("shape", shape));
+    JUST(attrs.SetAttr<DataType>("dtype", dtype));
+    if (IsIntegralDataType(dtype)) {
+      JUST(attrs.SetAttr<bool>("is_floating_value", false));
+      JUST(attrs.SetAttr<int64_t>("integer_value", JUST(value.As<int64_t>())));
+    } else {
+      JUST(attrs.SetAttr<bool>("is_floating_value", true));
+      JUST(attrs.SetAttr<double>("floating_value", JUST(value.As<double>())));
+    }
+    const auto& parallel_distribution = JUST(MakeParallelDistribution(sbp_tuple));
+    if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
+      JUST(attrs.SetAttr<std::string>("nd_sbp", parallel_distribution->DebugString()));
+    }
+    return OpInterpUtil::Dispatch<Tensor>(
+        *op_, {}, OpExprInterpContext(attrs, placement, parallel_distribution));
+  }
+
+  Maybe<Symbol<cfg::ParallelDistribution>> MakeParallelDistribution(
+      const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
+    static thread_local std::map<std::vector<Symbol<cfg::SbpParallel>>,
+                                 Symbol<cfg::ParallelDistribution>>
+        map;
+    auto iter = map.find(sbp_tuple);
+    if (iter == map.end()) {
+      cfg::ParallelDistribution parallel_distribution;
+      for (const auto& sbp_parallel : sbp_tuple) {
+        *parallel_distribution.mutable_sbp_parallel()->Add() = *sbp_parallel;
+      }
+      iter = map.emplace(sbp_tuple, SymbolOf(parallel_distribution)).first;
+    }
+    return iter->second;
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
 class ConstantFunctor {
  public:
   ConstantFunctor() { op_ = CHECK_JUST(one::OpBuilder("constant").Output("out").Build()); }
-  Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const DataType& dtype) const {
+  Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const DataType& dtype,
+                           const Optional<Symbol<Device>>& device) const {
     MutableAttrMap attrs;
     JUST(attrs.SetAttr<Shape>("shape", shape));
     JUST(attrs.SetAttr<DataType>("dtype", dtype));
@@ -46,7 +98,17 @@ class ConstantFunctor {
       JUST(attrs.SetAttr<bool>("is_floating_value", true));
       JUST(attrs.SetAttr<double>("floating_value", JUST(value.As<double>())));
     }
-    return OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs);
+    {
+      ParallelDistribution parallel_distribution;
+      parallel_distribution.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
+      JUST(attrs.SetAttr<std::string>("nd_sbp", PbMessage2TxtString(parallel_distribution)));
+    }
+    if (device.has_value()) {
+      Symbol<Device> device_symbol = JUST(device.value());
+      return OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, device_symbol));
+    } else {
+      return OpInterpUtil::Dispatch<Tensor>(*op_, {}, attrs);
+    }
   }
 
  private:
@@ -953,6 +1015,7 @@ class TensorSetItemFunctor {
 }  // namespace impl
 
 ONEFLOW_FUNCTION_LIBRARY(m) {
+  m.add_functor<impl::ConsistentConstantFunctor>("ConsistentConstant");
   m.add_functor<impl::ConstantFunctor>("Constant");
   m.add_functor<impl::ZerosLikeFunctor>("ZerosLike");
   m.add_functor<impl::OnesLikeFunctor>("OnesLike");
diff --git a/oneflow/core/functional/impl/binary_functor.h b/oneflow/core/functional/impl/binary_functor.h
index 7be347c82..46e7bffb2 100644
--- a/oneflow/core/functional/impl/binary_functor.h
+++ b/oneflow/core/functional/impl/binary_functor.h
@@ -48,7 +48,7 @@ class InplaceableBinaryFunctor {
     if (inplace) {
       std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
       outputs->at(0) = x;
-      JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x, y}, outputs.get()));
+      JUST(OpInterpUtil::Dispatch(*op_, {x, y}, outputs.get()));
       return outputs->at(0);
     } else {
       return OpInterpUtil::Dispatch<Tensor>(*op_, {x, y});
diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp
index 079eba44f..a451ccbef 100644
--- a/oneflow/core/functional/impl/math_functor.cpp
+++ b/oneflow/core/functional/impl/math_functor.cpp
@@ -50,8 +50,7 @@ class AddNFunctor {
       if (i == 0 && inplace) {
         std::shared_ptr<TensorTuple> outs = std::make_shared<TensorTuple>(1);
         outs->at(0) = partial_inputs.at(0);
-        JUST(JUST(OpInterpUtil::GetInterpreter())
-                 ->Apply(*op_.at(size - 1), partial_inputs, outs.get()));
+        JUST(OpInterpUtil::Dispatch(*op_.at(size - 1), partial_inputs, outs.get()));
         outputs.push_back(outs->at(0));
       } else {
         outputs.push_back(JUST(OpInterpUtil::Dispatch<Tensor>(*op_.at(size - 1), partial_inputs)));
@@ -87,7 +86,7 @@ class ScalarAddFunctor {
     if (inplace) {
       std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
       outputs->at(0) = x;
-      JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get(), attrs));
+      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));
       return outputs->at(0);
     } else {
       return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp
index d00eeb783..29ba46df2 100644
--- a/oneflow/core/functional/impl/nn_functor.cpp
+++ b/oneflow/core/functional/impl/nn_functor.cpp
@@ -489,7 +489,7 @@ class DropoutFunctor {
 
     const auto& mask = JUST(OpInterpUtil::Dispatch<Tensor>(
         *random_mask_like_op_, {x},
-        OpExprInterpContext{.attrs = random_mask_like_attrs, .state = random_mask_like_state}));
+        OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));
     float scale = 1.0;
     if (p != 1.0) { scale = 1.0 / (1.0 - p); }
     MutableAttrMap dropout_attrs;
diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp
index f630bd28c..5216dccd0 100644
--- a/oneflow/core/functional/impl/random_functor.cpp
+++ b/oneflow/core/functional/impl/random_functor.cpp
@@ -55,8 +55,7 @@ class BernoulliFunctor {
     const auto& bernoulli_kernel_state = std::make_shared<BernoulliKernelState>(gen);
 
     return OpInterpUtil::Dispatch<Tensor>(
-        *bernoulli_op_, {x},
-        OpExprInterpContext{.attrs = bernoulli_attrs, .state = bernoulli_kernel_state});
+        *bernoulli_op_, {x}, OpExprInterpContext(bernoulli_attrs, bernoulli_kernel_state));
   }
 
  private:
diff --git a/oneflow/core/functional/impl/unary_functor.h b/oneflow/core/functional/impl/unary_functor.h
index dac994aa7..a13f08d4b 100644
--- a/oneflow/core/functional/impl/unary_functor.h
+++ b/oneflow/core/functional/impl/unary_functor.h
@@ -46,7 +46,7 @@ class InplaceableUnaryFunctor {
     if (inplace) {
       std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
       outputs->at(0) = x;
-      JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get()));
+      JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get()));
       return outputs->at(0);
     } else {
       return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h
index e88b48c52..29834f4db 100644
--- a/oneflow/core/functional/value_types.h
+++ b/oneflow/core/functional/value_types.h
@@ -30,6 +30,7 @@ class AttrMap;
 template<typename T>
 class Symbol;
 
+class Device;
 class ParallelDesc;
 
 namespace cfg {
@@ -88,6 +89,7 @@ enum ValueType {
   kGENERATOR_REF,
   kGENERATOR_MAYBE,
   kTENSOR_INDEX,
+  kDEVICE,
   kPARALLEL_DESC,
   kSBP_PARALLEL,
   kSBP_PARALLEL_LIST,
@@ -141,6 +143,7 @@ VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR);
 VALUE_TYPE_OF_IMPL(std::shared_ptr<one::Generator>, kGENERATOR_REF);
 VALUE_TYPE_OF_IMPL(Maybe<one::Generator>, kGENERATOR_MAYBE);
 VALUE_TYPE_OF_IMPL(TensorIndex, kTENSOR_INDEX);
+VALUE_TYPE_OF_IMPL(Symbol<Device>, kDEVICE);
 VALUE_TYPE_OF_IMPL(Symbol<ParallelDesc>, kPARALLEL_DESC);
 VALUE_TYPE_OF_IMPL(Symbol<cfg::SbpParallel>, kSBP_PARALLEL);
 VALUE_TYPE_OF_IMPL(std::vector<Symbol<cfg::SbpParallel>>, kSBP_PARALLEL_LIST);
diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp
index e491262fc..81a60f162 100644
--- a/oneflow/core/job/parallel_desc.cpp
+++ b/oneflow/core/job/parallel_desc.cpp
@@ -159,6 +159,25 @@ Maybe<Symbol<Device>> ParallelDesc::GetDevice4CurrentProcessCtx(int64_t* paralle
   }
 }
 
+Maybe<Symbol<Device>> GetDevice4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc,
+                                                  int64_t* parallel_id) {
+  static thread_local HashMap<Symbol<ParallelDesc>, int64_t> parallel_desc2parallel_id;
+  static thread_local HashMap<Symbol<ParallelDesc>, Symbol<Device>> parallel_desc2device;
+  auto parallel_id_iter = parallel_desc2parallel_id.find(parallel_desc);
+  auto device_iter = parallel_desc2device.find(parallel_desc);
+  if (device_iter == parallel_desc2device.end()) {
+    CHECK_OR_RETURN(parallel_id_iter == parallel_desc2parallel_id.end());
+    int64_t id_val = 0;
+    const auto& device = JUST(parallel_desc->GetDevice4CurrentProcessCtx(&id_val));
+    parallel_id_iter = parallel_desc2parallel_id.emplace(parallel_desc, id_val).first;
+    device_iter = parallel_desc2device.emplace(parallel_desc, device).first;
+  } else {
+    CHECK_OR_RETURN(parallel_id_iter != parallel_desc2parallel_id.end());
+  }
+  *parallel_id = parallel_id_iter->second;
+  return device_iter->second;
+}
+
 bool ParallelDesc::TryGetParallelId(int64_t machine_id, int64_t device_id,
                                     int64_t* parallel_id) const {
   const auto& machine_iter = machine_id2device_id2parallel_id_.find(machine_id);
diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h
index 127c5d128..221ea8ec9 100644
--- a/oneflow/core/job/parallel_desc.h
+++ b/oneflow/core/job/parallel_desc.h
@@ -131,6 +131,9 @@ class ParallelDesc final {
   std::shared_ptr<cfg::ParallelConf> cfg_parallel_conf_;
 };
 
+Maybe<Symbol<Device>> GetDevice4CurrentProcessCtx(Symbol<ParallelDesc> parallel_desc,
+                                                  int64_t* parallel_id);
+
 inline bool operator==(const ParallelConf& lhs, const ParallelConf& rhs) {
   return ParallelDesc(lhs) == ParallelDesc(rhs);
 }
diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp
index 59ab04844..8cb46e10c 100644
--- a/oneflow/core/operator/user_op.cpp
+++ b/oneflow/core/operator/user_op.cpp
@@ -768,12 +768,23 @@ Maybe<void> UserOp::InferParallelDistributionSignature(
     UserOpInferParallelDistributionFnContext ctx(this, parallel_distribution_signature,
                                                  parallel_distribution_constraints,
                                                  ParallelDistributionInferHint4Ibn);
-    return val_->parallel_distribution_infer_fn(&ctx);
+    JUST(val_->parallel_distribution_infer_fn(&ctx));
   } else {
-    return Operator::InferParallelDistributionSignature(
+    JUST(Operator::InferParallelDistributionSignature(
         parallel_distribution_signature, parallel_distribution_constraints, parallel_desc,
-        ParallelDistributionInferHint4Ibn);
+        ParallelDistributionInferHint4Ibn));
   }
+  std::string tick_bn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0);
+  if (std::find(input_bns().begin(), input_bns().end(), tick_bn) != input_bns().end()) {
+    auto* map = parallel_distribution_signature->mutable_bn_in_op2parallel_distribution();
+    if (map->count(tick_bn) == 0) {
+      auto* sbp_list = (*map)[tick_bn].mutable_sbp_parallel();
+      for (int i = 0; i < parallel_desc.hierarchy()->NumAxes(); ++i) {
+        sbp_list->Add()->mutable_broadcast_parallel();
+      }
+    }
+  }
+  return Maybe<void>::Ok();
 }
 
 Symbol<OperatorConf> UserOp::GetOpConfWithoutOpNameAndLbn() const {
diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp
index 2ec6d60f3..1b5457881 100644
--- a/oneflow/user/ops/constant_op.cpp
+++ b/oneflow/user/ops/constant_op.cpp
@@ -14,8 +14,14 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 #include "oneflow/core/framework/framework.h"
+#include "oneflow/core/common/protobuf.h"
+#include "oneflow/core/common/global.h"
+#include "oneflow/core/job/global_for.h"
 
 namespace oneflow {
+
+Maybe<void> InferConstantParallelDistribution(user_op::InferParallelDistributionFnContext* ctx);
+
 REGISTER_NO_GRAD_USER_OP("constant")
     .Output("out")
     .SetOutputBufferNum(1)
@@ -24,6 +30,7 @@ REGISTER_NO_GRAD_USER_OP("constant")
     .Attr<bool>("is_floating_value")
     .Attr<DataType>("dtype")
     .Attr<Shape>("shape")
+    .Attr<std::string>("nd_sbp")
     .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
       Shape* out_shape = ctx->OutputShape("out", 0);
       const Shape& shape = ctx->Attr<Shape>("shape");
@@ -43,6 +50,20 @@ REGISTER_NO_GRAD_USER_OP("constant")
       auto dtype = ctx->Attr<DataType>("dtype");
       *ctx->OutputDType("out", 0) = dtype;
       return Maybe<void>::Ok();
-    });
+    })
+    .SetParallelDistributionInferFn(&InferConstantParallelDistribution);
+
+Maybe<void> InferConstantParallelDistribution(user_op::InferParallelDistributionFnContext* ctx) {
+  cfg::ParallelDistribution* out = ctx->ParallelDistribution4ArgNameAndIndex("out", 0);
+  if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
+    const auto& pb_str = ctx->user_op_conf().attr<std::string>("nd_sbp");
+    ParallelDistribution pb;
+    CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb));
+    out->InitFromProto(pb);
+  } else {
+    out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
+  }
+  return Maybe<void>::Ok();
+}
 
 }  // namespace oneflow
diff --git a/python/oneflow/framework/distribute.py b/python/oneflow/framework/distribute.py
index 117ab34e7..3bdfae3b7 100644
--- a/python/oneflow/framework/distribute.py
+++ b/python/oneflow/framework/distribute.py
@@ -205,9 +205,7 @@ def is_multi_client():
     return oneflow._oneflow_internal.IsMultiClient()
 
 
-def split_sbp(
-    axis: int,
-) -> oneflow._oneflow_internal.oneflow.core.job.sbp_parallel.SbpParallel:
+def split_sbp(axis: int) -> oneflow._oneflow_internal.sbp.sbp:
     """Generate a split scheme in which op will be splitted at `axis`.
 
     Args:
diff --git a/python/oneflow/nn/modules/constant.py b/python/oneflow/nn/modules/constant.py
index f0de614c8..050213cea 100644
--- a/python/oneflow/nn/modules/constant.py
+++ b/python/oneflow/nn/modules/constant.py
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """
-from typing import Optional, Union
+from typing import List, Optional, Union
 
 import oneflow as flow
 from oneflow.framework.tensor import register_tensor_op
@@ -29,6 +29,10 @@ class _ConstantBase(Module):
         value: Union[float, int],
         dtype: Optional[flow.dtype],
         device: Union[flow.device, str] = None,
+        placement: flow.placement = None,
+        sbp: Union[
+            flow._oneflow_internal.sbp.sbp, List[flow._oneflow_internal.sbp.sbp]
+        ] = None,
         requires_grad: bool = False,
     ) -> None:
         super().__init__()
@@ -37,32 +41,63 @@ class _ConstantBase(Module):
             size, (int, tuple, flow.Size)
         ), "shape should be int or tuple int!"
         self.device = device
+        if isinstance(self.device, str):
+            self.device = flow.device(self.device)
         self.requires_grad = requires_grad
         size = _single(size)
         if dtype is None:
             dtype = flow.float32
-        if device is None:
-            self.device = flow.device("cpu")
+        if placement is None:
+            if device is None:
+                self.device = flow.device("cpu")
+        else:
+            assert device is None
+        self.placement = placement
+        self.sbp = sbp
+        if placement is not None:
+            assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp
+            if isinstance(self.sbp, flow.sbp.sbp):
+                self.sbp = (self.sbp,)
+            else:
+                for elem in sbp:
+                    assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp
+            assert len(self.sbp) == len(placement.hierarchy)
+        else:
+            assert sbp is None, "sbp: %s" % sbp
         self.shape = size
         self.value = value
         self.dtype = dtype
 
     def forward(self):
-        res = flow.F.constant(self.shape, self.value, self.dtype)
-        res = res.to(device=self.device)
+        if self.placement is not None:
+            res = flow.F.consistent_constant(
+                self.shape, self.value, self.dtype, self.placement, self.sbp,
+            )
+        else:
+            res = flow.F.constant(self.shape, self.value, self.dtype, self.device,)
         res.requires_grad = self.requires_grad
         return res
 
 
 class Ones(_ConstantBase):
-    def __init__(self, size, dtype=None, device=None, requires_grad=False):
-        super().__init__(size, 1, dtype, device, requires_grad)
+    def __init__(
+        self,
+        size,
+        dtype=None,
+        device=None,
+        placement=None,
+        sbp=None,
+        requires_grad=False,
+    ):
+        super().__init__(size, 1, dtype, device, placement, sbp, requires_grad)
 
 
 def ones_op(
     size: Union[_size_any_t, flow.Size],
     dtype: Optional[flow.dtype] = None,
     device: Union[flow.device, str, None] = None,
+    placement: flow.placement = None,
+    sbp: flow._oneflow_internal.sbp.sbp = None,
     requires_grad: bool = False,
 ):
     """
@@ -73,7 +108,9 @@ def ones_op(
         size (an integer or tuple of integer values) – defining the shape of the output tensor. Can be \\
          a variable number of arguments or a collection like a list or tuple.
         dtype (flow.dtype, optional) – the desired data type of returned tensor.
-        device (torch.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type
+        device (flow.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type
+        placement (flow.placement, optional) – the desired placement of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`.
+        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional) – the desired sbp descriptor of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`.
         requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default: False.
 
     For example:
@@ -86,18 +123,28 @@ def ones_op(
         tensor([1., 1., 1., 1., 1.], dtype=oneflow.float32)
 
     """
-    return Ones(size, dtype, device, requires_grad)()
+    return Ones(size, dtype, device, placement, sbp, requires_grad)()
 
 
 class Zeros(_ConstantBase):
-    def __init__(self, size, dtype=None, device=None, requires_grad=False):
-        super().__init__(size, 0, dtype, device, requires_grad)
+    def __init__(
+        self,
+        size,
+        dtype=None,
+        device=None,
+        placement=None,
+        sbp=None,
+        requires_grad=False,
+    ):
+        super().__init__(size, 0, dtype, device, placement, sbp, requires_grad)
 
 
 def zeros_op(
     size: Union[_size_any_t, flow.Size],
     dtype: Optional[flow.dtype] = None,
     device: Union[flow.device, str, None] = None,
+    placement: flow.placement = None,
+    sbp: flow._oneflow_internal.sbp.sbp = None,
     requires_grad: bool = False,
 ):
     """
@@ -108,7 +155,9 @@ def zeros_op(
         size(an integer or tuple of integer values) - defining the shape of the output tensor. Can be \\
          a variable number of arguments or a collection like a list or tuple.
         dtype (flow.dtype, optional) – the desired data type of returned tensor.
-        device (torch.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type
+        device (flow.device, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type
+        placement (flow.placement, optional) – the desired placement of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`.
+        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional) – the desired sbp descriptor of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`.
         requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default: False.
 
     For example:
@@ -121,7 +170,7 @@ def zeros_op(
         tensor([0., 0., 0., 0., 0.], dtype=oneflow.float32)
 
     """
-    return Zeros(size, dtype, device, requires_grad)()
+    return Zeros(size, dtype, device, placement, sbp, requires_grad)()
 
 
 class ZerosLike(Module):
@@ -192,10 +241,16 @@ class NewOnes(Module):
         size: Union[_size_any_t, flow.Size] = None,
         dtype: Optional[flow.dtype] = None,
         device: Union[flow.device, str] = None,
+        placement: flow.placement = None,
+        sbp: flow._oneflow_internal.sbp.sbp = None,
         requires_grad: bool = False,
     ):
         super().__init__()
         self.device = device
+        if isinstance(self.device, str):
+            self.device = flow.device(self.device)
+        self.placement = placement
+        self.sbp = sbp
         self.requires_grad = requires_grad
         if size != None:
             size = _single(size)
@@ -206,59 +261,89 @@ class NewOnes(Module):
         new_size = self.size
         new_dtype = self.dtype
         new_device = self.device
+        new_placement = self.placement
+        new_sbp = self.sbp
         new_requires_grad = self.requires_grad
         if self.size is None:
             new_size = x.shape
         if self.dtype is None:
             new_dtype = x.dtype
         if self.device is None:
-            new_device = x.device
+            new_device = x.device if x.is_local else None
+        if self.placement is None:
+            new_placement = x.placement if x.is_consistent else None
+        if self.sbp is None:
+            new_sbp = x.sbp if x.is_consistent else None
+        if new_placement is not None:
+            assert self.device is None
+            assert new_sbp is not None
         assert isinstance(
             new_size, (int, tuple, flow.Size)
         ), f"size parameter not correct, please check!"
         assert isinstance(
             new_dtype, flow.dtype
         ), f"dtype parameter not correct, please check!"
-        assert isinstance(
-            new_device, (str, flow.device)
-        ), f"device parameter not correct, please check!"
+        if new_placement is not None:
+            assert isinstance(
+                new_placement, flow.placement
+            ), f"device parameter not correct, please check!"
+            assert isinstance(
+                new_sbp, flow.sbp.sbp
+            ), f"device parameter not correct, please check!"
+        else:
+            assert isinstance(
+                new_device, (str, flow.device)
+            ), f"device parameter not correct, please check!"
         assert isinstance(
             new_requires_grad, bool
         ), f"requires_grad parameter not correct, please check!"
-        res = flow.F.constant(new_size, 1.0, new_dtype)
-        res = res.to(new_device)
+        if self.placement is not None:
+            res = flow.F.consistent_constant(
+                new_size, 1.0, new_dtype, self.placement, self.sbp
+            )
+        else:
+            res = flow.F.constant(new_size, 1.0, new_dtype, new_device)
         res.requires_grad = new_requires_grad
         return res
 
 
 @register_tensor_op("new_ones")
-def new_ones_op(x, size=None, dtype=None, device=None, requires_grad=False):
+def new_ones_op(
+    x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False
+):
     """
-    
+
     Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same torch.dtype and torch.device as this tensor.
 
     Args:
         size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.
         dtype (flow.dtype, optional):  the desired type of returned tensor. Default: if None, same flow.dtype as this tensor.
         device (flow.device, optional): the desired device of returned tensor. Default: if None, same flow.device as this tensor.
+        placement (flow.placement, optional) – the desired placement of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`.
+        sbp (flow.sbp.sbp or tuple of flow.sbp.sbp, optional) – the desired sbp descriptor of returned consistent tensor. Default: if None, the returned tensor is local one using the argument `device`.
         requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.
-    
+
     For example:
 
     .. code-block:: python
 
         >>> import numpy as np
         >>> import oneflow as flow
-        
+
         >>> x = flow.Tensor(np.ones((1, 2, 3)))
         >>> y = x.new_ones((2, 2))
         >>> y
         tensor([[1., 1.],
                 [1., 1.]], dtype=oneflow.float32)
     """
-    return NewOnes(size=size, dtype=dtype, device=device, requires_grad=requires_grad)(
-        x
-    )
+    return NewOnes(
+        size=size,
+        dtype=dtype,
+        device=device,
+        placement=placement,
+        sbp=sbp,
+        requires_grad=requires_grad,
+    )(x)
 
 
 if __name__ == "__main__":
diff --git a/python/oneflow/sbp.py b/python/oneflow/sbp.py
index 8eb354821..e652f7744 100644
--- a/python/oneflow/sbp.py
+++ b/python/oneflow/sbp.py
@@ -15,6 +15,8 @@ limitations under the License.
 """
 import oneflow
 from oneflow.framework.distribute import split_sbp as split
+import oneflow._oneflow_internal
 
+sbp = oneflow._oneflow_internal.sbp.sbp
 broadcast = oneflow._oneflow_internal.sbp.broadcast()
 partial_sum = oneflow._oneflow_internal.sbp.partial_sum()
diff --git a/python/oneflow/test/modules/test_constant.py b/python/oneflow/test/modules/test_constant.py
index deaeb0606..84d33da60 100644
--- a/python/oneflow/test/modules/test_constant.py
+++ b/python/oneflow/test/modules/test_constant.py
@@ -85,15 +85,10 @@ def _test_zeros_like(test_case, device, shape):
 
 
 def _test_new_ones(test_case, device, shape):
-    x = flow.Tensor(np.ones(shape), device=flow.device(device))
+    x = flow.ones(shape, device=flow.device("cpu"))
     y = x.new_ones(shape, device=device)
     test_case.assertTrue(x.dtype == y.dtype)
-    test_case.assertTrue(x.device == y.device)
-    test_case.assertTrue(x.requires_grad == y.requires_grad)
-    x = flow.Tensor(np.ones(shape), device=flow.device(device))
-    y = x.new_ones(x.shape, device=device)
-    test_case.assertTrue(x.dtype == y.dtype)
-    test_case.assertTrue(x.device == y.device)
+    test_case.assertEqual(flow.device(device), y.device)
     test_case.assertTrue(x.requires_grad == y.requires_grad)
     x = flow.Tensor(np.ones(shape), device=flow.device(device))
     x = x.new_ones(shape, device=device, requires_grad=True)
@@ -104,6 +99,13 @@ def _test_new_ones(test_case, device, shape):
 
 @flow.unittest.skip_unless_1n1d()
 class TestConstantModule(flow.unittest.TestCase):
+    def test_consistent_naive(test_case):
+        placement = flow.placement("cpu", {0: [0]})
+        sbp = (flow.sbp.broadcast,)
+        x = flow.ones((16, 16), placement=placement, sbp=sbp)
+        test_case.assertEqual(x.sbp, sbp)
+        test_case.assertEqual(x.placement, placement)
+
     def test_cast(test_case):
         arg_dict = OrderedDict()
         arg_dict["test_fun"] = [
diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py
index 9b71af013..00f64934a 100644
--- a/tools/generate_functional_api.py
+++ b/tools/generate_functional_api.py
@@ -151,9 +151,10 @@ types_allowed = {
     "Shape",
     "Generator",
     "TensorIndex",
-    "ParallelDesc",
-    "SbpParallel",
-    "SbpParallelList",
+    "Device",
+    "Placement",
+    "Sbp",
+    "SbpList",
 }
 
 generic_type_aliases = {
@@ -182,9 +183,10 @@ argument_type_aliases = {
     "Shape": "const Shape&",
     "Generator": "const std::shared_ptr<one::Generator>&",
     "TensorIndex": "const TensorIndex&",
-    "ParallelDesc": "const Symbol<ParallelDesc>&",
-    "SbpParallel": "const Symbol<cfg::SbpParallel>&",
-    "SbpParallelList": "const std::vector<Symbol<cfg::SbpParallel>>&",
+    "Device": "const Symbol<Device>&",
+    "Placement": "const Symbol<ParallelDesc>&",
+    "Sbp": "const Symbol<cfg::SbpParallel>&",
+    "SbpList": "const std::vector<Symbol<cfg::SbpParallel>>&",
     **generic_type_aliases,
 }
 
@@ -205,9 +207,10 @@ optional_argument_type_aliases = {
     "Shape": "const Optional<Shape>&",
     "Generator": "const Optional<one::Generator>&",
     "TensorIndex": "const Optional<TensorIndex>&",
-    "ParallelDesc": "const Optional<Symbol<ParallelDesc>>&",
-    "SbpParallel": "const Optional<Symbol<SbpParallel>>&",
-    "SbpParallelList": "const Optional<std::vector<Symbol<cfg::SbpParallel>>>&",
+    "Device": "const Optional<Symbol<Device>>&",
+    "Placement": "const Optional<Symbol<ParallelDesc>>&",
+    "Sbp": "const Optional<Symbol<SbpParallel>>&",
+    "SbpList": "const Optional<std::vector<Symbol<cfg::SbpParallel>>>&",
     **{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()},
 }
 
-- 
GitLab