diff --git a/CMakeLists.txt b/CMakeLists.txt
index 48557fc2b4a56438ded4944434197a661bc7cf77..0cf82571ebd846aca7ed5aae36adf0d66be17f4b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -40,6 +40,11 @@ if (WIN32)
   set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0")
 else()
   list(APPEND CUDA_NVCC_FLAGS -std=c++11 -w -Wno-deprecated-gpu-targets)
+  list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_30,code=\"sm_30,compute_30\")
+  list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_52,code=\"sm_52,compute_52\")
+  list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_60,code=\"sm_60,compute_60\")
+  list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_61,code=\"sm_61,compute_61\")
+  list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_70,code=\"sm_70,compute_70\")
   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare -Wno-unused-function")
   if (RELEASE_VERSION)
     list(APPEND CUDA_NVCC_FLAGS -O3)
diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake
index f43ab9b9a84577da39bf935c388725e4deb59822..f368c81215365a4f82c147bf2b8b586abdd65b81 100644
--- a/cmake/oneflow.cmake
+++ b/cmake/oneflow.cmake
@@ -1,5 +1,6 @@
 # main cpp
 list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/job/oneflow.cpp)
+list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/ndarray/ndarray_reduce_test.cpp)
 list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_resnet.cpp)
 list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_alexnet.cpp)
 list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_googlenet.cpp)
diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake
index e33988991f6994c795b0f2ff241cf176da669982..7e5a054376dea24ed1a320b2dbca0e7ac778adca 100644
--- a/cmake/third_party.cmake
+++ b/cmake/third_party.cmake
@@ -12,6 +12,7 @@ include(libjpeg-turbo)
 include(opencv)
 include(eigen)
 include(cocoapi)
+include(half)
 
 if (BUILD_CUDA)
   set(CUDA_SEPARABLE_COMPILATION ON)
@@ -90,6 +91,7 @@ set(oneflow_third_party_dependencies
   eigen
   cocoapi_copy_headers_to_destination
   cocoapi_copy_libs_to_destination
+  half_copy_headers_to_destination
 )
 
 include_directories(
@@ -104,6 +106,7 @@ include_directories(
     ${OPENCV_INCLUDE_DIR}
     ${EIGEN_INCLUDE_DIR}
     ${COCOAPI_INCLUDE_DIR}
+    ${HALF_INCLUDE_DIR}
 )
 
 if (BUILD_CUDA)
@@ -124,3 +127,5 @@ if (BUILD_CUDA)
     ${NCCL_INCLUDE_DIR}
 )
 endif()
+
+add_definitions(-DHALF_ENABLE_CPP11_USER_LITERALS=0)
diff --git a/cmake/third_party/half.cmake b/cmake/third_party/half.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..d068bf612d8dfac47d2b770f3bfb6fce61eaafe7
--- /dev/null
+++ b/cmake/third_party/half.cmake
@@ -0,0 +1,35 @@
+include (ExternalProject)
+
+set(HALF_INCLUDE_DIR ${THIRD_PARTY_DIR}/half/include)
+
+set(HALF_URL https://cfhcable.dl.sourceforge.net/project/half/half/1.12.0/half-1.12.0.zip)
+set(HALF_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/half/src/half)
+
+set(HALF_HEADERS
+    "${HALF_BASE_DIR}/include/half.hpp"
+)
+
+if(BUILD_THIRD_PARTY)
+
+ExternalProject_Add(half
+    PREFIX half
+    URL ${HALF_URL}
+    UPDATE_COMMAND ""
+    CONFIGURE_COMMAND ""
+    BUILD_COMMAND ""
+    BUILD_IN_SOURCE 1
+    INSTALL_COMMAND ""
+)
+
+add_custom_target(half_create_header_dir
+    COMMAND ${CMAKE_COMMAND} -E make_directory ${HALF_INCLUDE_DIR}
+    DEPENDS half)
+
+add_custom_target(half_copy_headers_to_destination
+    DEPENDS half_create_header_dir)
+
+foreach(header_file ${HALF_HEADERS})
+    add_custom_command(TARGET half_copy_headers_to_destination PRE_BUILD
+    COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${HALF_INCLUDE_DIR})
+endforeach()
+endif(BUILD_THIRD_PARTY)
diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp
index 8134fda32299958c6f613969826f793540568c89..d3ed86f26320c46418f022307641ca60eb6cd8d8 100644
--- a/oneflow/core/actor/actor.cpp
+++ b/oneflow/core/actor/actor.cpp
@@ -193,8 +193,11 @@ int Actor::HandlerNormal(const ActorMsg& msg) {
       Regst* regst = msg.regst();
       if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
         CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst));
-        NormalProcessNaiveReadableRegstMsg(
-            naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id()));
+        const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id());
+        CHECK(rdeq.empty() == false);
+        if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) {
+          NormalProcessNaiveReadableDataRegstMsg(rdeq);
+        }
       } else if (TryUpdtStateAsProducedRegst(regst) == 0) {
         // do nothing
       } else {
@@ -325,8 +328,9 @@ void Actor::AsyncSendConsumedCtrlRegstMsgToProducer() {
     CHECK_GE(reg_deq.size(), returned_regst_num);
     for (size_t i = 0; i < returned_regst_num; ++i) {
       Regst* regst = reg_deq.at(i);
-      AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst));
+      // must access regst before sending it to producer
       regst_desc_ids.push_back(regst->regst_desc_id());
+      AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst));
     }
   });
   naive_consumed_rs_.PopFrontRegsts(regst_desc_ids);
@@ -452,8 +456,10 @@ void Actor::AsyncSendRegstMsgToProducer(Regst* regst) {
 }
 
 void Actor::AsyncSendRegstMsgToProducer(Regst* regst, int64_t producer) {
+  // must access regst before sending it to producer
+  int64_t regst_desc_id = regst->regst_desc_id();
   AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer, regst));
-  naive_consumed_rs_.TryPopFrontRegst(regst->regst_desc_id());
+  naive_consumed_rs_.TryPopFrontRegst(regst_desc_id);
 }
 
 Regst* Actor::GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) {
diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h
index a3dc4d3de0bac88503c7f1975dc991b83f14b36b..5e23a330e37384cfda5ccd9a8ebac9a648c86021 100644
--- a/oneflow/core/actor/actor.h
+++ b/oneflow/core/actor/actor.h
@@ -128,7 +128,7 @@ class Actor {
 
   // Process Msg
   virtual void NormalProcessCustomizedEordMsg(const ActorMsg&) {}
-  virtual void NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>&) {}
+  virtual void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) {}
   virtual void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) { UNIMPLEMENTED(); }
   virtual bool NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg&) { return false; }
   int TryUpdtStateAsProducedRegst(Regst* regst);
diff --git a/oneflow/core/actor/boxing_actor.cpp b/oneflow/core/actor/boxing_actor.cpp
index 06394b0c3e21d8fbe60541049b5c9da0957f39a8..88325dec3bf74bfa0973d394403244a30d287c1c 100644
--- a/oneflow/core/actor/boxing_actor.cpp
+++ b/oneflow/core/actor/boxing_actor.cpp
@@ -9,7 +9,7 @@ void BoxingActor::VirtualActorInit(const TaskProto& task_proto) {
   OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal);
 }
 
-void BoxingActor::NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>& rq) {
+void BoxingActor::NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>& rq) {
   if (rq.back()->packed_blob()->max_col_num() > 1 && col_id_order_ == ColIdOrder::kUnCertain) {
     TrySetColIdOrder(rq.back());
   }
diff --git a/oneflow/core/actor/boxing_actor.h b/oneflow/core/actor/boxing_actor.h
index 88189f89dd9cd02387bd24751869718f88013d5b..32609ecb842cc60dd7247057a693de06bd5a8e47 100644
--- a/oneflow/core/actor/boxing_actor.h
+++ b/oneflow/core/actor/boxing_actor.h
@@ -14,7 +14,7 @@ class BoxingActor final : public Actor {
   void VirtualActorInit(const TaskProto&) override;
 
  private:
-  void NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>&) override;
+  void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) override;
   void Act() override;
   void VirtualAsyncSendNaiveProducedRegstMsgToConsumer();
   void VirtualAsyncSendNaiveConsumedRegstMsgToProducer();
diff --git a/oneflow/core/actor/loss_print_compute_actor.h b/oneflow/core/actor/loss_print_compute_actor.h
index 551a3d391561fcc2374766f55212c6a07f1f5834..f269d45c42a4d33515f00ea494e5f6691067a3ef 100644
--- a/oneflow/core/actor/loss_print_compute_actor.h
+++ b/oneflow/core/actor/loss_print_compute_actor.h
@@ -9,9 +9,13 @@ class LossPrintCompActor final : public SinkCompActor {
  public:
   OF_DISALLOW_COPY_AND_MOVE(LossPrintCompActor);
   LossPrintCompActor() = default;
-  ~LossPrintCompActor() = default;
+  ~LossPrintCompActor() override = default;
 
  private:
+  void VirtualSinkCompActorInit(const TaskProto&) override { timestamp_ = 0; }
+  void* NewOther() override { return &timestamp_; }
+
+  double timestamp_ = 0;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/actor/naive_actor.cpp b/oneflow/core/actor/naive_actor.cpp
index 642b5ecc1132b833f7b07d3ba0454b1b6df29ee9..f32db4bb1087f2398cb0068d81c69bd6db909173 100644
--- a/oneflow/core/actor/naive_actor.cpp
+++ b/oneflow/core/actor/naive_actor.cpp
@@ -12,4 +12,6 @@ void NaiveActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
   });
 }
 
+REGISTER_ACTOR(TaskType::kReduceIdentity, NaiveActor);
+
 }  // namespace oneflow
diff --git a/oneflow/core/actor/normal_backward_compute_actor.cpp b/oneflow/core/actor/normal_backward_compute_actor.cpp
index 61313aee6362908874ad2ee157a4238664be44ec..577136c906c8592ecc6d9a51eb44dd067cfaa7c7 100644
--- a/oneflow/core/actor/normal_backward_compute_actor.cpp
+++ b/oneflow/core/actor/normal_backward_compute_actor.cpp
@@ -33,7 +33,7 @@ void NormalBackwardCompActor::ForEachCurCustomizedReadableRegst(
   }
 }
 
-void NormalBackwardCompActor::NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>& rq) {
+void NormalBackwardCompActor::NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>& rq) {
   if (rq.size() == 1 && rq.front()->regst_desc_id() == any_out_diff_regst_desc_id_) {
     AsyncReturnModelRegstUntilModelVersionIdEqual(
         GetModelVersionIdFromPieceId(rq.front()->piece_id(), actual_num_of_piece_in_batch_));
diff --git a/oneflow/core/actor/normal_backward_compute_actor.h b/oneflow/core/actor/normal_backward_compute_actor.h
index fd0bc969520f71dfefcfc36e3b56f50fcd722026..4cf1ea961351ba6c90b2948819c0aa2bf9a425d7 100644
--- a/oneflow/core/actor/normal_backward_compute_actor.h
+++ b/oneflow/core/actor/normal_backward_compute_actor.h
@@ -15,7 +15,7 @@ class NormalBackwardCompActor final : public CompActor {
 
  private:
   void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;
-  void NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>&) override;
+  void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) override;
   void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;
   void Act() override;
   bool IsCustomizedReadReady() override;
diff --git a/oneflow/core/actor/reduce_concat_compute_actor.cpp b/oneflow/core/actor/reduce_concat_compute_actor.cpp
index 8cedf4dcf37355dc37a78425694c6e7c153c1910..dc7b7bf87b64bc1efc013fe941d8cf34ff539e05 100644
--- a/oneflow/core/actor/reduce_concat_compute_actor.cpp
+++ b/oneflow/core/actor/reduce_concat_compute_actor.cpp
@@ -2,16 +2,6 @@
 
 namespace oneflow {
 
-void ReduceConcatCompActor::VirtualCompActorInit(const TaskProto& proto) {
-  InputWiseCompActor::Init(proto);
-}
-
-void ReduceConcatCompActor::SetKernelCtxOther(void** other) {
-  int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id());
-  other_val_ = std::make_pair(in_bn_id, EnableInplace());
-  *other = static_cast<void*>(&other_val_);
-}
-
 REGISTER_ACTOR(TaskType::kReduceConcat, ReduceConcatCompActor);
 
 }  // namespace oneflow
diff --git a/oneflow/core/actor/reduce_concat_compute_actor.h b/oneflow/core/actor/reduce_concat_compute_actor.h
index e482fce7d0df4003be929f0227731fbbfd591528..740c6f28088d92fffb64554a9550f2838797bb37 100644
--- a/oneflow/core/actor/reduce_concat_compute_actor.h
+++ b/oneflow/core/actor/reduce_concat_compute_actor.h
@@ -1,21 +1,15 @@
 #ifndef ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_
 #define ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_
 
-#include "oneflow/core/actor/input_wise_compute_actor.h"
+#include "oneflow/core/actor/naive_actor.h"
 
 namespace oneflow {
 
-class ReduceConcatCompActor final : public InputWiseCompActor {
+class ReduceConcatCompActor final : public NaiveActor {
  public:
   OF_DISALLOW_COPY_AND_MOVE(ReduceConcatCompActor);
   ReduceConcatCompActor() = default;
   ~ReduceConcatCompActor() = default;
-
- private:
-  void VirtualCompActorInit(const TaskProto& proto) override;
-  void SetKernelCtxOther(void** other) override;
-
-  std::pair<int64_t, bool> other_val_;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/actor/reduce_split_compute_actor.cpp b/oneflow/core/actor/reduce_split_compute_actor.cpp
index d5e58b32fae21025d2ee5743860d13cf44529cbb..e107fa957239b0d9ed3f5460e87c51cb4a29b77a 100644
--- a/oneflow/core/actor/reduce_split_compute_actor.cpp
+++ b/oneflow/core/actor/reduce_split_compute_actor.cpp
@@ -2,15 +2,6 @@
 
 namespace oneflow {
 
-void ReduceSplitCompActor::VirtualCompActorInit(const TaskProto& proto) {
-  InputWiseCompActor::Init(proto);
-}
-
-void ReduceSplitCompActor::SetKernelCtxOther(void** other) {
-  other_val_ = EnableInplace();
-  *other = static_cast<void*>(&other_val_);
-}
-
 REGISTER_ACTOR(TaskType::kReduceSplit, ReduceSplitCompActor);
 
 }  // namespace oneflow
diff --git a/oneflow/core/actor/reduce_split_compute_actor.h b/oneflow/core/actor/reduce_split_compute_actor.h
index 7cbdc0c582b4781b8466aa29d6572ebec5466356..2a3f7752ac66dba825c3c2f1e6eb60153c8d5116 100644
--- a/oneflow/core/actor/reduce_split_compute_actor.h
+++ b/oneflow/core/actor/reduce_split_compute_actor.h
@@ -1,21 +1,15 @@
 #ifndef ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_
 #define ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_
 
-#include "oneflow/core/actor/input_wise_compute_actor.h"
+#include "oneflow/core/actor/naive_actor.h"
 
 namespace oneflow {
 
-class ReduceSplitCompActor final : public InputWiseCompActor {
+class ReduceSplitCompActor final : public NaiveActor {
  public:
   OF_DISALLOW_COPY_AND_MOVE(ReduceSplitCompActor);
   ReduceSplitCompActor() = default;
   ~ReduceSplitCompActor() = default;
-
- private:
-  void VirtualCompActorInit(const TaskProto& proto) override;
-  void SetKernelCtxOther(void** other) override;
-
-  bool other_val_;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/comm_network/epoll/epoll_comm_network.cpp b/oneflow/core/comm_network/epoll/epoll_comm_network.cpp
index ae47cb890022b26616d554644781596c8758cef3..de85b83a808014d8dc4e699675865771db352cc3 100644
--- a/oneflow/core/comm_network/epoll/epoll_comm_network.cpp
+++ b/oneflow/core/comm_network/epoll/epoll_comm_network.cpp
@@ -114,13 +114,13 @@ void EpollCommNet::InitSockets() {
              ((this_machine.data_port_agent() != -1) ? (this_machine.data_port_agent())
                                                      : (this_listen_port)));
   } else {
-    for (this_listen_port = 1024; this_listen_port < MaxVal<uint16_t>(); ++this_listen_port) {
+    for (this_listen_port = 1024; this_listen_port < GetMaxVal<uint16_t>(); ++this_listen_port) {
       if (SockListen(listen_sockfd, this_listen_port, total_machine_num) == 0) {
         PushPort(this_machine_id, this_listen_port);
         break;
       }
     }
-    CHECK_LT(this_listen_port, MaxVal<uint16_t>());
+    CHECK_LT(this_listen_port, GetMaxVal<uint16_t>());
   }
   int32_t src_machine_count = 0;
 
diff --git a/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp b/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
index 4ec79c7b2adb50c0407b250499f9ab15551cf606..1a747879b37cc02d0bf8123637bfcc53f96b90f3 100644
--- a/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
+++ b/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
@@ -65,7 +65,7 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) {
   qp_attr.ah_attr.grh.dgid.global.interface_id = peer_info.interface_id();
   qp_attr.ah_attr.grh.flow_label = 0;
   qp_attr.ah_attr.grh.sgid_index = 0;
-  qp_attr.ah_attr.grh.hop_limit = MaxVal<uint8_t>();
+  qp_attr.ah_attr.grh.hop_limit = GetMaxVal<uint8_t>();
   qp_attr.ah_attr.dlid = peer_info.lid();
   qp_attr.ah_attr.sl = 0;
   qp_attr.ah_attr.src_path_bits = 0;
diff --git a/oneflow/core/common/blas.h b/oneflow/core/common/blas.h
index 4bbe52b2d9d0db1e5fa4f0754953d883d540a56f..9c8efe25f396bfc7827a15d8b71abf5d7126c686 100644
--- a/oneflow/core/common/blas.h
+++ b/oneflow/core/common/blas.h
@@ -8,14 +8,16 @@
 
 namespace oneflow {
 
-#define BLAS_NAME_SEQ        \
-  OF_PP_MAKE_TUPLE_SEQ(dot)  \
-  OF_PP_MAKE_TUPLE_SEQ(swap) \
-  OF_PP_MAKE_TUPLE_SEQ(copy) \
-  OF_PP_MAKE_TUPLE_SEQ(axpy) \
-  OF_PP_MAKE_TUPLE_SEQ(scal) \
-  OF_PP_MAKE_TUPLE_SEQ(gemv) \
-  OF_PP_MAKE_TUPLE_SEQ(gemm)
+#define BLAS_NAME_SEQ               \
+  OF_PP_MAKE_TUPLE_SEQ(dot)         \
+  OF_PP_MAKE_TUPLE_SEQ(swap)        \
+  OF_PP_MAKE_TUPLE_SEQ(copy)        \
+  OF_PP_MAKE_TUPLE_SEQ(axpy)        \
+  OF_PP_MAKE_TUPLE_SEQ(scal)        \
+  OF_PP_MAKE_TUPLE_SEQ(gemv)        \
+  OF_PP_MAKE_TUPLE_SEQ(gemm)        \
+  OF_PP_MAKE_TUPLE_SEQ(gemmBatched) \
+  OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched)
 
 #define CBLAS_TEMPLATE(name)                                                                    \
   template<typename T, typename... Args>                                                        \
diff --git a/oneflow/core/common/data_type.h b/oneflow/core/common/data_type.h
index 688afb9e19cc55858bf0cc32c58df7afc06862ec..4c763a606184ef4609f0d6e678f0fa5dbbf93b22 100644
--- a/oneflow/core/common/data_type.h
+++ b/oneflow/core/common/data_type.h
@@ -95,6 +95,37 @@ TRAIT_CONST_VAR(One, 1);
 
 #undef TRAIT_CONST_VAR
 
+template<typename T>
+struct MaxVal;
+template<typename T>
+struct MinVal;
+
+#define TRAIT_LIMIT_VAL(max_or_min, T, limit_value)                                               \
+  template<>                                                                                      \
+  struct max_or_min##Val<T> final {                                                               \
+    static_assert(alignof(int) == alignof(int32_t), "int32_t should be exactly int");             \
+    static_assert(alignof(long long) == alignof(int64_t), "int64_t should be exactly long long"); \
+    constexpr static T value = limit_value;                                                       \
+  }
+
+TRAIT_LIMIT_VAL(Max, int8_t, CHAR_MAX);
+TRAIT_LIMIT_VAL(Max, int32_t, INT_MAX);
+TRAIT_LIMIT_VAL(Max, uint32_t, UINT_MAX);
+TRAIT_LIMIT_VAL(Max, int64_t, LLONG_MAX);
+TRAIT_LIMIT_VAL(Max, uint64_t, ULLONG_MAX);
+TRAIT_LIMIT_VAL(Max, float, FLT_MAX);
+TRAIT_LIMIT_VAL(Max, double, DBL_MAX);
+
+TRAIT_LIMIT_VAL(Min, int8_t, CHAR_MIN);
+TRAIT_LIMIT_VAL(Min, int32_t, INT_MIN);
+TRAIT_LIMIT_VAL(Min, uint32_t, 0);
+TRAIT_LIMIT_VAL(Min, int64_t, LLONG_MIN);
+TRAIT_LIMIT_VAL(Min, uint64_t, 0);
+TRAIT_LIMIT_VAL(Min, float, -FLT_MAX);
+TRAIT_LIMIT_VAL(Min, double, -DBL_MAX);
+
+#undef TRAIT_LIMIT_VAL
+
 // Func
 
 bool IsIntegralDataType(DataType data_type);
diff --git a/oneflow/core/common/protobuf.cpp b/oneflow/core/common/protobuf.cpp
index 252f3c5e71f9e235e436f756714415c7eed5f3d7..72f03736aa34a703a9803ba62f87d3f2b025c23f 100644
--- a/oneflow/core/common/protobuf.cpp
+++ b/oneflow/core/common/protobuf.cpp
@@ -60,6 +60,14 @@ PbMessage* MutableMessageInPbMessage(PbMessage* msg, const std::string& field_na
   return r->MutableMessage(msg, fd);
 }
 
+PbMessage* MutableMessageInPbMessage(PbMessage* msg, int field_index) {
+  const auto* d = const_cast<google::protobuf::Descriptor*>(msg->GetDescriptor());
+  const auto* fd = const_cast<PbFd*>(d->FindFieldByNumber(field_index));
+  CHECK_NOTNULL(fd);
+  const auto* r = const_cast<google::protobuf::Reflection*>(msg->GetReflection());
+  return r->MutableMessage(msg, fd);
+}
+
 #define DEFINE_ADD_VAL_IN_PBRF(cpp_type, pb_type_name)                                    \
   template<>                                                                              \
   void AddValInPbRf(PbMessage* msg, const std::string& field_name, const cpp_type& val) { \
diff --git a/oneflow/core/common/protobuf.h b/oneflow/core/common/protobuf.h
index 7663e98543fd1a4312578ad6037a37de048d2943..3ff546352f289da9b968188365c5ae255215c5f7 100644
--- a/oneflow/core/common/protobuf.h
+++ b/oneflow/core/common/protobuf.h
@@ -75,12 +75,19 @@ const PbRpf<T>& GetPbRpfFromPbMessage(const PbMessage& msg, const std::string& f
   return r->GetRepeatedPtrField<T>(msg, fd);
 }
 
+template<typename T>
+PbRpf<T>* MutPbRpfFromPbMessage(PbMessage* msg, const std::string& field_name) {
+  PROTOBUF_REFLECTION((*msg), field_name);
+  return r->MutableRepeatedPtrField<T>(msg, fd);
+}
+
 // Set In PbMessage
 
 template<typename T>
 void SetValInPbMessage(PbMessage* msg, const std::string& field_name, const T& val);
 
 PbMessage* MutableMessageInPbMessage(PbMessage*, const std::string& field_name);
+PbMessage* MutableMessageInPbMessage(PbMessage*, int field_index);
 
 // Add In PbMessage RepeatedField
 
diff --git a/oneflow/core/common/shape.cpp b/oneflow/core/common/shape.cpp
index b7697b494151c2ae7989efc4df8fb0d7caed0c60..7be79b858ac0ec3528dbf4999ea9d0b312c65c7f 100644
--- a/oneflow/core/common/shape.cpp
+++ b/oneflow/core/common/shape.cpp
@@ -19,14 +19,20 @@ Shape& Shape::operator=(const Shape& shape) {
 
 bool Shape::operator==(const Shape& rhs) const { return dim_vec_ == rhs.dim_vec_; }
 
-std::string Shape::DebugStr() const {
+std::string Shape::ToString() const {
   std::stringstream ss;
-  ss << "{";
-  for (int64_t dim : dim_vec_) { ss << dim << ","; }
-  ss << "(" << elem_cnt_ << ")}";
+  int32_t idx = 0;
+  ss << "(";
+  for (int64_t dim : dim_vec_) {
+    ss << dim;
+    if (++idx != dim_vec_.size() || dim_vec_.size() == 1) { ss << ","; }
+  }
+  ss << ")";
   return ss.str();
 }
 
+std::string Shape::DebugStr() const { return ToString(); }
+
 void Shape::ToProto(ShapeProto* ret) const {
   *(ret->mutable_dim()) = PbRf<int64_t>(dim_vec_.begin(), dim_vec_.end());
 }
@@ -57,4 +63,11 @@ std::ostream& operator<<(std::ostream& out, const Shape& shape) {
   return out;
 }
 
+Shape Shape::CreateLeftExtendedShape(int num_axes) const {
+  CHECK_GE(num_axes, NumAxes());
+  std::vector<int64_t> dim_vec = this->dim_vec();
+  for (int i = 0; i < num_axes - NumAxes(); ++i) { dim_vec.insert(dim_vec.begin(), 1LL); }
+  return Shape(dim_vec);
+}
+
 }  // namespace oneflow
diff --git a/oneflow/core/common/shape.h b/oneflow/core/common/shape.h
index e2ad1277349de0c7ba74db85ec2ec2398173cd0f..161089d45eb89361328c16dd8fabe1c19d21339b 100644
--- a/oneflow/core/common/shape.h
+++ b/oneflow/core/common/shape.h
@@ -19,6 +19,7 @@ class Shape final {
   bool operator==(const Shape& rhs) const;
   bool operator!=(const Shape& rhs) const { return !(*this == rhs); }
   std::string DebugStr() const;
+  std::string ToString() const;
 
   void ToProto(ShapeProto*) const;
 
@@ -34,6 +35,8 @@ class Shape final {
   int64_t Count(int64_t begin_axis, int64_t end_axis) const;
   int64_t Count(int64_t begin_axis) const;
 
+  Shape CreateLeftExtendedShape(int num_axes) const;
+
  private:
   void UpdateElemCnt();
 
diff --git a/oneflow/core/common/switch_func.h b/oneflow/core/common/switch_func.h
index 4373b0160b90b918681757f512cb18c2e6e700b3..a36f232adf799ef25a21f76e2c715f2c4451b73a 100644
--- a/oneflow/core/common/switch_func.h
+++ b/oneflow/core/common/switch_func.h
@@ -29,27 +29,27 @@ auto SwitchCase(Args&&... args) -> decltype(std::make_tuple(std::forward<Args>(a
 // CTRV example: (float, DataType::kFloat)
 // TYPED_CTRV_SEQ example: (DataType, ((float, DataType::kFloat)))
 
-#define MAKE_TYPED_CTRV_SEQ(runtime_value_type, ctrv_pair_seq) (runtime_value_type, ctrv_pair_seq)
-
 #define MAKE_DATA_TYPE_CTRV_SEQ(data_type_seq) MAKE_TYPED_CTRV_SEQ(DataType, data_type_seq)
 #define MAKE_DEVICE_TYPE_CTRV_SEQ(device_type_seq) \
   MAKE_TYPED_CTRV_SEQ(DeviceType,                  \
-                      OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPE_SEQ, device_type_seq))
+                      OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, device_type_seq))
 #define MAKE_NDIM_CTRV_SEQ(ndim_seq) \
-  MAKE_TYPED_CTRV_SEQ(int32_t, OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPE_SEQ, ndim_seq))
+  MAKE_TYPED_CTRV_SEQ(int32_t, OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, ndim_seq))
 
 #define MAKE_STRINGIZED_DATA_TYPE_CTRV(data_type_pair) \
   (OF_PP_PAIR_FIRST(data_type_pair), OF_PP_STRINGIZE(OF_PP_PAIR_FIRST(data_type_pair)))
 #define MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(data_type_seq) \
   (std::string, OF_PP_SEQ_MAP(MAKE_STRINGIZED_DATA_TYPE_CTRV, data_type_seq))
 
+#define MAKE_TYPED_CTRV_SEQ(runtime_value_type, ctrv_pair_seq) (runtime_value_type, ctrv_pair_seq)
+
 //  internal preprocessor macros
 
 #define OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(switch_case, func_args_type, func) \
   {switch_case,                                                               \
    [](func_args_type&&... args) { return func(std::forward<func_args_type>(args)...); }},
 
-#define OF_PP_I_MAKE_REPLICATE_TUPE_SEQ(x) OF_PP_MAKE_TUPLE_SEQ(x, x)
+#define OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ(x) OF_PP_MAKE_TUPLE_SEQ(x, x)
 
 #define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_1(make_template_func, func_name, func_args_type, \
                                          switch_case_pair0)                             \
diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h
index e9db5351b7868b4cdc7ceccd63938937a8f6357a..4caa56a94396e1a90bdfcd13ff5bd4173f82c4c1 100644
--- a/oneflow/core/common/util.h
+++ b/oneflow/core/common/util.h
@@ -158,7 +158,9 @@ inline uint32_t NewRandomSeed() {
 #define DEVICE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU)
 #endif
 
-#define DIM_SEQ (1)(2)(3)(4)(5)(6)(7)(8)
+#define DIM_SEQ           \
+  OF_PP_MAKE_TUPLE_SEQ(1) \
+  OF_PP_MAKE_TUPLE_SEQ(2) OF_PP_MAKE_TUPLE_SEQ(3) OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5)
 
 #define BOOL_SEQ (true)(false)
 #define PARALLEL_POLICY_SEQ (ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel)
@@ -203,12 +205,12 @@ void Erase(T& container, const std::function<bool(const typename T::value_type&)
 }
 
 template<typename T>
-inline T MinVal() {
+inline T GetMinVal() {
   return std::numeric_limits<T>::lowest();
 }
 
 template<typename T>
-inline T MaxVal() {
+inline T GetMaxVal() {
   return std::numeric_limits<T>::max();
 }
 
@@ -220,6 +222,8 @@ inline T MaxVal() {
 
 #if defined(__GNUC__)
 #define ALWAYS_INLINE __attribute__((always_inline))
+#elif defined(__CUDACC__)
+#define ALWAYS_INLINE __forceinline__
 #else
 #define ALWAYS_INLINE inline
 #endif
diff --git a/oneflow/core/device/cuda_util.cpp b/oneflow/core/device/cuda_util.cpp
index 7b080712f3b3836fd2462b288dc5077e68959c95..6fddfbb680f5e7aa3be2606642b9e5ef4940d1cb 100644
--- a/oneflow/core/device/cuda_util.cpp
+++ b/oneflow/core/device/cuda_util.cpp
@@ -45,8 +45,17 @@ const char* CurandGetErrorString(curandStatus_t error) {
   return "Unknown curand status";
 }
 
+cudaDeviceProp global_device_prop;
+
 }  // namespace
 
+void InitGlobalCudaDeviceProp() { cudaGetDeviceProperties(&global_device_prop, 0); }
+
+int32_t GetSMCudaMaxBlocksNum() {
+  return global_device_prop.multiProcessorCount * global_device_prop.maxThreadsPerMultiProcessor
+         / kCudaThreadsNumPerBlock;
+}
+
 template<>
 void CudaCheck(cudaError_t error) {
   CHECK_EQ(error, cudaSuccess) << cudaGetErrorString(error);
@@ -73,6 +82,14 @@ size_t GetAvailableGpuMemSize(int dev_id) {
   return prop.totalGlobalMem;
 }
 
+cudaDataType_t GetCudaDataType(DataType val) {
+#define MAKE_ENTRY(type_cpp, type_cuda) \
+  if (val == GetDataType<type_cpp>::value) { return type_cuda; }
+  OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CUDA_DATA_TYPE_SEQ);
+#undef MAKE_ENTRY
+  UNIMPLEMENTED();
+}
+
 #endif  // WITH_CUDA
 
 }  // namespace oneflow
diff --git a/oneflow/core/device/cuda_util.h b/oneflow/core/device/cuda_util.h
index 8c830f4201333720dd149d486d125ba1156f2d1b..d1abc1252ebef5854c5995eccad4fb5e65ee3224 100644
--- a/oneflow/core/device/cuda_util.h
+++ b/oneflow/core/device/cuda_util.h
@@ -19,16 +19,29 @@ template<typename T>
 void CudaCheck(T error);
 
 // CUDA: grid stride looping
-#define CUDA_1D_KERNEL_LOOP(i, n) \
-  for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
+#define CUDA_1D_KERNEL_LOOP(i, n)                                                                 \
+  for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < (n); \
+       i += step)
 
-const int32_t kCudaThreadsNumPerBlock = 512;
+const int32_t kCudaThreadsNumPerBlock = 1024;
 const int32_t kCudaMaxBlocksNum = 4096;
 
+int32_t GetSMCudaMaxBlocksNum();
+void InitGlobalCudaDeviceProp();
+
 inline int32_t BlocksNum4ThreadsNum(const int32_t n) {
   return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, kCudaMaxBlocksNum);
 }
 
+inline int32_t SMBlocksNum4ThreadsNum(const int32_t n) {
+  return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock,
+                  GetSMCudaMaxBlocksNum());
+}
+
+#define RUN_CUDA_KERNEL(func, device_ctx_ptr, thread_num, ...)           \
+  func<<<SMBlocksNum4ThreadsNum(thread_num), kCudaThreadsNumPerBlock, 0, \
+         (device_ctx_ptr)->cuda_stream()>>>(__VA_ARGS__)
+
 size_t GetAvailableGpuMemSize(int dev_id);
 
 #define CUDA_WORK_TYPE_SEQ           \
@@ -38,14 +51,31 @@ size_t GetAvailableGpuMemSize(int dev_id);
   OF_PP_MAKE_TUPLE_SEQ(kNcclScatter) \
   OF_PP_MAKE_TUPLE_SEQ(kNcclGather)  \
   OF_PP_MAKE_TUPLE_SEQ(kMix)         \
+  OF_PP_MAKE_TUPLE_SEQ(kReduceCtrl)  \
   OF_PP_MAKE_TUPLE_SEQ(kMdUpdt)
 
 enum class CudaWorkType {
 #define DECLARE_CUDA_WORK_TYPE(type) type,
   OF_PP_FOR_EACH_TUPLE(DECLARE_CUDA_WORK_TYPE, CUDA_WORK_TYPE_SEQ)
 };
+
 inline size_t GetCudaWorkTypeSize() { return OF_PP_SEQ_SIZE(CUDA_WORK_TYPE_SEQ); }
 
+#define CUDA_DATA_TYPE_SEQ                \
+  OF_PP_MAKE_TUPLE_SEQ(float, CUDA_R_32F) \
+  OF_PP_MAKE_TUPLE_SEQ(double, CUDA_R_64F)
+
+cudaDataType_t GetCudaDataType(DataType);
+
+template<typename T>
+struct CudaDataType;
+
+#define SPECIALIZE_CUDA_DATA_TYPE(type_cpp, type_cuda) \
+  template<>                                           \
+  struct CudaDataType<type_cpp> : std::integral_constant<cudaDataType_t, type_cuda> {};
+OF_PP_FOR_EACH_TUPLE(SPECIALIZE_CUDA_DATA_TYPE, CUDA_DATA_TYPE_SEQ);
+#undef SPECIALIZE_CUDA_DATA_TYPE
+
 }  // namespace oneflow
 
 #endif  // WITH_CUDA
diff --git a/oneflow/core/graph/accuracy_compute_task_node.cpp b/oneflow/core/graph/accuracy_compute_task_node.cpp
index d7b94e3723a29ace9c17897bbb777f78c0337eaa..1185d93a955305498a3605d97b58403e90fc6e4a 100644
--- a/oneflow/core/graph/accuracy_compute_task_node.cpp
+++ b/oneflow/core/graph/accuracy_compute_task_node.cpp
@@ -6,6 +6,7 @@ namespace oneflow {
 
 void AccuracyCompTaskNode::ProduceAllRegstsAndBindEdges() {
   ProduceRegst("accuracy", false);
+  ProduceRegst("data_tmp", true);
   for (TaskEdge* edge : out_edges()) { BindEdgeWithProducedRegst(edge, "accuracy"); }
 }
 
@@ -27,6 +28,7 @@ void AccuracyCompTaskNode::BuildExecGphAndRegst() {
     accuracy_regst->AddLbi(accuracy_op->BnInOp2Lbi(obn));
     accuracy_node->BindBnWithRegst(obn, accuracy_regst);
   }
+  accuracy_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp"));
   accuracy_node->InferBlobDescs(parallel_ctx());
 }
 
diff --git a/oneflow/core/graph/boxing_task_node.cpp b/oneflow/core/graph/boxing_task_node.cpp
index 8b70a7054b3ecb02598ff474c189d4226ba31764..0c33835056420994410c4a3f66e551913d907fb8 100644
--- a/oneflow/core/graph/boxing_task_node.cpp
+++ b/oneflow/core/graph/boxing_task_node.cpp
@@ -2,6 +2,7 @@
 #include "oneflow/core/common/balanced_splitter.h"
 #include "oneflow/core/graph/logical_node.h"
 #include "oneflow/core/graph/logical_graph.h"
+#include "oneflow/core/graph/op_graph.h"
 #include "oneflow/core/operator/operator.h"
 
 namespace oneflow {
@@ -55,20 +56,20 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(InBoxingTaskNode, DataConcatAndDataSplit) {
   conf->mutable_concat_box()->set_axis(0);
   BoxSplitConf* split_conf = conf->mutable_split_box();
   split_conf->set_axis(0);
-  BalancedSplitter bs(Global<JobDesc>::Get()->PieceSize(),
-                      out_logical->parallel_desc()->parallel_num());
-  SetBoxSplitPart(sorted_out_edges, bs, split_conf);
+  const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
+  SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi),
+                  split_conf);
 }
 DEFINE_BLD_BOXING_OP_CONF_METHOD(OutBoxingTaskNode, DataConcatAndDataSplit) {
   conf->mutable_concat_box()->set_axis(0);
   BoxSplitConf* split_conf = conf->mutable_split_box();
   split_conf->set_axis(0);
-  BalancedSplitter in_bs(Global<JobDesc>::Get()->PieceSize(),
-                         in_logical->parallel_desc()->parallel_num());
+  const std::string& in_op_name = in_logical->op_vec().at(0)->op_name();
+  BalancedSplitter in_bs = Global<OpGraph>::Get()->GetBalancedSplitter(in_op_name, lbi);
   Range in_range =
       in_bs.At(sorted_in_edges.front().parallel_id_min, sorted_in_edges.back().parallel_id_max);
-  BalancedSplitter out_bs(Global<JobDesc>::Get()->PieceSize(),
-                          out_logical->parallel_desc()->parallel_num());
+  const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
+  BalancedSplitter out_bs = Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi);
   for (const EdgeInfo& out_edge : sorted_out_edges) {
     Range out_range = out_bs.At(out_edge.parallel_id_min, out_edge.parallel_id_max);
     Range intersectant_range = FindIntersectant(in_range, out_range);
@@ -82,44 +83,89 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndClone) {
 DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndModelSplit) {
   conf->mutable_concat_box()->set_axis(0);
   BoxSplitConf* split_conf = conf->mutable_split_box();
-  split_conf->set_axis(out_logical->GetModelSplitAxis());
-  BalancedSplitter bs(out_logical->GetMaxModelSplitNum(),
-                      out_logical->parallel_desc()->parallel_num());
-  SetBoxSplitPart(sorted_out_edges, bs, split_conf);
+  const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
+  split_conf->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(out_op_name, lbi));
+  SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi),
+                  split_conf);
 }
 DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndDataSplit) {
-  conf->mutable_concat_box()->set_axis(in_logical->GetModelSplitAxis());
+  const std::string& in_op_name = in_logical->op_vec().at(0)->op_name();
+  conf->mutable_concat_box()->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(in_op_name, lbi));
   BoxSplitConf* split_conf = conf->mutable_split_box();
   split_conf->set_axis(0);
-  BalancedSplitter bs(Global<JobDesc>::Get()->PieceSize(),
-                      out_logical->parallel_desc()->parallel_num());
-  SetBoxSplitPart(sorted_out_edges, bs, split_conf);
+
+  const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
+  SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi),
+                  split_conf);
 }
 DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndClone) {
-  conf->mutable_concat_box()->set_axis(in_logical->GetModelSplitAxis());
+  const std::string& in_op_name = in_logical->op_vec().at(0)->op_name();
+  conf->mutable_concat_box()->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(in_op_name, lbi));
   conf->mutable_clone_box();
 }
 DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndDataSplit) {
   conf->mutable_add_box();
   BoxSplitConf* split_conf = conf->mutable_split_box();
   split_conf->set_axis(0);
-  BalancedSplitter bs(Global<JobDesc>::Get()->PieceSize(),
-                      out_logical->parallel_desc()->parallel_num());
-  SetBoxSplitPart(sorted_out_edges, bs, split_conf);
+
+  const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
+  SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi),
+                  split_conf);
 }
 DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndModelSplit) {
   conf->mutable_add_box();
   BoxSplitConf* split_conf = conf->mutable_split_box();
-  split_conf->set_axis(out_logical->GetModelSplitAxis());
-  BalancedSplitter bs(out_logical->GetMaxModelSplitNum(),
-                      out_logical->parallel_desc()->parallel_num());
-  SetBoxSplitPart(sorted_out_edges, bs, split_conf);
+  const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
+  split_conf->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(out_op_name, lbi));
+  SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi),
+                  split_conf);
 }
 DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndClone) {
   conf->mutable_add_box();
   conf->mutable_clone_box();
 }
 
+void SetBoxingOpConfBySbpParallel(
+    BoxingOpConf* conf, const LogicalBlobId& lbi, const Operator& in_op, const Operator& out_op,
+    const std::vector<BoxingTaskNode::EdgeInfo>& sorted_edges,
+    const std::function<SbpParallel(const std::string&, const LogicalBlobId&)>& GetSbpParallel) {
+  SbpParallel in_sbp = GetSbpParallel(in_op.op_name(), lbi);
+  if (in_sbp.has_split_parallel()) {
+    conf->mutable_concat_box()->set_axis(in_sbp.split_parallel().axis());
+  } else if (in_sbp.has_partial_sum_parallel()) {
+    conf->mutable_add_box();
+  } else {
+    UNIMPLEMENTED();
+  }
+  SbpParallel out_sbp = GetSbpParallel(out_op.op_name(), lbi);
+  if (out_sbp.has_split_parallel()) {
+    BoxSplitConf* split_conf = conf->mutable_split_box();
+    split_conf->set_axis(out_sbp.split_parallel().axis());
+    const auto& bs = Global<OpGraph>::Get()->GetBalancedSplitter(out_op.op_name(), lbi);
+    SetBoxSplitPart(sorted_edges, bs, split_conf);
+  } else if (out_sbp.has_broadcast_parallel()) {
+    conf->mutable_clone_box();
+  } else {
+    UNIMPLEMENTED();
+  }
+}
+
+DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, FwSbpParallel) {
+  SetBoxingOpConfBySbpParallel(conf, lbi, *in_logical->SoleOp(), *out_logical->SoleOp(),
+                               sorted_out_edges,
+                               [&](const std::string& op_name, const LogicalBlobId& lbi) {
+                                 return Global<OpGraph>::Get()->GetSbpParallel(op_name, lbi);
+                               });
+}
+
+DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, BwSbpParallel) {
+  SetBoxingOpConfBySbpParallel(
+      conf, lbi, *in_logical->SoleOp(), *out_logical->SoleOp(), sorted_out_edges,
+      [&](const std::string& op_name, const LogicalBlobId& lbi) {
+        return GetDualSbpParallel(Global<OpGraph>::Get()->GetSbpParallel(op_name, lbi));
+      });
+}
+
 void BoxingTaskNode::InitLogical2SortedEdgeInfo(
     const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() const,
     TaskEdge* (TaskNode::*SoleEdge)() const, TaskNode* (TaskEdge::*SoleNode)() const,
@@ -128,8 +174,8 @@ void BoxingTaskNode::InitLogical2SortedEdgeInfo(
   for (const TaskEdge* edge : (this->*GetEdges)()) {
     EdgeInfo edge_info;
     edge_info.edge = edge;
-    edge_info.parallel_id_min = MaxVal<int64_t>();
-    edge_info.parallel_id_max = MinVal<int64_t>();
+    edge_info.parallel_id_min = GetMaxVal<int64_t>();
+    edge_info.parallel_id_max = GetMinVal<int64_t>();
     std::queue<const TaskNode*> node_queue;
     node_queue.push((edge->*SoleNode)());
     const LogicalNode* logical = nullptr;
diff --git a/oneflow/core/graph/boxing_task_node.h b/oneflow/core/graph/boxing_task_node.h
index 69dfbc678a8399d65c6a9fd73259358f10f63317..447f56d78313af8aa494eb1cdeb1c246903806b4 100644
--- a/oneflow/core/graph/boxing_task_node.h
+++ b/oneflow/core/graph/boxing_task_node.h
@@ -42,6 +42,8 @@ class BoxingTaskNode : public TaskNode {
   DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndDataSplit);
   DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndModelSplit);
   DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndClone);
+  DECLARE_BLD_BOXING_OP_CONF_METHOD(FwSbpParallel);
+  DECLARE_BLD_BOXING_OP_CONF_METHOD(BwSbpParallel);
 
  private:
   void InitLogical2SortedEdgeInfo(const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)()
diff --git a/oneflow/core/graph/chain_graph.cpp b/oneflow/core/graph/chain_graph.cpp
index 5805a3bd014cb1936902e9130de3684cf26fb5dc..8767cd79b3fecbb439611f174571214b776cdc00 100644
--- a/oneflow/core/graph/chain_graph.cpp
+++ b/oneflow/core/graph/chain_graph.cpp
@@ -1,6 +1,7 @@
 #include "oneflow/core/graph/chain_graph.h"
 #include "oneflow/core/graph/task_graph.h"
 #include "oneflow/core/graph/task_node.h"
+#include "oneflow/core/graph/normal_forward_compute_task_node.h"
 #include "oneflow/core/thread/thread_pool.h"
 #include "oneflow/core/common/blocking_counter.h"
 
@@ -66,7 +67,7 @@ void ChainMerger::InitChains() {
     Chain& cur_chain = chain_list_.back();
     cur_chain.nodes = {task_node};
     cur_chain.stream_area_id =
-        std::make_pair(task_node->area_id(), task_node->GlobalWorkStreamId());
+        std::make_pair(task_node->AreaId4ChainMerge(), task_node->GlobalWorkStreamId());
     cur_chain.ancestors.resize(bitset_num);
     cur_chain.ancestors_and_this.resize(bitset_num);
     CarefullySetBitset(&(cur_chain.ancestors_and_this), GetTaskUid(task_node));
@@ -81,7 +82,7 @@ void ChainMerger::InitChains() {
 bool ChainMerger::DoMerge(std::list<ChainIt>& chains, ChainIt rhs) {
   CHECK_EQ(rhs->nodes.size(), 1);
   // rm kMdUpdtArea chain merge
-  if (rhs->nodes.front()->area_id() == kMdUpdtArea) { return false; }
+  if (rhs->nodes.front()->AreaId4ChainMerge() == kMdUpdtArea) { return false; }
   for (auto chains_it = chains.rbegin(); chains_it != chains.rend(); ++chains_it) {
     ChainIt lhs = *chains_it;
     if (IsSubset(lhs, rhs)) {
@@ -128,6 +129,23 @@ bool ChainMerger::IsSubset(const ChainIt& lhs, const ChainIt& rhs) const {
   return true;
 }
 
+bool IsForwardOnlyTaskNode(TaskNode* node) {
+  auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
+  if (fw_node == nullptr) { return true; }
+  return fw_node->HasBackwardCompTaskNode() == false;
+};
+
+bool NoOutRegstConsumedByBwNode(TaskNode* node) {
+  auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
+  if (fw_node == nullptr) { return false; }
+  for (TaskEdge* edge : fw_node->out_edges()) {
+    auto* fw_consumer = dynamic_cast<NormalForwardCompTaskNode*>(edge->dst_node());
+    if (fw_consumer == nullptr) { return false; }
+    if (fw_consumer->HasBackwardCompTaskNode()) { return false; }
+  }
+  return true;
+};
+
 }  // namespace
 
 std::string ChainNode::VisualStr() const {
@@ -145,6 +163,7 @@ ChainGraph::ChainGraph(const TaskGraph& task_gph) : task_gph_(task_gph) {
   std::vector<std::vector<TaskNode*>> chains;
   GroupTaskNodesByMachineAndCollectAncestors(task_gph, &machine2tasks, &node2ancestors);
   MergeTaskNodes(machine2tasks, node2ancestors, &chains);
+  for (auto& task_nodes : chains) { PrioritizeUntrainableTaskNode(&task_nodes); }
   InitChainNode(chains);
   InitChainEdge(chains);
   SetChainId4ChainNode();
@@ -158,7 +177,7 @@ void ChainGraph::GroupTaskNodesByMachineAndCollectAncestors(
     (*machine2tasks)[node->machine_id()].emplace_back(node);
     CHECK(node2ancestors->emplace(node, HashSet<TaskNode*>()).second);
     // to reduce memory consumption
-    if (node->area_id() == kMdUpdtArea) { return; }
+    if (node->AreaId4ChainMerge() == kMdUpdtArea) { return; }
     node->ForEachNodeOnInEdge([&](TaskNode* in_node) {
       if (IsBackEdge(in_node, node)) { return; }
       (*node2ancestors)[node].insert(in_node);
@@ -191,6 +210,85 @@ void ChainGraph::MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>&
   counter.WaitUntilCntEqualZero();
 }
 
+void ChainGraph::PrioritizeUntrainableTaskNode(std::vector<TaskNode*>* task_nodes) const {
+  HashSet<TaskNode*> task_nodes_set(task_nodes->begin(), task_nodes->end());
+  auto IsInSubset = [&](TaskNode* node) {
+    return task_nodes_set.find(node) != task_nodes_set.end();
+  };
+  auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
+    node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
+      if (IsInSubset(node_on_in_edge)) { Handler(node_on_in_edge); }
+    });
+  };
+  auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
+    node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
+      if (IsInSubset(node_on_out_edge)) { Handler(node_on_out_edge); }
+    });
+  };
+  auto IsSourceNode = [&](TaskNode* node) {
+    int32_t in_node_num = 0;
+    ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; });
+    return in_node_num == 0;
+  };
+  std::list<TaskNode*> starts;
+  for (TaskNode* node : task_nodes_set) {
+    if (IsSourceNode(node)) { starts.push_back(node); }
+  }
+  task_nodes->clear();
+  auto IsPrior = [&](TaskNode* node) {
+    return IsForwardOnlyTaskNode(node) && NoOutRegstConsumedByBwNode(node);
+  };
+  PartialPriorTopoForEachNode(starts, ForEachInNode, ForEachOutNode, IsPrior,
+                              [&](TaskNode* node) { task_nodes->push_back(node); });
+  HashSet<TaskNode*> task_nodes_set_check(task_nodes->begin(), task_nodes->end());
+  CHECK(task_nodes_set == task_nodes_set_check);
+}
+
+void ChainGraph::PartialPriorTopoForEachNode(
+    const std::list<TaskNode*> starts,
+    const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachInNode,
+    const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachOutNode,
+    const std::function<bool(TaskNode*)>& IsPrior,
+    const std::function<void(TaskNode*)>& Handler) const {
+  // collect prior nodes
+  HashSet<TaskNode*> prior_nodes;
+  auto IsTaskNodePrior = [&](TaskNode* node) {
+    if (!IsPrior(node)) { return false; }
+    bool is_prior = true;
+    ForEachInNode(node, [&](TaskNode* in_node) {
+      is_prior = is_prior && (prior_nodes.find(in_node) != prior_nodes.end());
+    });
+    return is_prior;
+  };
+  std::list<TaskNode*> nodes;
+  task_gph_.TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](TaskNode* node) {
+    if (IsTaskNodePrior(node)) { CHECK(prior_nodes.emplace(node).second); }
+    nodes.push_back(node);
+  });
+  // travel prior nodes;
+  auto ForEachPriorInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
+    ForEachInNode(node, [&](TaskNode* in_node) {
+      if (prior_nodes.find(in_node) != prior_nodes.end()) { Handler(in_node); }
+    });
+  };
+  auto ForEachPriorOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
+    ForEachOutNode(node, [&](TaskNode* out_node) {
+      if (prior_nodes.find(out_node) != prior_nodes.end()) { Handler(out_node); }
+    });
+  };
+  std::list<TaskNode*> prior_starts;
+  for (TaskNode* start : starts) {
+    if (IsTaskNodePrior(start)) { prior_starts.push_back(start); }
+  }
+  task_gph_.DfsTopoForEachNodeSortByDistanceToSink(prior_starts, ForEachPriorInNode,
+                                                   ForEachPriorOutNode, Handler);
+  // travel other nodes ;
+  task_gph_.DfsTopoForEachNodeSortByDistanceToSink(
+      starts, ForEachInNode, ForEachOutNode, [&](TaskNode* node) {
+        if (prior_nodes.find(node) == prior_nodes.end()) { Handler(node); }
+      });
+}
+
 void ChainGraph::InitChainNode(const std::vector<std::vector<TaskNode*>>& chains) {
   for (auto& chain : chains) {
     ChainNode* chain_node = new ChainNode(chain);
diff --git a/oneflow/core/graph/chain_graph.h b/oneflow/core/graph/chain_graph.h
index 8e097f19d890d559fb135a75455ddabc2cc06b04..59a61ba9dfe00bf9c7800a6dfbe5db62e96e134e 100644
--- a/oneflow/core/graph/chain_graph.h
+++ b/oneflow/core/graph/chain_graph.h
@@ -78,6 +78,13 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
   void MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>& machine2tasks,
                       const HashMap<TaskNode*, HashSet<TaskNode*>>& node2ancestors,
                       std::vector<std::vector<TaskNode*>>* chains) const;
+  void PrioritizeUntrainableTaskNode(std::vector<TaskNode*>* task_nodes) const;
+  void PartialPriorTopoForEachNode(
+      const std::list<TaskNode*> starts,
+      const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachInNode,
+      const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachOutNode,
+      const std::function<bool(TaskNode*)>& IsPrior,
+      const std::function<void(TaskNode*)>& Handler) const;
   void InitChainNode(const std::vector<std::vector<TaskNode*>>& chains);
   void InitChainEdge(const std::vector<std::vector<TaskNode*>>& chains);
   void SetChainId4ChainNode();
diff --git a/oneflow/core/graph/chain_logical_graph.cpp b/oneflow/core/graph/chain_logical_graph.cpp
deleted file mode 100644
index 16fcfeab6852dbec45ccef6748c7a17f9dce1173..0000000000000000000000000000000000000000
--- a/oneflow/core/graph/chain_logical_graph.cpp
+++ /dev/null
@@ -1,186 +0,0 @@
-#include "oneflow/core/operator/fully_connected_op.h"
-#include "oneflow/core/graph/chain_logical_graph.h"
-#include "oneflow/core/graph/logical_graph.h"
-
-namespace oneflow {
-
-struct ChainLogicalGraph::Chain {
-  std::vector<const LogicalNode*> nodes;
-  HashSet<const LogicalNode*> ancestors;
-  HashSet<const LogicalNode*> ancestors_and_this;
-  HashSet<const LogicalNode*> descendants;
-  HashSet<const LogicalNode*> descendants_and_this;
-  bool is_mergeable;
-
-  bool IsParallelDescEqual(const Chain& rhs) const {
-    CHECK_GT(nodes.size(), 0);
-    CHECK_GT(rhs.nodes.size(), 0);
-    return nodes.front()->parallel_desc()->Equal(rhs.nodes.front()->parallel_desc().get());
-  }
-};
-
-ChainLogicalGraph::ChainLogicalGraph(const LogicalGraph& logical_graph) {
-  std::list<Chain> chain_list;
-  HashMap<const LogicalNode*, std::list<Chain>::iterator> logical2chain_it;
-  HashMap<const LogicalNode*, size_t> logical2order_in_topo;
-
-  InitChains(logical_graph, &chain_list, &logical2chain_it, &logical2order_in_topo);
-  MergeChains(&chain_list, &logical2chain_it);
-  SortNodesInChains(&chain_list, logical2order_in_topo);
-  BuildGraph(logical_graph, &chain_list);
-  ToDotWithAutoFilePath();
-}
-
-void ChainLogicalGraph::InitChains(
-    const LogicalGraph& logical_graph, std::list<Chain>* chain_list,
-    HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it,
-    HashMap<const LogicalNode*, size_t>* logical2order_in_topo) {
-  logical_graph.ForEachNode([&](const LogicalNode* node) {
-    chain_list->emplace_back();
-    logical2chain_it->insert({node, --chain_list->end()});
-    Chain& chain = chain_list->back();
-    chain.nodes = {node};
-    chain.is_mergeable = IsLogicalNodeMergeable(node);
-    size_t order_in_topo = logical2order_in_topo->size();
-    logical2order_in_topo->emplace(node, order_in_topo);
-  });
-
-  logical_graph.TopoForEachNode([&](const LogicalNode* node) {
-    auto cur_chain = logical2chain_it->at(node);
-    for (const LogicalEdge* edge : node->in_edges()) {
-      LogicalNode* pred_node = edge->src_node();
-      auto pred_chain = logical2chain_it->at(pred_node);
-      cur_chain->ancestors.insert(pred_chain->ancestors.begin(), pred_chain->ancestors.end());
-      cur_chain->ancestors.insert(pred_node);
-    }
-    cur_chain->ancestors_and_this.insert(cur_chain->ancestors.begin(), cur_chain->ancestors.end());
-    cur_chain->ancestors_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end());
-  });
-
-  logical_graph.ReverseTopoForEachNode([&](const LogicalNode* node) {
-    auto cur_chain = logical2chain_it->at(node);
-    for (const LogicalEdge* edge : node->out_edges()) {
-      LogicalNode* succ_node = edge->dst_node();
-      auto succ_chain = logical2chain_it->at(succ_node);
-      cur_chain->descendants.insert(succ_chain->descendants.begin(), succ_chain->descendants.end());
-      cur_chain->descendants.insert(succ_node);
-    }
-    cur_chain->descendants_and_this.insert(cur_chain->descendants.begin(),
-                                           cur_chain->descendants.end());
-    cur_chain->descendants_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end());
-  });
-}
-
-void ChainLogicalGraph::MergeChains(
-    std::list<Chain>* chain_list,
-    HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
-  while (chain_list->size() > 1 && TryMergeTwoChains(chain_list, logical2chain_it)) {};
-}
-
-bool ChainLogicalGraph::TryMergeTwoChains(
-    std::list<Chain>* chain_list,
-    HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
-  return TryMergeTwoParallelChains(chain_list, logical2chain_it)
-         || TryMergeTwoConnectedChains(chain_list, logical2chain_it);
-}
-
-bool ChainLogicalGraph::TryMergeTwoParallelChains(
-    std::list<Chain>* chain_list,
-    HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
-  for (auto lhs = chain_list->begin(); lhs != chain_list->end(); ++lhs) {
-    if (!lhs->is_mergeable) { continue; }
-    for (auto rhs = lhs; rhs != chain_list->end(); ++rhs) {
-      if (lhs == rhs) { continue; }
-      if (!rhs->is_mergeable) { continue; }
-      if (!lhs->IsParallelDescEqual(*rhs)) { continue; }
-      if (lhs->ancestors != rhs->ancestors || lhs->descendants != rhs->descendants) { continue; }
-      for (const LogicalNode* node : rhs->nodes) {
-        lhs->nodes.push_back(node);
-        lhs->ancestors_and_this.insert(node);
-        lhs->descendants_and_this.insert(node);
-        logical2chain_it->at(node) = lhs;
-      }
-      chain_list->erase(rhs);
-      return true;
-    }
-  }
-  return false;
-}
-
-bool ChainLogicalGraph::TryMergeTwoConnectedChains(
-    std::list<Chain>* chain_list,
-    HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) {
-  for (auto succ_chain_it = chain_list->begin(); succ_chain_it != chain_list->end();
-       ++succ_chain_it) {
-    if (!succ_chain_it->is_mergeable) { continue; }
-    for (const LogicalNode* node_in_succ : succ_chain_it->nodes) {
-      for (const LogicalEdge* in_edge : node_in_succ->in_edges()) {
-        auto pred_chain_it = logical2chain_it->at(in_edge->src_node());
-        if (pred_chain_it == succ_chain_it) { continue; }
-        if (!pred_chain_it->is_mergeable) { continue; }
-        if (!pred_chain_it->IsParallelDescEqual(*succ_chain_it)) { continue; }
-        if (pred_chain_it->ancestors_and_this != succ_chain_it->ancestors
-            || pred_chain_it->descendants != succ_chain_it->descendants_and_this) {
-          continue;
-        }
-        for (const LogicalNode* node : succ_chain_it->nodes) {
-          pred_chain_it->nodes.push_back(node);
-          pred_chain_it->ancestors_and_this.insert(node);
-          pred_chain_it->descendants.erase(node);
-          logical2chain_it->at(node) = pred_chain_it;
-        }
-        chain_list->erase(succ_chain_it);
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-void ChainLogicalGraph::SortNodesInChains(
-    std::list<Chain>* chain_list,
-    const HashMap<const LogicalNode*, size_t>& logical2order_in_topo) {
-  for (Chain& chain : *chain_list) {
-    std::sort(chain.nodes.begin(), chain.nodes.end(),
-              [&](const LogicalNode* a, const LogicalNode* b) {
-                return logical2order_in_topo.at(a) < logical2order_in_topo.at(b);
-              });
-  }
-}
-
-void ChainLogicalGraph::BuildGraph(const LogicalGraph& logical_graph,
-                                   std::list<Chain>* chain_list) {
-  HashMap<const LogicalNode*, ChainLogicalNode*> logical_node2chain_logical_node;
-
-  for (const Chain& chain : *chain_list) {
-    ChainLogicalNode* chain_logical_node = NewNode();
-    chain_logical_node->mut_logical_nodes() = chain.nodes;
-    for (const LogicalNode* node : chain.nodes) {
-      CHECK(logical_node2chain_logical_node.emplace(node, chain_logical_node).second);
-    }
-  }
-
-  std::unordered_set<std::pair<ChainLogicalNode*, ChainLogicalNode*>> pred_succ_pairs;
-  logical_graph.ForEachEdge([&](const LogicalEdge* edge) {
-    pred_succ_pairs.emplace(logical_node2chain_logical_node.at(edge->src_node()),
-                            logical_node2chain_logical_node.at(edge->dst_node()));
-  });
-
-  for (auto& pair : pred_succ_pairs) {
-    if (pair.first == pair.second) { continue; }
-    ChainLogicalEdge* edge = NewEdge();
-    Connect(pair.first, edge, pair.second);
-  }
-}
-
-bool ChainLogicalGraph::IsLogicalNodeMergeable(const LogicalNode* logical_node) const {
-  if (logical_node->parallel_desc()->policy() != kDataParallel) { return false; }
-  if (!dynamic_cast<const NormalForwardLogicalNode*>(logical_node)) { return false; }
-  for (const std::shared_ptr<Operator>& op : logical_node->op_vec()) {
-    if (dynamic_cast<FullyConnectedOp*>(op.get())) { return false; }
-    if (op->IsRecurrentOp()) { return false; }
-  }
-  return true;
-}
-
-}  // namespace oneflow
diff --git a/oneflow/core/graph/chain_logical_graph.h b/oneflow/core/graph/chain_logical_graph.h
deleted file mode 100644
index bb95bcc43368e859ed89c57927ed3e72e6977a2e..0000000000000000000000000000000000000000
--- a/oneflow/core/graph/chain_logical_graph.h
+++ /dev/null
@@ -1,60 +0,0 @@
-#ifndef ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_
-#define ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_
-
-#include "oneflow/core/graph/graph.h"
-#include "oneflow/core/graph/logical_graph.h"
-
-namespace oneflow {
-
-class ChainLogicalEdge;
-
-class ChainLogicalNode final : public Node<ChainLogicalNode, ChainLogicalEdge> {
- public:
-  OF_DISALLOW_COPY_AND_MOVE(ChainLogicalNode);
-  ChainLogicalNode() = default;
-  ~ChainLogicalNode() override = default;
-
-  const std::vector<const LogicalNode*>& logical_nodes() const { return logical_nodes_; }
-  std::vector<const LogicalNode*>& mut_logical_nodes() { return logical_nodes_; }
-
- private:
-  std::vector<const LogicalNode*> logical_nodes_;
-};
-
-class ChainLogicalEdge final : public Edge<ChainLogicalNode, ChainLogicalEdge> {
- public:
-  OF_DISALLOW_COPY_AND_MOVE(ChainLogicalEdge);
-  ChainLogicalEdge() = default;
-  ~ChainLogicalEdge() override = default;
-};
-
-class ChainLogicalGraph final : public Graph<ChainLogicalNode, ChainLogicalEdge> {
- public:
-  OF_DISALLOW_COPY_AND_MOVE(ChainLogicalGraph);
-  explicit ChainLogicalGraph(const LogicalGraph& logical_graph);
-  ~ChainLogicalGraph() override = default;
-
- private:
-  struct Chain;
-  void InitChains(const LogicalGraph& logical_graph, std::list<Chain>* chain_list,
-                  HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it,
-                  HashMap<const LogicalNode*, size_t>* logical2order_in_topo);
-  void MergeChains(std::list<Chain>* chain_list,
-                   HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
-  bool TryMergeTwoChains(std::list<Chain>* chain_list,
-                         HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
-  bool TryMergeTwoParallelChains(
-      std::list<Chain>* chain_list,
-      HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
-  bool TryMergeTwoConnectedChains(
-      std::list<Chain>* chain_list,
-      HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it);
-  void SortNodesInChains(std::list<Chain>* chain_list,
-                         const HashMap<const LogicalNode*, size_t>& logical2order_in_topo);
-  void BuildGraph(const LogicalGraph& logical_graph, std::list<Chain>* chain_list);
-  bool IsLogicalNodeMergeable(const LogicalNode* logical_node) const;
-};
-
-}  // namespace oneflow
-
-#endif  // ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_
diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp
index d3ed963f0af535c3c90841dd882cccbd33552d97..4c94992570626e3237a73be324259a07aa31f67a 100644
--- a/oneflow/core/graph/exec_graph.cpp
+++ b/oneflow/core/graph/exec_graph.cpp
@@ -1,4 +1,5 @@
 #include "oneflow/core/graph/exec_graph.h"
+#include "oneflow/core/graph/op_graph.h"
 
 namespace oneflow {
 
@@ -51,8 +52,10 @@ void ExecNode::ToProto(bool is_forward, const ParallelContext* parallel_ctx,
 }
 
 void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) {
-  op_->InferBlobDescsIf(GetBlobDesc4BnInOpFunc(), parallel_ctx,
+  auto GetBlobDesc4BnInOp = GetBlobDesc4BnInOpFunc();
+  op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, Global<JobDesc>::Get()->RecordPieceSize(),
                         [this](OpContext* op_ctx) { op_ctx_.reset(op_ctx); });
+  Global<OpGraph>::Get()->CheckBlobDescs(op_->op_name(), GetBlobDesc4BnInOp, parallel_ctx);
 }
 
 void ExecNode::InferBwBufBlobDescs(const ParallelContext* parallel_ctx) {
diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h
index 7b8de5479a85748d8bf5f06c873ea9aabbe53983..7a00416ba74e0314f99d69a103529a1c9f1e91dc 100644
--- a/oneflow/core/graph/graph.h
+++ b/oneflow/core/graph/graph.h
@@ -45,8 +45,8 @@ class Graph {
       const std::function<void(NodeType*)>& Handler) const;
 
   // Getters
-  const std::unordered_set<NodeType*>& source_nodes() const;
-  const std::unordered_set<NodeType*>& sink_nodes() const;
+  std::list<NodeType*> source_nodes() const;
+  std::list<NodeType*> sink_nodes() const;
   NodeType* SoleSourceNode() const;
   NodeType* SoleSinkNode() const;
   NodeType* SoleNode() const;
@@ -57,7 +57,8 @@ class Graph {
   // Setters
   template<typename DerivedNodeType = NodeType>
   DerivedNodeType* NewNode();
-  EdgeType* NewEdge();
+  template<class... Args>
+  EdgeType* NewEdge(Args&&... args);
   void AddAllocatedNode(NodeType*);
   void AddAllocatedEdge(EdgeType*);
   void DeleteNode(NodeType*);
@@ -79,23 +80,47 @@ void Graph<NodeType, EdgeType>::ForEachNode(std::function<void(NodeType*)> NodeH
 }
 
 template<typename NodeType, typename EdgeType>
-void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {
-  std::list<NodeType*> starts;
+std::list<NodeType*> Graph<NodeType, EdgeType>::source_nodes() const {
+  std::list<NodeType*> ret;
+  ForEachNode([&](NodeType* node) {
+    if (node->in_edges().empty()) { ret.push_back(node); }
+  });
+  return ret;
+}
+
+template<typename NodeType, typename EdgeType>
+std::list<NodeType*> Graph<NodeType, EdgeType>::sink_nodes() const {
+  std::list<NodeType*> ret;
   ForEachNode([&](NodeType* node) {
-    if (node->in_edges().empty()) { starts.push_back(node); }
+    if (node->out_edges().empty()) { ret.push_back(node); }
   });
-  TopoForEachNode(starts, &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge,
+  return ret;
+}
+
+template<typename NodeType, typename EdgeType>
+NodeType* Graph<NodeType, EdgeType>::SoleSourceNode() const {
+  std::list<NodeType*> source_nodes_list = source_nodes();
+  CHECK_EQ(source_nodes_list.size(), 1);
+  return source_nodes_list.front();
+}
+
+template<typename NodeType, typename EdgeType>
+NodeType* Graph<NodeType, EdgeType>::SoleSinkNode() const {
+  std::list<NodeType*> sink_nodes_list = sink_nodes();
+  CHECK_EQ(sink_nodes_list.size(), 1);
+  return sink_nodes_list.front();
+}
+
+template<typename NodeType, typename EdgeType>
+void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {
+  TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge,
                   NodeHandler);
 }
 
 template<typename NodeType, typename EdgeType>
 void Graph<NodeType, EdgeType>::ReverseTopoForEachNode(
     std::function<void(NodeType*)> NodeHandler) const {
-  std::list<NodeType*> starts;
-  ForEachNode([&](NodeType* node) {
-    if (node->out_edges().empty()) { starts.push_back(node); }
-  });
-  TopoForEachNode(starts, &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge,
+  TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge,
                   NodeHandler);
 }
 
@@ -122,8 +147,9 @@ DerivedNodeType* Graph<NodeType, EdgeType>::NewNode() {
 }
 
 template<typename NodeType, typename EdgeType>
-EdgeType* Graph<NodeType, EdgeType>::NewEdge() {
-  EdgeType* ret = new EdgeType;
+template<class... Args>
+EdgeType* Graph<NodeType, EdgeType>::NewEdge(Args&&... args) {
+  EdgeType* ret = new EdgeType(std::forward<Args>(args)...);
   AddAllocatedEdge(ret);
   return ret;
 }
@@ -163,6 +189,7 @@ template<typename NodeType, typename EdgeType>
 void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path) {
   auto log_stream = TeePersistentLogStream::Create(file_path);
   ToDotWithStream(log_stream);
+  log_stream->Flush();
 }
 
 template<typename NodeType, typename EdgeType>
@@ -242,17 +269,18 @@ void Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink(
   }
   HashMap<NodeType*, int64_t> node2distance_to_sink;
   TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) {
-    int64_t distince_to_sink = -1;
+    int64_t distance_to_sink = -1;
     ForEachOutNode(node, [&](NodeType* out_node) {
-      distince_to_sink = std::max(distince_to_sink, node2distance_to_sink[out_node]);
+      distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]);
     });
-    node2distance_to_sink[node] = distince_to_sink + 1;
+    node2distance_to_sink[node] = distance_to_sink + 1;
   });
   auto ForEachOutNodeSortedByDistanceToSink = [&](NodeType* node,
                                                   const std::function<void(NodeType*)>& Handler) {
     std::vector<NodeType*> out_nodes;
     ForEachOutNode(node, [&](NodeType* out_node) { out_nodes.push_back(out_node); });
     std::sort(out_nodes.begin(), out_nodes.end(), [&](NodeType* lhs, NodeType* rhs) {
+      // DfsTopoForEachNode use stack, so sort desc
       return node2distance_to_sink.at(lhs) > node2distance_to_sink.at(rhs);
     });
     for (NodeType* out_node : out_nodes) { Handler(out_node); }
diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp
index 8ffccbef733b6ed5c519924ddbecd3fed6b48651..8b6aff144b7f379f28b36acfff35cac351a970f1 100644
--- a/oneflow/core/graph/logical_graph.cpp
+++ b/oneflow/core/graph/logical_graph.cpp
@@ -1,8 +1,8 @@
 #include "oneflow/core/graph/logical_graph.h"
 #include "oneflow/core/graph/task_graph.h"
+#include "oneflow/core/graph/op_graph.h"
 #include "oneflow/core/operator/operator.h"
 #include "oneflow/core/operator/op_conf.pb.h"
-#include "oneflow/core/graph/chain_logical_graph.h"
 #include "oneflow/core/common/balanced_splitter.h"
 
 namespace oneflow {
@@ -45,7 +45,6 @@ std::function<bool(const LogicalNode*)> MakePredicatorHasActualOutDiff(const Log
 LogicalGraph::LogicalGraph(bool is_train) {
   BuildFwStruct();
   if (is_train) { GroupNodesForReduceStruct(); }
-  SetMainModelParallel();
   if (is_train) { BuildBwStruct(); }
   MergeEdge();
   SetNodeDataLbi();
@@ -59,24 +58,41 @@ LogicalGraph::LogicalGraph(bool is_train) {
 }
 
 void LogicalGraph::GroupNodesForReduceStruct() {
-  ChainLogicalGraph chain_logical_graph(*this);
-  std::vector<std::vector<const LogicalNode*>> fw_node_groups;
-  chain_logical_graph.ForEachNode(
-      [&](ChainLogicalNode* node) { fw_node_groups.emplace_back(node->logical_nodes()); });
-  for (auto& fw_node_group : fw_node_groups) {
-    if (fw_node_group.size() < Global<JobDesc>::Get()->reduce_group_size()) {
-      fw_node_groups_.emplace_back(std::move(fw_node_group));
-    } else {
-      int64_t fw_node_group_size = fw_node_group.size();
-      int64_t seg_num = fw_node_group_size / Global<JobDesc>::Get()->reduce_group_size() + 1;
-      BalancedSplitter bs(fw_node_group_size, seg_num);
-      FOR_RANGE(int64_t, idx, 0, seg_num) {
-        std::vector<const LogicalNode*> sub_fw_node_group;
-        Range range = bs.At(idx);
-        FOR_RANGE(int64_t, nid, range.begin(), range.end()) {
-          sub_fw_node_group.emplace_back(fw_node_group[nid]);
-        }
-        fw_node_groups_.emplace_back(std::move(sub_fw_node_group));
+  // get op model size
+  HashMap<std::string, size_t> op_name2model_size;
+  auto OpName2ModelSize = [&](const std::string& op_name) -> size_t {
+    if (op_name2model_size.find(op_name) == op_name2model_size.end()) { return 0; }
+    return op_name2model_size.at(op_name);
+  };
+  Global<OpGraph>::Get()->InferOpModelSize(&op_name2model_size);
+  size_t model_total_size = 0;
+  for (const auto& pair : op_name2model_size) { model_total_size += pair.second; }
+  HashMap<ParallelDesc, std::list<const LogicalNode*>> parellel_desc2fw_group;
+  size_t avg_size = model_total_size / Global<JobDesc>::Get()->all_reduce_group_num();
+  const size_t group_min_size = Global<JobDesc>::Get()->all_reduce_group_min_byte();
+  const float group_size_warmup = Global<JobDesc>::Get()->all_reduce_group_size_warmup();
+  size_t cur_group_size = group_min_size / group_size_warmup;
+  auto GetCurGroupSize = [&](int32_t group_id) {
+    if (cur_group_size < avg_size) { cur_group_size *= group_size_warmup; }
+    return std::min(cur_group_size, avg_size);
+  };
+  // group fw nodes by parallel desc
+  ReverseTopoForEachNode([&](LogicalNode* fw_node) {
+    parellel_desc2fw_group[*fw_node->parallel_desc()].push_front(fw_node);
+  });
+  CHECK_GT(parellel_desc2fw_group.size(), 0);
+  for (auto& pair : parellel_desc2fw_group) {
+    fw_node_groups_.emplace_back(std::vector<const LogicalNode*>());
+    auto& fw_node_group = pair.second;
+    size_t cur_group_model_size = 0;
+    int32_t group_id = 0;
+    for (const LogicalNode* fw_node : fw_node_group) {
+      fw_node_groups_.back().emplace_back(fw_node);
+      cur_group_model_size += OpName2ModelSize(fw_node->SoleOp()->op_name());
+      if (cur_group_model_size >= GetCurGroupSize(group_id)) {
+        fw_node_groups_.emplace_back(std::vector<const LogicalNode*>());
+        cur_group_model_size = 0;
+        ++group_id;
       }
     }
   }
@@ -200,20 +216,6 @@ void LogicalGraph::LinkUnpackFw2PackFw(
   });
 }
 
-void LogicalGraph::SetMainModelParallel() {
-  ForEachNode([](LogicalNode* node) {
-    if (node->parallel_desc()->policy() == kModelParallel) { node->set_main_model_parallel(node); }
-  });
-  ForEachNode([](LogicalNode* node) {
-    LogicalNode* pred_node = node;
-    while (pred_node->SoleOp()->IsElemWiseOp()) { pred_node = pred_node->SoleInEdge()->src_node(); }
-    if (pred_node != node && pred_node->parallel_desc()->policy() == kModelParallel) {
-      node->mut_parallel_desc() = pred_node->parallel_desc();
-      node->set_main_model_parallel(pred_node);
-    }
-  });
-}
-
 void LogicalGraph::BuildBwStruct() {
   NaiveBuildBwStruct();
   AddBackwardClone();
@@ -356,6 +358,7 @@ void LogicalGraph::BuildLossPrintStruct() {
     reduce_loss_op_conf.set_device_type(loss_op->device_type());
     auto reduce_sum_conf = reduce_loss_op_conf.mutable_reduce_sum_conf();
     *(reduce_sum_conf->mutable_in_sys()) = loss_op->BnInOp2Lbi("loss");
+    reduce_sum_conf->add_axis(0);
     reduce_sum_conf->set_out("out");
     std::shared_ptr<Operator> reduce_loss_op = ConstructOp(reduce_loss_op_conf);
     loss_logical->mut_op_vec().push_back(reduce_loss_op);
@@ -508,8 +511,11 @@ void LogicalGraph::BuildModelStruct(bool is_train) {
       }
     }
   });
-  for (auto& fw_node_group : fw_node_groups_) {
+  for (int i = 0; i < fw_node_groups_.size(); ++i) {
+    auto& fw_node_group = fw_node_groups_[i];
     ReduceCtx group_reduce_ctx;
+    group_reduce_ctx.order_in_logical_graph = i;
+    int order_in_reduce_group = 0;
     for (auto& fw_node : fw_node_group) {
       auto reduce_ctx_it = fw_node2reduce_ctx.find(fw_node);
       if (reduce_ctx_it != fw_node2reduce_ctx.end()) {
@@ -518,42 +524,55 @@ void LogicalGraph::BuildModelStruct(bool is_train) {
         group_reduce_ctx.bw_logicals.emplace_back(reduce_ctx.bw_logicals.at(0));
         group_reduce_ctx.md_diff_acc_logicals.emplace_back(reduce_ctx.md_diff_acc_logicals.at(0));
         group_reduce_ctx.md_updt_logicals.emplace_back(reduce_ctx.md_updt_logicals.at(0));
+        auto* md_updt = dynamic_cast<NormalMdUpdtLogicalNode*>(reduce_ctx.md_updt_logicals.at(0));
+        md_updt->set_order_in_reduce_group(order_in_reduce_group++);
       }
     }
-    BuildReduceStruct(group_reduce_ctx);
+    if (group_reduce_ctx.fw_logicals.size() > 0) { BuildReduceStruct(group_reduce_ctx); }
   }
   SetupNormalMdUpdtOp();
 }
 
 void LogicalGraph::BuildReduceStruct(const ReduceCtx& reduce_ctx) {
-  if (reduce_ctx.fw_logicals.size() > 1) {
-    std::shared_ptr<const ParallelDesc> src_pd = reduce_ctx.fw_logicals[0]->parallel_desc();
-
-    OperatorConf reduce_concat_op_conf;
-    reduce_concat_op_conf.set_name("reduce_concat_" + NewUniqueId());
-    reduce_concat_op_conf.set_device_type(src_pd->device_type());
-    reduce_concat_op_conf.mutable_reduce_concat_conf()->set_in_num(reduce_ctx.fw_logicals.size());
-    LogicalNode* reduce_concat_node = NewNode<ReduceConcatLogicalNode>();
-    reduce_concat_node->mut_op_vec() = {ConstructOp(reduce_concat_op_conf)};
-    reduce_concat_node->mut_parallel_desc() = src_pd;
-
-    OperatorConf reduce_split_op_conf;
-    reduce_split_op_conf.set_name("reduce_split_" + NewUniqueId());
-    reduce_split_op_conf.set_device_type(src_pd->device_type());
-    reduce_split_op_conf.mutable_reduce_split_conf()->set_out_num(reduce_ctx.fw_logicals.size());
-    LogicalNode* reduce_split_node = NewNode<ReduceSplitLogicalNode>();
-    reduce_split_node->mut_op_vec() = {ConstructOp(reduce_split_op_conf)};
-    reduce_split_node->mut_parallel_desc() = src_pd;
-
-    for (auto& md_diff_acc_node : reduce_ctx.md_diff_acc_logicals) {
-      Connect(md_diff_acc_node, NewEdge(), reduce_concat_node);
-    }
-    for (auto& md_updt_node : reduce_ctx.md_updt_logicals) {
-      Connect(reduce_split_node, NewEdge(), md_updt_node);
-    }
-    AddAllReduce(reduce_concat_node, reduce_split_node);
-  } else if (reduce_ctx.fw_logicals.size() == 1) {
-    AddAllReduce(reduce_ctx.md_diff_acc_logicals.at(0), reduce_ctx.md_updt_logicals.at(0));
+  CHECK_GT(reduce_ctx.fw_logicals.size(), 0);
+  std::shared_ptr<const ParallelDesc> src_pd = reduce_ctx.fw_logicals[0]->parallel_desc();
+
+  OperatorConf reduce_concat_op_conf;
+  reduce_concat_op_conf.set_name("reduce_concat_" + NewUniqueId());
+  reduce_concat_op_conf.set_device_type(src_pd->device_type());
+  reduce_concat_op_conf.mutable_reduce_concat_conf()->set_in_num(reduce_ctx.fw_logicals.size());
+  LogicalNode* reduce_concat_node = NewNode<ReduceConcatLogicalNode>();
+  reduce_concat_node->mut_op_vec() = {ConstructOp(reduce_concat_op_conf)};
+  reduce_concat_node->mut_parallel_desc() = src_pd;
+
+  // We can not add ctrl edges between all_reduce nodes due to the implementation of nccl.
+  // So we add ctrl edges between ReduceIdentityTaskNodes which are followed by
+  // all_reduce nodes;
+  OperatorConf reduce_identity_conf;
+  reduce_identity_conf.set_name("reduce_identity_" + NewUniqueId());
+  reduce_identity_conf.set_device_type(src_pd->device_type());
+  reduce_identity_conf.mutable_reduce_identity_conf();
+  auto* reduce_identity_node = NewNode<ReduceIdentityLogicalNode>();
+  reduce_identity_node->mut_op_vec() = {ConstructOp(reduce_identity_conf)};
+  reduce_identity_node->mut_parallel_desc() = src_pd;
+  reduce_identity_node->set_order_in_logical_graph(reduce_ctx.order_in_logical_graph);
+
+  OperatorConf reduce_split_op_conf;
+  reduce_split_op_conf.set_name("reduce_split_" + NewUniqueId());
+  reduce_split_op_conf.set_device_type(src_pd->device_type());
+  reduce_split_op_conf.mutable_reduce_split_conf()->set_out_num(reduce_ctx.fw_logicals.size());
+  auto* reduce_split_node = NewNode<ReduceSplitLogicalNode>();
+  reduce_split_node->mut_op_vec() = {ConstructOp(reduce_split_op_conf)};
+  reduce_split_node->mut_parallel_desc() = src_pd;
+  reduce_split_node->set_order_in_logical_graph(reduce_ctx.order_in_logical_graph);
+
+  for (auto& md_diff_acc_node : reduce_ctx.md_diff_acc_logicals) {
+    Connect(md_diff_acc_node, NewEdge(), reduce_concat_node);
+  }
+  Connect(reduce_concat_node, NewEdge(), static_cast<LogicalNode*>(reduce_identity_node));
+  AddAllReduce(reduce_identity_node, reduce_split_node);
+  for (auto& md_updt_node : reduce_ctx.md_updt_logicals) {
+    Connect(static_cast<LogicalNode*>(reduce_split_node), NewEdge(), md_updt_node);
   }
 }
 
@@ -562,7 +581,8 @@ void LogicalGraph::AddAllReduce(LogicalNode* src, LogicalNode* dst) {
   std::shared_ptr<const ParallelDesc> dst_pd = dst->parallel_desc();
   CHECK_EQ(src_pd->parallel_num(), dst_pd->parallel_num());
   CHECK_EQ(src_pd->device_type(), dst_pd->device_type());
-  if (Global<JobDesc>::Get()->enable_nccl()) {
+
+  if (Global<JobDesc>::Get()->enable_nccl() && src_pd->device_type() == DeviceType::kGPU) {
     if (src_pd->sorted_machine_ids().size() == 1
         || Global<JobDesc>::Get()->use_nccl_inter_node_communication()) {
       AddNcclAllReduce(src, dst);
@@ -694,7 +714,8 @@ NormalMdUpdtLogicalNode* LogicalGraph::BuildNormalMdUpdtAndMdSaveStruct(
     // for model
     BuildMdSaveStruct(fw_logical, md_updt_logical);
     // TODO: remove the following ugly hard coded `if'
-    if (Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf()) {
+    if (Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf()
+        || Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_adam_conf()) {
       // for forward_model
       BuildMdSaveStruct(fw_logical, md_updt_logical);
     }
diff --git a/oneflow/core/graph/logical_graph.h b/oneflow/core/graph/logical_graph.h
index 2f61c869cb274a575e6c5f9866d87fee3557c453..bb90a6e253fefb846469ee80e751cdd4d832b361 100644
--- a/oneflow/core/graph/logical_graph.h
+++ b/oneflow/core/graph/logical_graph.h
@@ -28,6 +28,7 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
     std::vector<LogicalEdge*> edges;
   };
   struct ReduceCtx {
+    int32_t order_in_logical_graph;
     std::vector<LogicalNode*> fw_logicals;
     std::vector<LogicalNode*> bw_logicals;
     std::vector<LogicalNode*> md_diff_acc_logicals;
@@ -43,7 +44,6 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
   void LinkUnpackFw2PackFw(const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes);
   void ReConnectToFwClone(LogicalNode* clone_node, const LogicalBlobId& lbi,
                           const std::vector<LogicalEdge*>& edges, const std::string& obn);
-  void SetMainModelParallel();
   void BuildBwStruct();
   void NaiveBuildBwStruct();
   void AddBackwardClone();
diff --git a/oneflow/core/graph/logical_node.cpp b/oneflow/core/graph/logical_node.cpp
index 0ceb5e6090a2499b9904d57685f4fc7e364b4b42..e3fcdc25ee2ad236fa9d05d9ef6d44511c8b7704 100644
--- a/oneflow/core/graph/logical_node.cpp
+++ b/oneflow/core/graph/logical_node.cpp
@@ -23,13 +23,24 @@
 #include "oneflow/core/graph/accuracy_accumulate_compute_task_node.h"
 #include "oneflow/core/graph/accuracy_print_compute_task_node.h"
 #include "oneflow/core/graph/task_graph.h"
+#include "oneflow/core/graph/reduce_identity_task_node.h"
+#include "oneflow/core/graph/op_graph.h"
 
 namespace oneflow {
 
 namespace {
 
+bool HasSoleIdentityOp(const LogicalNode* logical_node) {
+  const auto& op_conf = logical_node->SoleOp()->op_conf();
+  return logical_node->op_vec().size() == 1
+         && (op_conf.has_parallel_cast_conf() || op_conf.has_tuple_identity_conf());
+}
+
 BldBoxingOpConfMthd GetBldBoxingOpConfMethodByFwParallelPolicy(const LogicalNode* in_logical,
                                                                const LogicalNode* out_logical) {
+  if (HasSoleIdentityOp(in_logical) || HasSoleIdentityOp(out_logical)) {
+    return &BoxingTaskNode::BldBoxingOpConfWithFwSbpParallel;
+  }
   ParallelPolicy in_policy = in_logical->parallel_desc()->policy();
   ParallelPolicy out_policy = out_logical->parallel_desc()->policy();
   if (in_policy == kDataParallel && out_policy == kDataParallel) {
@@ -47,6 +58,9 @@ BldBoxingOpConfMthd GetBldBoxingOpConfMethodByFwParallelPolicy(const LogicalNode
 }
 BldBoxingOpConfMthd GetBldBoxingOpConfMethodByBwParallelPolicy(const LogicalNode* in_logical,
                                                                const LogicalNode* out_logical) {
+  if (HasSoleIdentityOp(in_logical) || HasSoleIdentityOp(out_logical)) {
+    return &BoxingTaskNode::BldBoxingOpConfWithBwSbpParallel;
+  }
   ParallelPolicy in_policy = in_logical->parallel_desc()->policy();
   ParallelPolicy out_policy = out_logical->parallel_desc()->policy();
   if (in_policy == kDataParallel && out_policy == kDataParallel) {
@@ -237,6 +251,10 @@ void LogicalNode::GenSortedCompTaskNodes(
             comp_task_node->set_thrd_id(id_mgr->GetGpuMixThrdId(dev_phy_id));
             break;
           }
+          case CudaWorkType::kReduceCtrl: {
+            comp_task_node->set_thrd_id(id_mgr->GetGpuReduceCtrlThrdId(dev_phy_id));
+            break;
+          }
           case CudaWorkType::kMdUpdt: {
             comp_task_node->set_thrd_id(id_mgr->GetGpuMdUpdtThrdId(dev_phy_id));
             break;
@@ -259,30 +277,6 @@ void LogicalNode::GenSortedCompTaskNodes(
   }
 }
 
-int32_t LogicalNode::GetModelSplitAxis() const {
-  CHECK_EQ(parallel_desc_->policy(), kModelParallel);
-  CHECK_NOTNULL(main_model_parallel_);
-  if (main_model_parallel_ == this) {
-    int32_t ret = SoleOp()->ModelSplitAxis();
-    CHECK_NE(ret, -1);
-    return ret;
-  } else {
-    return main_model_parallel_->GetModelSplitAxis();
-  }
-}
-
-int32_t LogicalNode::GetMaxModelSplitNum() const {
-  CHECK_EQ(parallel_desc_->policy(), kModelParallel);
-  CHECK_NOTNULL(main_model_parallel_);
-  if (main_model_parallel_ == this) {
-    int32_t ret = SoleOp()->MaxModelSplitNum();
-    CHECK_NE(ret, -1);
-    return ret;
-  } else {
-    return main_model_parallel_->GetMaxModelSplitNum();
-  }
-}
-
 bool LogicalNode::HasOpWithCondition(std::function<bool(const Operator*)> cond) const {
   for (std::shared_ptr<const Operator> op : op_vec_) {
     if (cond(op.get())) { return true; }
@@ -291,12 +285,39 @@ bool LogicalNode::HasOpWithCondition(std::function<bool(const Operator*)> cond)
 }
 
 static bool IsModelParallel121(const LogicalNode* src_node, const LogicalNode* dst_node) {
-  return src_node->main_model_parallel() == dst_node->main_model_parallel();
+  if (src_node->parallel_desc()->parallel_num() != dst_node->parallel_desc()->parallel_num()) {
+    return false;
+  }
+  LogicalEdge* connect_edge = nullptr;
+  for (LogicalEdge* edge : src_node->out_edges()) {
+    if (edge->dst_node() == dst_node) { connect_edge = edge; }
+  }
+  CHECK_NOTNULL(connect_edge);
+  CHECK_GT(connect_edge->lbis().size(), 0);
+  const std::string& src_op_name = src_node->SoleOp()->op_name();
+  const std::string& dst_op_name = dst_node->SoleOp()->op_name();
+  for (const LogicalBlobId& lbi : connect_edge->lbis()) {
+    const auto& src_sbp = Global<OpGraph>::Get()->GetSbpParallel(src_op_name, lbi);
+    const auto& dst_sbp = Global<OpGraph>::Get()->GetSbpParallel(dst_op_name, lbi);
+    if (src_sbp != dst_sbp) { return false; }
+  }
+  return true;
 }
 
 BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const LogicalNode* dst_node) {
   std::shared_ptr<const ParallelDesc> src_pd = src_node->parallel_desc();
   std::shared_ptr<const ParallelDesc> dst_pd = dst_node->parallel_desc();
+  if (src_node->op_vec().size() == 1 && dst_node->op_vec().size() == 1) {
+    if (src_node->SoleOp()->op_conf().has_record_load_conf()
+        && dst_node->SoleOp()->op_conf().has_tick_conf()) {
+      CHECK(src_pd->parallel_num() == dst_pd->parallel_num());
+      CHECK(src_pd->policy() == kDataParallel && dst_pd->policy() == kDataParallel);
+    }
+    if (src_node->SoleOp()->op_conf().has_tick_conf()
+        && dst_node->SoleOp()->op_conf().has_log_counter_conf() == false) {
+      return &TaskGraph::BldSubTskGphByTickToSource;
+    }
+  }
   if (src_pd->parallel_num() == 1 && dst_pd->parallel_num() == 1) {
     return &TaskGraph::BldSubTskGphByOneToOne;
   }
@@ -403,6 +424,7 @@ REGISTER_BLD_BOXING_OP_CONF_MTHD("NormalBackward"
   OF_PP_MAKE_TUPLE_SEQ(MdDiffAcc, kDataBackwardArea)      \
   OF_PP_MAKE_TUPLE_SEQ(Print, kPrintArea)                 \
   OF_PP_MAKE_TUPLE_SEQ(ReduceConcat, kMdUpdtArea)         \
+  OF_PP_MAKE_TUPLE_SEQ(ReduceIdentity, kMdUpdtArea)       \
   OF_PP_MAKE_TUPLE_SEQ(ReduceScatter, kMdUpdtArea)        \
   OF_PP_MAKE_TUPLE_SEQ(ReduceAdd, kMdUpdtArea)            \
   OF_PP_MAKE_TUPLE_SEQ(ReduceGather, kMdUpdtArea)         \
@@ -425,7 +447,6 @@ BackwardLogicalNode* ForwardLogicalNode::NewBackwardNode() {
   bw_node_->mut_op_vec() = op_vec();
   bw_node_->mut_parallel_desc() = parallel_desc();
   bw_node_->fw_node_ = this;
-  bw_node_->set_main_model_parallel(main_model_parallel());
   return bw_node_;
 }
 
diff --git a/oneflow/core/graph/logical_node.h b/oneflow/core/graph/logical_node.h
index 5dc5e057c1ff3cf656017fdd4c8bc92883c26565..844d0ac11bf30b5e7a4e8f8576a4689bc5dbaf49 100644
--- a/oneflow/core/graph/logical_node.h
+++ b/oneflow/core/graph/logical_node.h
@@ -52,16 +52,11 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
                               std::vector<std::pair<int64_t, CompTaskNode*>>* nodes,
                               std::function<void(CompTaskNode*)>) const;
 
-  // model split
-  LogicalNode* main_model_parallel() const { return main_model_parallel_; }
-  void set_main_model_parallel(LogicalNode* val) { main_model_parallel_ = val; }
-  int32_t GetModelSplitAxis() const;
-  int32_t GetMaxModelSplitNum() const;
-
   virtual int64_t GetAreaId() const = 0;
+  virtual bool MayConsumeModelDiff() const { return false; }
 
  protected:
-  LogicalNode() : main_model_parallel_(nullptr) {}
+  LogicalNode() {}
   virtual CompTaskNode* NewCompTaskNode() const = 0;
   virtual void FixCompTaskNode(CompTaskNode*) const {}
 
@@ -71,7 +66,6 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
   std::vector<std::shared_ptr<Operator>> op_vec_;
   std::shared_ptr<const ParallelDesc> parallel_desc_;
   std::shared_ptr<const std::vector<LogicalNode*>> shared_model_nodes_;
-  LogicalNode* main_model_parallel_;
 
   HashMap<const LogicalNode*, std::vector<LogicalBlobId>> dst2data_lbis_;
 };
@@ -238,19 +232,31 @@ DECLARE_NAIVE_LOGICAL_NODE(AccuracyPrintLogicalNode);
 class NormalMdUpdtLogicalNode final : public LogicalNode {
  public:
   OF_DISALLOW_COPY_AND_MOVE(NormalMdUpdtLogicalNode);
-  NormalMdUpdtLogicalNode() : random_seed_(NewRandomSeed()) {}
+  NormalMdUpdtLogicalNode() : random_seed_(NewRandomSeed()), order_in_reduce_group_(0) {}
   ~NormalMdUpdtLogicalNode() = default;
 
   OVERRIDE_PURE_VIRTUAL_METHOD();
+  bool MayConsumeModelDiff() const override { return true; }
+
+  int order_in_reduce_group() const { return order_in_reduce_group_; }
+  void set_order_in_reduce_group(int order_in_reduce_group) {
+    order_in_reduce_group_ = order_in_reduce_group;
+  }
 
  private:
   void FixCompTaskNode(CompTaskNode*) const override;
 
   uint32_t random_seed_;
+  int order_in_reduce_group_;
 };
 
 DECLARE_NAIVE_LOGICAL_NODE(MdSaveLogicalNode);
-DECLARE_NAIVE_LOGICAL_NODE(MdDiffAccLogicalNode);
+
+class MdDiffAccLogicalNode final : public LogicalNode {
+ public:
+  LOGICAL_NODE_BOILERPLATE(MdDiffAccLogicalNode);
+  bool MayConsumeModelDiff() const override { return true; }
+};
 
 class ReduceLogicalNode : public LogicalNode {
  public:
@@ -275,23 +281,41 @@ class ReduceLogicalNode : public LogicalNode {
   }
 };
 
-#define DECLARE_REDUCE_LOGICAL_NODE(name)       \
-  class name final : public ReduceLogicalNode { \
-   public:                                      \
-    LOGICAL_NODE_BOILERPLATE(name);             \
+#define DECLARE_REDUCE_LOGICAL_NODE(name, may_consume_md_diff)                \
+  class name final : public ReduceLogicalNode {                               \
+   public:                                                                    \
+    LOGICAL_NODE_BOILERPLATE(name);                                           \
+    bool MayConsumeModelDiff() const override { return may_consume_md_diff; } \
   }
 
-DECLARE_REDUCE_LOGICAL_NODE(ReduceConcatLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(ReduceSplitLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(ReduceScatterLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(ReduceGatherLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(ReduceAddLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(NcclAllReduceLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(NcclAllGatherLogicalNode);
-DECLARE_REDUCE_LOGICAL_NODE(NcclReduceScatterLogicalNode);
+DECLARE_REDUCE_LOGICAL_NODE(ReduceConcatLogicalNode, true);
+DECLARE_REDUCE_LOGICAL_NODE(ReduceScatterLogicalNode, true);
+DECLARE_REDUCE_LOGICAL_NODE(ReduceGatherLogicalNode, false);
+DECLARE_REDUCE_LOGICAL_NODE(NcclAllReduceLogicalNode, true);
+DECLARE_REDUCE_LOGICAL_NODE(ReduceAddLogicalNode, false);
+DECLARE_REDUCE_LOGICAL_NODE(NcclAllGatherLogicalNode, false);
+DECLARE_REDUCE_LOGICAL_NODE(NcclReduceScatterLogicalNode, true);
 
 DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(RepeatForward);
 DECLARE_DERIVED_BACKWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(RepeatBackward);
+
+#define DECLARE_BEFORE_OR_AFTER_ALLREDUCE_REDUCE_NODE(class_name, may_consume_md_diff) \
+  class class_name final : public ReduceLogicalNode {                                  \
+   public:                                                                             \
+    LOGICAL_NODE_BOILERPLATE(class_name);                                              \
+    void set_order_in_logical_graph(int32_t order_in_logical_graph) {                  \
+      order_in_logical_graph_ = order_in_logical_graph;                                \
+    }                                                                                  \
+    int32_t order_in_logical_graph() const { return order_in_logical_graph_; }         \
+    bool MayConsumeModelDiff() const override { return may_consume_md_diff; }          \
+                                                                                       \
+   private:                                                                            \
+    int32_t order_in_logical_graph_;                                                   \
+  }
+
+DECLARE_BEFORE_OR_AFTER_ALLREDUCE_REDUCE_NODE(ReduceIdentityLogicalNode, true);
+DECLARE_BEFORE_OR_AFTER_ALLREDUCE_REDUCE_NODE(ReduceSplitLogicalNode, false);
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_GRAPH_LOGICAL_NODE_H_
diff --git a/oneflow/core/graph/loss_compute_task_node.cpp b/oneflow/core/graph/loss_compute_task_node.cpp
index d1cb346fe7113b959875d3b4ec66914630e739d2..08a9d3d5fd7f5c56bc74ccbd56882c2413b39bba 100644
--- a/oneflow/core/graph/loss_compute_task_node.cpp
+++ b/oneflow/core/graph/loss_compute_task_node.cpp
@@ -8,6 +8,7 @@ void LossCompTaskNode::ProduceAllRegstsAndBindEdges() {
   ProduceRegst("loss", false);
   ProduceRegst("out", true);
   ProduceRegst("data_tmp", true, 1, 1);
+  ProduceRegst("const_buf", false, 1, 1);
   for (TaskEdge* edge : out_edges()) {
     const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge);
     if (succ_logical->TypeName() == "LossAcc") {
@@ -37,6 +38,7 @@ void LossCompTaskNode::BuildExecGphAndRegst() {
   }
   std::shared_ptr<RegstDesc> data_tmp_regst = GetProducedRegst("data_tmp");
   loss_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, data_tmp_regst);
+  loss_node->AddBnToRegstAndBindIt(&Operator::const_buf_bns, GetProducedRegst("const_buf"));
 
   if (Global<JobDesc>::Get()->IsTrain()) {
     BuildRegstWhenTraining();
@@ -48,6 +50,8 @@ void LossCompTaskNode::BuildExecGphAndRegst() {
     }
   }
   mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
+  mut_exec_gph().TopoForEachNode(
+      [this](ExecNode* node) { node->FixInDiffBlobDescs(parallel_ctx()); });
 }
 
 void LossCompTaskNode::BuildRegstWhenTraining() {
diff --git a/oneflow/core/graph/normal_backward_compute_task_node.cpp b/oneflow/core/graph/normal_backward_compute_task_node.cpp
index 1bfbda612788d610d07990e8ef46d019a2b68f62..6ebb84eeaf091d8f76f4a0f878aa17cd9fa224c6 100644
--- a/oneflow/core/graph/normal_backward_compute_task_node.cpp
+++ b/oneflow/core/graph/normal_backward_compute_task_node.cpp
@@ -4,17 +4,19 @@
 
 namespace oneflow {
 
+int64_t NormalBackwardCompTaskNode::AreaId4ChainMerge() const {
+  CHECK_EQ(area_id(), AreaType::kDataBackwardArea);
+  return AreaType::kDataForwardArea;
+}
+
 void NormalBackwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
   ProduceRegst("in_diff", true);
   ProduceRegst("activation_diff", true, 1, 1);
   ProduceRegst("bw_buf", true, 1, 1);
   for (TaskEdge* edge : out_edges()) {
     const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge);
-    if (succ_logical->TypeName() == "MdDiffAcc" || succ_logical->TypeName() == "NormalMdUpdt"
-        || succ_logical->TypeName() == "ReduceScatter" || succ_logical->TypeName() == "ReduceConcat"
-        || succ_logical->TypeName() == "NcclAllReduce"
-        || succ_logical->TypeName() == "NcclReduceScatter") {
-      edge->AddRegst("model_diff", ProduceRegst("model_diff", true));
+    if (succ_logical->MayConsumeModelDiff()) {
+      edge->AddRegst("model_diff", ProduceRegst("model_diff", true, 1, 1));
       type_name4model_related_logical_node_ = succ_logical->TypeName();
     } else {
       BindEdgeWithProducedRegst(edge, "in_diff");
diff --git a/oneflow/core/graph/normal_backward_compute_task_node.h b/oneflow/core/graph/normal_backward_compute_task_node.h
index de6a83bbac56532ef790f4e2b51d1f21618b4ce4..d9eda294ea08020699b0f3674e5be787f195799c 100644
--- a/oneflow/core/graph/normal_backward_compute_task_node.h
+++ b/oneflow/core/graph/normal_backward_compute_task_node.h
@@ -16,6 +16,7 @@ class NormalBackwardCompTaskNode final : public CompTaskNode {
   void BuildExecGphAndRegst() override;
   TaskType GetTaskType() const override { return TaskType::kNormalBackward; }
   void RmUselessConsumeRelationshipToFw();
+  int64_t AreaId4ChainMerge() const override;
 
  protected:
   void BuildExecGphAndBindOutDiffRegst();
diff --git a/oneflow/core/graph/normal_forward_compute_task_node.cpp b/oneflow/core/graph/normal_forward_compute_task_node.cpp
index 0553c646c00ed58532460010c3799a5dcaf326f1..7e42af94b997e05f32b2ef156cf8726c86d26315 100644
--- a/oneflow/core/graph/normal_forward_compute_task_node.cpp
+++ b/oneflow/core/graph/normal_forward_compute_task_node.cpp
@@ -1,11 +1,39 @@
 #include "oneflow/core/graph/normal_forward_compute_task_node.h"
 #include "oneflow/core/graph/task_graph.h"
 #include "oneflow/core/graph/logical_node.h"
+#include "oneflow/core/operator/variable_op.h"
 
 namespace oneflow {
 
+namespace {
+
+bool IsAllOutputNaive(const LogicalNode* logical_node) {
+  const Operator* op = logical_node->SoleOp().get();
+  if (op->IsAllOutputConst()) {
+    return false;
+  } else if (dynamic_cast<const VariableOp*>(op) != nullptr) {
+    return false;
+  } else {
+    return true;
+  }
+}
+
+}  // namespace
+
+bool NormalForwardCompTaskNode::HasBackwardCompTaskNode() {
+  for (TaskEdge* edge : out_edges()) {
+    const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge);
+    if (succ_logical->TypeName() == "NormalBackward") { return true; }
+  }
+  return false;
+}
+
 void NormalForwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
-  ProduceRegst("out", true);
+  if (IsAllOutputNaive(logical_node())) {
+    ProduceRegst("out", true);
+  } else {
+    ProduceRegst("out", false, 1, 1);
+  }
   ProduceRegst("activation", true);
   ProduceRegst("data_tmp", true);
   ProduceRegst("fw_buf", true, 1, 1);
diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h
index e3680db4595c85684655fe008f38fd25b2cc7ab9..545acd83b58b76a86c112e74b09fea709800b6da 100644
--- a/oneflow/core/graph/normal_forward_compute_task_node.h
+++ b/oneflow/core/graph/normal_forward_compute_task_node.h
@@ -16,6 +16,7 @@ class NormalForwardCompTaskNode final : public CompTaskNode {
   bool IsReadyForBuild() override;
 
   TaskType GetTaskType() const override { return TaskType::kNormalForward; }
+  bool HasBackwardCompTaskNode();
   virtual void ToProto(TaskProto*) override;
 
   void set_random_seed(int64_t random_seed) { random_seed_ = random_seed; }
diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.cpp b/oneflow/core/graph/normal_model_update_compute_task_node.cpp
index 6f439735626d0eca98736c3067a97e7ba0f5f60e..83e89c6352e48f75ed0724f498fcd218e7b8f551 100644
--- a/oneflow/core/graph/normal_model_update_compute_task_node.cpp
+++ b/oneflow/core/graph/normal_model_update_compute_task_node.cpp
@@ -86,9 +86,12 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
                                                                       + "total_instance_num");
     op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name());
     if (Global<JobDesc>::Get()->IsTrain()) {
-      *(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) =
-          Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf();
-
+      if (lbi.blob_name() == "total_instance_num") {
+        op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()->mutable_naive_conf();
+      } else {
+        *(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) =
+            Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf();
+      }
       float primary_lr = Global<JobDesc>::Get()->primary_lr();
       float secondary_lr = Global<JobDesc>::Get()->secondary_lr();
       if (secondary_lr < 0) { secondary_lr = primary_lr; }
@@ -146,6 +149,13 @@ void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) {
   task_proto->set_related_init_model_task_id(related_init_model_task_id_);
 }
 
+void NormalMdUpdtCompTaskNode::FixPackedBlobDescOfProducedRegst() {
+  std::shared_ptr<RegstDesc> diff_add_out_regst = GetProducedRegst("processed_model_diff");
+  CHECK(diff_add_out_regst->IsLocked());
+  Shape& shape = diff_add_out_regst->MutBlobDesc(GenPackedLbi())->mut_shape();
+  shape = Shape({static_cast<int64_t>(RoundUp(shape.elem_cnt(), parallel_ctx()->parallel_num()))});
+}
+
 void NormalMdUpdtCompTaskNode::InferProducedDataRegstTimeShape() {
   ForEachProducedDataRegst([](const std::string& name, RegstDesc* regst) {
     if (name == "const_model") {
@@ -157,4 +167,21 @@ void NormalMdUpdtCompTaskNode::InferProducedDataRegstTimeShape() {
   });
 }
 
+void NormalMdUpdtCompTaskNode::EnableMemSharingBetweenFirstInAndProcessedMdDiffRegst() {
+  if (!IsTrainable()) { return; }
+  ExecNode* diff_add_node = exec_gph().SoleSourceNode();
+  RegstDesc* first_in_regst =
+      diff_add_node->RegstDesc4BnInOp(diff_add_node->op()->input_bns().Get(0));
+  RegstDesc* diff_add_out_regst = diff_add_node->RegstDesc4BnInOp(diff_add_node->op()->SoleObn());
+  CHECK_EQ(diff_add_out_regst, GetProducedRegst("processed_model_diff").get());
+  CHECK(first_in_regst->HasSameMemSize(diff_add_out_regst));
+  if (!first_in_regst->HasSetMemSharedId()) {
+    int64_t mem_shared_id = Global<IDMgr>::Get()->NewMemSharedId();
+    first_in_regst->set_enable_mem_sharing(true);
+    first_in_regst->set_mem_shared_id(mem_shared_id);
+    first_in_regst->set_mem_shared_offset(0);
+  }
+  diff_add_out_regst->CopyMemSharedInfoFrom(first_in_regst);
+}
+
 }  // namespace oneflow
diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.h b/oneflow/core/graph/normal_model_update_compute_task_node.h
index 3840232eb7e983b652d45d5e7fac2c2d85741b39..c80c1ec2f443a1bc358af277bbc9f0bbdf4471ad 100644
--- a/oneflow/core/graph/normal_model_update_compute_task_node.h
+++ b/oneflow/core/graph/normal_model_update_compute_task_node.h
@@ -17,6 +17,7 @@ class NormalMdUpdtCompTaskNode final : public CompTaskNode {
   bool IsReadyForBuild() override;
   void BuildExecGphAndRegst() override;
   void LockRegsts() override;
+  void EnableMemSharingBetweenFirstInAndProcessedMdDiffRegst();
 
   void set_random_seed(uint32_t val) { random_seed_ = val; }
   TaskType GetTaskType() const override { return TaskType::kNormalMdUpdt; }
@@ -26,6 +27,7 @@ class NormalMdUpdtCompTaskNode final : public CompTaskNode {
  private:
   const NormalForwardCompTaskNode* GetForwardTaskNode() const;
   bool IsTrainable() const;
+  void FixPackedBlobDescOfProducedRegst() override;
   void InferProducedDataRegstTimeShape() override;
   uint32_t random_seed_;
   int64_t related_init_model_task_id_;
diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ba46f2f6df8eb4b1482dbb05ce6f7883d9f3f180
--- /dev/null
+++ b/oneflow/core/graph/op_graph.cpp
@@ -0,0 +1,615 @@
+#include "oneflow/core/graph/op_graph.h"
+
+namespace oneflow {
+
+std::string OpEdge::VisualStr() const {
+  std::string str;
+  int32_t idx = 0;
+  for (const LogicalBlobId& lbi : lbis_) {
+    if (idx++ > 0) { str += "\\n"; }
+    str += lbi.blob_name() + ":";
+    str += src_node()->NoParallelBlobDesc4Lbi(lbi).shape().ToString();
+  }
+  return str;
+}
+
+bool* OpNode::MutIsModelBlob4Lbi(const LogicalBlobId& lbi) {
+  CHECK_EQ(ProducerOpNode4Lbi(lbi), this);
+  return &lbi2is_model_blob_[lbi];
+}
+bool OpNode::IsModelBlob4Lbi(const LogicalBlobId& lbi) const {
+  return ProducerOpNode4Lbi(lbi)->lbi2is_model_blob_.at(lbi);
+}
+
+const SbpParallel& OpNode::SbpParallel4Lbi(const LogicalBlobId& lbi) const {
+  return lbi2sbp_parallel_.at(lbi);
+}
+
+SbpParallel* OpNode::MutSbpParallel4Lbi(const LogicalBlobId& lbi) {
+  return &lbi2sbp_parallel_[lbi];
+}
+
+std::string OpNode::VisualStr() const {
+  std::string str = op().op_name();
+  {
+    for (int64_t machine_id : parallel_desc().sorted_machine_ids()) {
+      std::string dev_type;
+      if (parallel_desc().device_type() == DeviceType::kCPU) {
+        dev_type = "cpu";
+      } else if (parallel_desc().device_type() == DeviceType::kGPU) {
+        dev_type = "gpu";
+      } else {
+        UNIMPLEMENTED();
+      }
+      std::string parallel_desc_str = std::to_string(machine_id) + ":" + dev_type + ":";
+      const auto& dev_phy_ids = parallel_desc().sorted_dev_phy_ids(machine_id);
+      parallel_desc_str += std::to_string(dev_phy_ids.front());
+      if (dev_phy_ids.back() > dev_phy_ids.front()) {
+        parallel_desc_str += "-" + std::to_string(dev_phy_ids.back());
+      }
+      str += "\\n" + parallel_desc_str;
+    }
+  }
+  auto GetTimeShapeStr = [&](const Shape& shape, const std::string& prefix) {
+    std::string time_shape_str = prefix + ":";
+    time_shape_str += shape.ToString();
+    return time_shape_str;
+  };
+  if (in_edges().empty() == false) {
+    str += "\\n" + GetTimeShapeStr(*GetInputBlobTimeShape(), "in_blob_time_shape");
+  }
+  str += "\\n" + GetTimeShapeStr(out_blob_time_shape(), "out_blob_time_shape");
+  return str;
+}
+
+const BlobDesc& OpNode::NoParallelBlobDesc4Lbi(const LogicalBlobId& lbi) const {
+  return lbi2no_parallel_blob_desc_.at(lbi);
+}
+
+const BlobDesc& OpNode::LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const {
+  return lbi2logical_blob_desc_.at(lbi);
+}
+
+BlobDesc* OpNode::MutNoParallelBlobDesc(const LogicalBlobId& lbi) {
+  CHECK_EQ(lbi.op_name(), op().op_name());
+  return &lbi2no_parallel_blob_desc_[lbi];
+}
+
+BlobDesc* OpNode::MutLogicalBlobDesc4Lbi(const LogicalBlobId& lbi) {
+  CHECK_EQ(lbi.op_name(), op().op_name());
+  return &lbi2logical_blob_desc_[lbi];
+}
+
+BlobDesc* OpNode::NoParallelBlobDesc4BnInOp(const std::string& bn_in_op) {
+  return ProducerOpNode4BnInOp(bn_in_op)->MutNoParallelBlobDesc(op().BnInOp2Lbi(bn_in_op));
+}
+
+const Shape* OpNode::GetInputBlobTimeShape(const std::string& bn_in_op) const {
+  return &SrcNode4InputBnInOp(bn_in_op)->out_blob_time_shape();
+}
+
+OpNode* OpNode::ProducerOpNode4BnInOp(const std::string& bn_in_op) {
+  if (ibns_.find(bn_in_op) != ibns_.end()) { return SrcNode4InputBnInOp(bn_in_op); }
+  return this;
+}
+
+OpNode* OpNode::SrcNode4InputBnInOp(const std::string& bn_in_op) const {
+  const LogicalBlobId& lbi = op().BnInOp2Lbi(bn_in_op);
+  CHECK(ibns_.find(bn_in_op) != ibns_.end());
+  return SrcNode4InputLbi(lbi);
+}
+
+OpNode* OpNode::ProducerOpNode4Lbi(const LogicalBlobId& lbi) {
+  OpNode* producer = SrcNode4InputLbi(lbi);
+  if (producer == nullptr) { producer = this; }
+  return producer;
+}
+
+const OpNode* OpNode::ProducerOpNode4Lbi(const LogicalBlobId& lbi) const {
+  const OpNode* producer = SrcNode4InputLbi(lbi);
+  if (producer == nullptr) { producer = this; }
+  return producer;
+}
+
+OpNode* OpNode::SrcNode4InputLbi(const LogicalBlobId& lbi) const {
+  for (OpEdge* edge : in_edges()) {
+    for (const LogicalBlobId& edge_lbi : edge->lbis()) {
+      if (lbi == edge_lbi) { return edge->src_node(); }
+    }
+  }
+  return nullptr;
+}
+
+const Shape* OpNode::GetInputBlobTimeShape() const {
+  if (in_edges().empty()) { UNIMPLEMENTED(); }
+  OpNode* first_input = (*in_edges().begin())->src_node();
+  for (OpEdge* edge : in_edges()) {
+    CHECK_EQ(first_input->out_blob_time_shape(), edge->src_node()->out_blob_time_shape());
+  }
+  return &first_input->out_blob_time_shape();
+}
+
+void OpNode::ForEachParallelBlobDesc(const BlobDesc& blob_desc, const SbpParallel& sbp_parallel,
+                                     const std::function<void(const BlobDesc&)>& Handler) const {
+  if (sbp_parallel.has_split_parallel()) {
+    // split BlobDesc
+    int32_t axis = sbp_parallel.split_parallel().axis();
+    CHECK_GE(axis, 0);
+    CHECK_LT(axis, blob_desc.shape().NumAxes());
+    CHECK_GE(blob_desc.shape().At(axis), parallel_desc().parallel_num());
+    BalancedSplitter bs(blob_desc.shape().At(axis), parallel_desc().parallel_num());
+    FOR_RANGE(int64_t, axis_parallel_id, 0, parallel_desc().parallel_num()) {
+      BlobDesc sub_blob_desc(blob_desc);
+      sub_blob_desc.mut_shape().Set(axis, bs.At(axis_parallel_id).size());
+      Handler(sub_blob_desc);
+    }
+  } else {
+    CHECK(sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel());
+    // broadcast BlobDesc
+    FOR_RANGE(int64_t, axis_parallel_id, 0, parallel_desc().parallel_num()) { Handler(blob_desc); }
+  }
+}
+
+void OpNode::ConcatBlobDesc(const std::vector<BlobDesc>& blob_descs,
+                            const SbpParallel& sbp_parallel,
+                            BlobDesc* concatenated_blob_desc) const {
+  CHECK_EQ(blob_descs.size(), parallel_desc().parallel_num());
+  if (sbp_parallel.has_split_parallel()) {
+    int32_t axis = sbp_parallel.split_parallel().axis();
+    // concat BlobDesc
+    CHECK_GE(axis, 0);
+    CHECK_LT(axis, blob_descs.at(0).shape().NumAxes());
+    int64_t logical_blob_axis_dim = 0;
+    for (const BlobDesc& blob_desc : blob_descs) {
+      logical_blob_axis_dim += blob_desc.shape().At(axis);
+    }
+    CHECK_GE(logical_blob_axis_dim, parallel_desc().parallel_num());
+    BalancedSplitter bs(logical_blob_axis_dim, parallel_desc().parallel_num());
+    std::vector<BlobDesc> same_blob_descs(blob_descs);
+    FOR_RANGE(int64_t, axis_parallel_id, 0, parallel_desc().parallel_num()) {
+      CHECK_EQ(bs.At(axis_parallel_id).size(), blob_descs.at(axis_parallel_id).shape().At(axis));
+      same_blob_descs.at(axis_parallel_id).mut_shape().Set(axis, logical_blob_axis_dim);
+    }
+    for (const BlobDesc& blob_desc : same_blob_descs) { CHECK(blob_desc == same_blob_descs.at(0)); }
+    *concatenated_blob_desc = same_blob_descs.at(0);
+  } else {
+    // select first BlobDesc
+    for (const BlobDesc& blob_desc : blob_descs) { CHECK(blob_desc == blob_descs.at(0)); }
+    *concatenated_blob_desc = blob_descs.at(0);
+  }
+}
+
+int64_t OpNode::GetAxisParallelNum(
+    const std::function<void(bool*, int32_t*, int64_t*)>& GetAxisParallelInfo) const {
+  bool is_split = false;
+  int32_t axis = -1;
+  int64_t axis_parallel_num = 0;
+  GetAxisParallelInfo(&is_split, &axis, &axis_parallel_num);
+  return axis_parallel_num;
+}
+
+void OpNode::SplitLogicalInputBlobDesc() {
+  for (const std::string& bn : op().input_bns()) {
+    const LogicalBlobId& lbi = op().BnInOp2Lbi(bn);
+    const BlobDesc& logical_blob_desc = ProducerOpNode4BnInOp(bn)->LogicalBlobDesc4Lbi(lbi);
+    const SbpParallel& sbp_parallel = SbpParallel4Lbi(lbi);
+    ForEachParallelBlobDesc(logical_blob_desc, sbp_parallel, [&](const BlobDesc& blob_desc) {
+      lbi2parallel_id2blob_desc_[lbi].push_back(blob_desc);
+    });
+    CHECK_EQ(lbi2parallel_id2blob_desc_.at(lbi).size(), parallel_desc().parallel_num());
+  }
+}
+
+void OpNode::ConcatLogicalOutputBlobDesc() {
+  for (const std::string& bn : op().output_bns()) {
+    const LogicalBlobId& lbi = op().BnInOp2Lbi(bn);
+    const SbpParallel& sbp_parallel = SbpParallel4Lbi(lbi);
+    ConcatBlobDesc(lbi2parallel_id2blob_desc_.at(lbi), sbp_parallel, MutLogicalBlobDesc4Lbi(lbi));
+  }
+}
+
+void OpNode::CheckBlobDescs(const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx) const {
+  int64_t parallel_id = parallel_ctx->parallel_id();
+  auto Check = [&](const std::string& bn) {
+    const LogicalBlobId& lbi = op().BnInOp2Lbi(bn);
+    if (lbi2parallel_id2blob_desc_.find(lbi) == lbi2parallel_id2blob_desc_.end()) { return; }
+    CHECK_EQ(parallel_ctx->parallel_num(), lbi2parallel_id2blob_desc_.at(lbi).size());
+    const BlobDesc& blob_desc = *GetBlobDesc4BnInOp(bn);
+    CHECK(blob_desc == lbi2parallel_id2blob_desc_.at(lbi).at(parallel_id));
+  };
+  for (const std::string& bn : op().data_tmp_bns()) { Check(bn); }
+  for (const std::string& bn : op().fw_buf_bns()) { Check(bn); }
+  for (const std::string& bn : op().input_bns()) { Check(bn); }
+  for (const std::string& bn : op().output_bns()) { Check(bn); }
+  for (const std::string& bn : op().model_bns()) { Check(bn); }
+  for (const std::string& bn : op().const_model_bns()) { Check(bn); }
+  for (const std::string& bn : op().const_buf_bns()) { Check(bn); }
+  for (const std::string& bn : op().forward_model_bns()) { Check(bn); }
+}
+
+void OpGraph::InferOpModelSize(HashMap<std::string, size_t>* op_name2model_size) {
+  ForEachNode([&](OpNode* op_node) {
+    size_t model_size = 0;
+    for (const std::string& model_bn : op_node->op().model_bns()) {
+      int64_t elem_cnt = op_node->NoParallelBlobDesc4BnInOp(model_bn)->shape().elem_cnt();
+      model_size += elem_cnt * GetSizeOfDataType(job_desc_->DefaultDataType());
+      model_size = RoundUp(model_size, kCudaAlignSize);
+    }
+    size_t parallel_num = op_node->parallel_desc().parallel_num();
+    if (op_node->parallel_desc().policy() == ParallelPolicy::kModelParallel) {
+      model_size = (model_size + parallel_num - 1) / parallel_num;
+    }
+    CHECK(op_name2model_size->emplace(op_node->op().op_name(), model_size).second);
+  });
+}
+
+void OpGraph::Init() {
+  InitNodes();
+  ForEachNode(
+      [&](OpNode* node) { CHECK(op_name2op_node_.emplace(node->op().op_name(), node).second); });
+  InitEdges();
+  FixOpParallelDesc();
+  UpdateOpNodeHasInDiff();
+  InferTimeShape();
+  InferNoParallelBlobDesc();
+  InferIsModelBlob();
+  InferSbpParallel();
+  InferLogicalBlobDesc();
+}
+
+void OpGraph::InitNodes() {
+  auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job_desc_->placement());
+  for (const auto& op_conf : job_desc_->dlnet_conf().op()) {
+    OpNode* node = new OpNode(ParallelDesc(*ParallelConf4OpName(op_conf.name())), op_conf);
+    AddAllocatedNode(node);
+  }
+}
+
+void OpGraph::InitEdges() {
+  HashMap<LogicalBlobId, OpNode*> lbi2producer;
+  ForEachNode([&](OpNode* op_node) {
+    for (const auto& obn : op_node->op().output_bns()) {
+      CHECK(lbi2producer.emplace(op_node->op().BnInOp2Lbi(obn), op_node).second);
+    }
+  });
+  ForEachNode([&](OpNode* op_node) {
+    HashMap<std::string, std::vector<LogicalBlobId>> producer_name2lbis;
+    HashMap<std::string, HashMap<LogicalBlobId, std::vector<std::string>>>
+        consumer_op_name2lbi2ibns;
+    for (const auto& ibn : op_node->op().input_bns()) {
+      const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);
+      producer_name2lbis[lbi.op_name()].push_back(lbi);
+      consumer_op_name2lbi2ibns[op_node->op().op_name()][lbi].push_back(ibn);
+    }
+    for (const auto& pair : producer_name2lbis) {
+      const auto& lbis = pair.second;
+      const auto& lbi2ibns = consumer_op_name2lbi2ibns.at(op_node->op().op_name());
+      OpNode* producer = lbi2producer.at(lbis.at(0));
+      Connect(producer, NewEdge(lbis, lbi2ibns), op_node);
+    }
+  });
+}
+
+void OpGraph::FixOpParallelDesc() const {
+  ForEachNode([&](OpNode* node) { node->op().FixParallelDesc(node->mut_parallel_desc()); });
+}
+
+void OpGraph::UpdateOpNodeHasInDiff() const {
+  TopoForEachNode([&](OpNode* op_node) {
+    bool has_diff = false;
+    for (OpEdge* edge : op_node->in_edges()) {
+      if (edge->src_node()->has_in_diff() || edge->src_node()->has_model_diff()) {
+        edge->set_has_diff(true);
+        has_diff = true;
+        break;
+      }
+    }
+    op_node->set_has_in_diff(has_diff);
+  });
+}
+
+void OpGraph::InferTimeShape() const {
+  TopoForEachNode([&](OpNode* op_node) {
+    ParallelContext parallel_ctx;
+    parallel_ctx.set_parallel_id(0);
+    parallel_ctx.set_parallel_num(op_node->parallel_desc().parallel_num());
+    parallel_ctx.set_policy(op_node->parallel_desc().policy());
+    auto GetInputBlobTimeShape = [&](const std::string& bn_in_op) {
+      return op_node->GetInputBlobTimeShape(bn_in_op);
+    };
+    op_node->op().InferOutputBlobTimeShapeIf(GetInputBlobTimeShape, &parallel_ctx,
+                                             op_node->mut_out_blob_time_shape());
+  });
+}
+
+void OpGraph::InferNoParallelBlobDesc() const {
+  TopoForEachNode([&](OpNode* op_node) {
+    ParallelContext parallel_ctx;
+    parallel_ctx.set_parallel_id(0);
+    parallel_ctx.set_parallel_num(1);
+    parallel_ctx.set_policy(op_node->parallel_desc().policy());
+    // the real important data we want to get is:
+    // a) model blobs' byte size;
+    // b) number of axes of blobs' body shape;
+    // Hence the argument record_piece_size can be any positive number, here it's 1
+    op_node->op().InferBlobDescsIf(
+        std::bind(&OpNode::NoParallelBlobDesc4BnInOp, op_node, std::placeholders::_1),
+        &parallel_ctx, 1, [](OpContext*) {});
+  });
+}
+
+void OpGraph::InferIsModelBlob() const {
+  TopoForEachNode([&](OpNode* op_node) {
+    op_node->op().InferIsModelBlob4OutputBlobsIf([&](const std::string& bn) -> bool* {
+      return op_node->ProducerOpNode4BnInOp(bn)->MutIsModelBlob4Lbi(op_node->op().BnInOp2Lbi(bn));
+    });
+  });
+}
+
+void OpGraph::InferSbpParallel() const {
+  TopoForEachNode([&](OpNode* op_node) {
+    HashMap<std::string, SbpInferHint> ibn2sbp_infer_hint;
+    for (const std::string& ibn : op_node->op().input_bns()) {
+      const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);
+      OpNode* producer = op_node->SrcNode4InputBnInOp(ibn);
+      bool is_model_blob = producer->IsModelBlob4Lbi(lbi);
+      int64_t parallel_num = op_node->parallel_desc().parallel_num();
+      int64_t num_axes = producer->NoParallelBlobDesc4Lbi(lbi).shape().NumAxes();
+      const auto& sbp = producer->SbpParallel4Lbi(lbi);
+      ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(is_model_blob, parallel_num, num_axes, sbp));
+    }
+    auto SbpParallel4BnInOp = [&](const std::string& bn) -> SbpParallel* {
+      return op_node->MutSbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn));
+    };
+    auto SbpInferHint4Ibn = [&](const std::string& ibn) -> const SbpInferHint& {
+      return ibn2sbp_infer_hint.at(ibn);
+    };
+    ParallelContext parallel_ctx;
+    parallel_ctx.set_parallel_id(0);
+    parallel_ctx.set_parallel_num(op_node->parallel_desc().parallel_num());
+    parallel_ctx.set_policy(op_node->parallel_desc().policy());
+    op_node->op().InferInputOutputSbpParallelIf(SbpParallel4BnInOp, SbpInferHint4Ibn,
+                                                &parallel_ctx);
+  });
+}
+
+void OpGraph::InferLogicalBlobDesc() const {
+  TopoForEachNode([&](OpNode* op_node) {
+    auto* lbi2parallel_id2blob_desc = op_node->mut_lbi2parallel_id2blob_desc();
+    op_node->SplitLogicalInputBlobDesc();
+    int64_t parallel_num = op_node->parallel_desc().parallel_num();
+    const auto& input_bns = op_node->op().input_bns();
+    FOR_RANGE(int64_t, parallel_id, 0, parallel_num) {
+      auto BlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* {
+        const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn);
+        if (std::find(input_bns.begin(), input_bns.end(), bn) != input_bns.end()) {
+          CHECK(lbi2parallel_id2blob_desc->find(lbi) != lbi2parallel_id2blob_desc->end());
+          CHECK_EQ(lbi2parallel_id2blob_desc->at(lbi).size(), parallel_num);
+        } else if (lbi2parallel_id2blob_desc->find(lbi) == lbi2parallel_id2blob_desc->end()) {
+          (*lbi2parallel_id2blob_desc)[lbi].resize(parallel_num);
+        } else {
+          CHECK_EQ(lbi2parallel_id2blob_desc->at(lbi).size(), parallel_num);
+        }
+        return &(*lbi2parallel_id2blob_desc)[lbi][parallel_id];
+      };
+      ParallelContext parallel_ctx;
+      parallel_ctx.set_parallel_id(parallel_id);
+      parallel_ctx.set_parallel_num(parallel_num);
+      parallel_ctx.set_policy(op_node->parallel_desc().policy());
+      op_node->op().InferBlobDescsIf(BlobDesc4BnInOp, &parallel_ctx, job_desc_->RecordPieceSize(),
+                                     [](OpContext*) {});
+    }
+    op_node->ConcatLogicalOutputBlobDesc();
+  });
+}
+
+BalancedSplitter OpGraph::GetBalancedSplitter(const std::string& op_name,
+                                              const LogicalBlobId& lbi) const {
+  OpNode* op_node = op_name2op_node_.at(GetOpNameKey(op_name, lbi));
+  const SbpParallel& sbp_parallel = GetSbpParallel(op_name, lbi);
+  CHECK(sbp_parallel.has_split_parallel());
+  int64_t split_num = GetSplitNum(op_name, lbi);
+  if (IsDataBlob(op_name, lbi)) {
+    CHECK_EQ(split_num % op_node->parallel_desc().parallel_num(), 0);
+  } else {
+    CHECK(IsModelBlob(op_name, lbi));
+    CHECK_GE(split_num, op_node->parallel_desc().parallel_num());
+  }
+  return BalancedSplitter(split_num, op_node->parallel_desc().parallel_num());
+}
+
+int32_t OpGraph::GetModelSplitAxis(const std::string& op_name, const LogicalBlobId& lbi) const {
+  const SbpParallel& sbp_parallel = GetSbpParallel(op_name, lbi);
+  CHECK(sbp_parallel.has_split_parallel());
+  return sbp_parallel.split_parallel().axis();
+}
+
+int64_t OpGraph::GetSplitNum(const std::string& op_name, const LogicalBlobId& lbi) const {
+  OpNode* op_node = op_name2op_node_.at(GetOpNameKey(op_name, lbi));
+  const LogicalBlobId& lbi_key = GetLogicalBlobIdKey(op_name, lbi);
+  const SbpParallel& sbp_parallel = op_node->SbpParallel4Lbi(lbi_key);
+  CHECK(sbp_parallel.has_split_parallel());
+  return op_node->ProducerOpNode4Lbi(lbi)->LogicalBlobDesc4Lbi(lbi_key).shape().At(
+      sbp_parallel.split_parallel().axis());
+}
+
+const SbpParallel& OpGraph::GetSbpParallel(const std::string& op_name,
+                                           const LogicalBlobId& lbi) const {
+  return op_name2op_node_.at(GetOpNameKey(op_name, lbi))
+      ->SbpParallel4Lbi(GetLogicalBlobIdKey(op_name, lbi));
+}
+
+DataType OpGraph::GetBlobDataType(const LogicalBlobId& lbi) const {
+  return op_name2op_node_.at(lbi.op_name())
+      ->NoParallelBlobDesc4Lbi(GetLogicalBlobIdKey(lbi.op_name(), lbi))
+      .data_type();
+}
+
+bool OpGraph::IsModelBlob(const std::string& op_name, const LogicalBlobId& lbi) const {
+  return op_name2op_node_.at(GetOpNameKey(op_name, lbi))
+      ->IsModelBlob4Lbi(GetLogicalBlobIdKey(op_name, lbi));
+}
+
+bool OpGraph::IsDataBlob(const std::string& op_name, const LogicalBlobId& lbi) const {
+  return !IsModelBlob(op_name, lbi);
+}
+
+void OpGraph::CheckBlobDescs(const std::string& op_name,
+                             const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
+                             const ParallelContext* parallel_ctx) const {
+  if (op_name2op_node_.find(op_name) == op_name2op_node_.end()) { return; }
+  op_name2op_node_.at(op_name)->CheckBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);
+}
+
+void OpGraph::ForEachPseudoChain(
+    const std::function<void(const HashSet<OpNode*>&)>& Handler) const {
+  auto IsReachable = MakePredicatorIsReachable();
+  ForEachComponentWithSameDataParallelDescAndTimeShape(
+      [&](const std::vector<OpNode*>& nodes) { ForEachPseudoChain(nodes, IsReachable, Handler); });
+}
+
+std::function<bool(OpNode* src, OpNode* dst)> OpGraph::MakePredicatorIsReachable() const {
+  auto node2ancestors_ptr = std::make_shared<HashMap<OpNode*, HashSet<OpNode*>>>();
+  TopoForEachNode([&](OpNode* node) {
+    node->ForEachNodeOnInEdge([&](OpNode* in_node) {
+      (*node2ancestors_ptr)[node].insert(in_node);
+      (*node2ancestors_ptr)[node].insert((*node2ancestors_ptr)[in_node].begin(),
+                                         (*node2ancestors_ptr)[in_node].end());
+    });
+  });
+  return [node2ancestors_ptr](OpNode* src, OpNode* dst) -> bool {
+    return node2ancestors_ptr->at(dst).find(src) != node2ancestors_ptr->at(dst).end();
+  };
+}
+
+void OpGraph::ForEachComponentWithSameDataParallelDescAndTimeShape(
+    const std::function<void(const std::vector<OpNode*>&)>& Handler) const {
+  auto WithSameDataParallelDescAndTimeShape = [](OpNode* src, OpNode* dst) -> bool {
+    if (src->parallel_desc().policy() != ParallelPolicy::kDataParallel) { return false; }
+    if (dst->parallel_desc().policy() != ParallelPolicy::kDataParallel) { return false; }
+    if (src->in_edges().empty()) { return false; }
+    if (*src->GetInputBlobTimeShape() != src->out_blob_time_shape()) { return false; }
+    if (*dst->GetInputBlobTimeShape() != dst->out_blob_time_shape()) { return false; }
+    return src->parallel_desc() == dst->parallel_desc()
+           && src->out_blob_time_shape() == dst->out_blob_time_shape();
+  };
+  auto ForEachNext = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) {
+    node->ForEachNodeOnInEdge([&](OpNode* in_node) {
+      if (WithSameDataParallelDescAndTimeShape(in_node, node)) { Handler(in_node); }
+    });
+    node->ForEachNodeOnOutEdge([&](OpNode* out_node) {
+      if (WithSameDataParallelDescAndTimeShape(node, out_node)) { Handler(out_node); }
+    });
+  };
+  HashMap<OpNode*, int32_t> op_node2component_id;
+  int32_t cur_component_id = 0;
+  ForEachNode([&](OpNode* start) {
+    if (op_node2component_id.find(start) != op_node2component_id.end()) { return; }
+    ++cur_component_id;
+    BfsForEachNode({start}, ForEachNext, [&](OpNode* node) {
+      CHECK(op_node2component_id.emplace(node, cur_component_id).second);
+    });
+  });
+  HashMap<int32_t, std::vector<OpNode*>> component_id2op_nodes;
+  for (const auto& pair : op_node2component_id) {
+    component_id2op_nodes[pair.second].push_back(pair.first);
+  }
+  for (const auto& pair : component_id2op_nodes) { Handler(pair.second); }
+}
+
+void OpGraph::ForEachPseudoChain(
+    const std::vector<OpNode*>& nodes,
+    const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable,
+    const std::function<void(const HashSet<OpNode*>&)>& Handler) const {
+  if (nodes.size() <= 1) { return; }
+  if (nodes.front()->parallel_desc().device_type() == DeviceType::kCPU) { return; }
+  if (nodes.front()->parallel_desc().policy() != ParallelPolicy::kDataParallel) { return; }
+  HashSet<OpNode*> all_nodes(nodes.begin(), nodes.end());
+  while (all_nodes.size() > 1) {
+    HashSet<OpNode*> chain;
+    ReverseTopoGetPseudoChain(all_nodes, &chain, IsReachable);
+    Handler(chain);
+    for (OpNode* node_in_chain : chain) { all_nodes.erase(node_in_chain); }
+  }
+}
+
+void OpGraph::ReverseTopoGetPseudoChain(
+    const HashSet<OpNode*>& op_nodes, HashSet<OpNode*>* pseudo_chain_nodes,
+    const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable) const {
+  // get sink nodes
+  std::list<OpNode*> sinks;
+  auto IsSink = [&](OpNode* node) {
+    for (OpNode* inner_node : op_nodes) {
+      if (IsReachable(node, inner_node)) { return false; }
+    }
+    return true;
+  };
+  for (OpNode* op_node : op_nodes) {
+    if (IsSink(op_node)) { sinks.push_back(op_node); }
+  }
+  // generate connections of subgraph
+  HashMap<OpNode*, std::vector<OpNode*>> node2in_nodes;
+  HashMap<OpNode*, std::vector<OpNode*>> node2out_nodes;
+  auto IsInSubset = [&](OpNode* node) { return op_nodes.find(node) != op_nodes.end(); };
+  auto ReachableToAnySink = [&](OpNode* node) {
+    for (OpNode* sink : sinks) {
+      if (node == sink) { return true; }
+      if (IsReachable(node, sink)) { return true; }
+    }
+    return false;
+  };
+  auto AnyOutputNodesNotInSubsetAndReachableIntoSink = [&](OpNode* node) {
+    for (OpEdge* edge : node->out_edges()) {
+      if (!IsInSubset(edge->dst_node()) && ReachableToAnySink(edge->dst_node())) { return true; }
+    }
+    return false;
+  };
+  for (OpNode* node : op_nodes) {
+    if (AnyOutputNodesNotInSubsetAndReachableIntoSink(node)) { continue; }
+    node->ForEachNodeOnOutEdge([&](OpNode* out_node) {
+      if (IsInSubset(out_node)) {
+        node2in_nodes[out_node].push_back(node);
+        node2out_nodes[node].push_back(out_node);
+      }
+    });
+  }
+  // get chain nodes
+  auto ForEachInNode = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) {
+    for (OpNode* in_node : node2in_nodes[node]) { Handler(in_node); }
+  };
+  auto ForEachOutNode = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) {
+    for (OpNode* out_node : node2out_nodes[node]) { Handler(out_node); }
+  };
+  TopoForEachNode(sinks, ForEachOutNode, ForEachInNode,
+                  [&](OpNode* node) { CHECK(pseudo_chain_nodes->emplace(node).second); });
+}
+
+std::string OpGraph::GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const {
+  CHECK(!lbi.has_is_packed_id());
+  std::string op_name_key;
+  if (op_name2op_node_.find(op_name) == op_name2op_node_.end()) {
+    CHECK(lbi.has_clone_id());
+    return lbi.op_name();
+  } else {
+    CHECK(!lbi.has_clone_id());
+    return op_name;
+  }
+}
+
+LogicalBlobId OpGraph::GetLogicalBlobIdKey(const std::string& op_name,
+                                           const LogicalBlobId& lbi) const {
+  CHECK(!lbi.has_is_packed_id());
+  if (op_name2op_node_.find(op_name) == op_name2op_node_.end()) {
+    CHECK(lbi.has_clone_id());
+    LogicalBlobId lbi_key;
+    lbi_key.set_op_name(lbi.op_name());
+    lbi_key.set_blob_name(lbi.blob_name());
+    return lbi_key;
+  } else {
+    CHECK(!lbi.has_clone_id());
+    return lbi;
+  }
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/graph/op_graph.h b/oneflow/core/graph/op_graph.h
new file mode 100644
index 0000000000000000000000000000000000000000..94d0974e986fb647490e53bc0082e75dc3f134bf
--- /dev/null
+++ b/oneflow/core/graph/op_graph.h
@@ -0,0 +1,160 @@
+#ifndef ONEFLOW_CORE_GRAPH_OP_GRAPH_H_
+#define ONEFLOW_CORE_GRAPH_OP_GRAPH_H_
+
+#include "oneflow/core/graph/graph.h"
+#include "oneflow/core/job/job_desc.h"
+#include "oneflow/core/job/parallel_desc.h"
+#include "oneflow/core/operator/operator.h"
+#include "oneflow/core/common/balanced_splitter.h"
+
+namespace oneflow {
+
+class OpEdge;
+class OpGraph;
+
+class OpNode final : public Node<OpNode, OpEdge> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(OpNode);
+  explicit OpNode(const ParallelDesc& parallel_desc, const OperatorConf& op_conf)
+      : parallel_desc_(parallel_desc),
+        op_(ConstructOp(op_conf, parallel_desc.device_type())),
+        ibns_(op_->input_bns().begin(), op_->input_bns().end()),
+        has_in_diff_(false) {}
+  ~OpNode() = default;
+
+  // Getters
+  const Shape& out_blob_time_shape() const { return out_blob_time_shape_; }
+  const Operator& op() const { return *op_; }
+  bool HasBackward() const { return has_in_diff() || has_model_diff(); }
+  bool has_in_diff() const { return has_in_diff_; }
+  bool has_model_diff() const { return op().model_diff_bns().size() > 0; }
+  void set_has_in_diff(bool has_in_diff) { has_in_diff_ = has_in_diff; }
+  const ParallelDesc& parallel_desc() const { return parallel_desc_; }
+
+  std::string VisualStr() const override;
+
+ private:
+  friend class OpGraph;
+  friend class OpEdge;
+  // Getters
+  const BlobDesc& NoParallelBlobDesc4Lbi(const LogicalBlobId& lbi) const;
+  const BlobDesc& LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const;
+  const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const;
+  const Shape* GetInputBlobTimeShape(const std::string& bn_in_op) const;
+  const Shape* GetInputBlobTimeShape() const;
+
+  // Setters
+  ParallelDesc* mut_parallel_desc() { return &parallel_desc_; }
+  Shape* mut_out_blob_time_shape() { return &out_blob_time_shape_; }
+  HashMap<LogicalBlobId, std::vector<BlobDesc>>* mut_lbi2parallel_id2blob_desc() {
+    return &lbi2parallel_id2blob_desc_;
+  }
+  bool IsModelBlob4Lbi(const LogicalBlobId& lbi) const;
+  bool* MutIsModelBlob4Lbi(const LogicalBlobId& lbi);
+  BlobDesc* NoParallelBlobDesc4BnInOp(const std::string& bn_in_op);
+  BlobDesc* MutNoParallelBlobDesc(const LogicalBlobId& lbi);
+  BlobDesc* MutLogicalBlobDesc4Lbi(const LogicalBlobId& lbi);
+  SbpParallel* MutSbpParallel4Lbi(const LogicalBlobId& lbi);
+  OpNode* SrcNode4InputBnInOp(const std::string& bn_in_op) const;
+  OpNode* ProducerOpNode4BnInOp(const std::string& bn_in_op);
+  OpNode* SrcNode4InputLbi(const LogicalBlobId& lbi) const;
+  OpNode* ProducerOpNode4Lbi(const LogicalBlobId& lbi);
+  const OpNode* ProducerOpNode4Lbi(const LogicalBlobId& lbi) const;
+
+  void ForEachParallelBlobDesc(const BlobDesc& blob_desc, const SbpParallel& sbp_parallel,
+                               const std::function<void(const BlobDesc&)>& Handler) const;
+  int64_t GetAxisParallelNum(
+      const std::function<void(bool*, int32_t*, int64_t*)>& GetAxisParallelInfo) const;
+  void ConcatBlobDesc(const std::vector<BlobDesc>& blob_descs, const SbpParallel& sbp_parallel,
+                      BlobDesc* concatenated_blob_desc) const;
+  void SplitLogicalInputBlobDesc();
+  void ConcatLogicalOutputBlobDesc();
+  void CheckBlobDescs(const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const;
+
+  ParallelDesc parallel_desc_;
+  std::shared_ptr<Operator> op_;
+  HashSet<std::string> ibns_;
+  bool has_in_diff_;
+  Shape out_blob_time_shape_;
+  HashMap<LogicalBlobId, BlobDesc> lbi2no_parallel_blob_desc_;
+  HashMap<LogicalBlobId, bool> lbi2is_model_blob_;
+  HashMap<LogicalBlobId, SbpParallel> lbi2sbp_parallel_;
+  HashMap<LogicalBlobId, std::vector<BlobDesc>> lbi2parallel_id2blob_desc_;
+  HashMap<LogicalBlobId, BlobDesc> lbi2logical_blob_desc_;
+};
+
+class OpEdge final : public Edge<OpNode, OpEdge> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(OpEdge);
+  explicit OpEdge(const std::vector<LogicalBlobId>& lbis,
+                  const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns)
+      : lbis_(lbis), lbi2ibns_(lbi2ibns), has_diff_(false) {}
+  ~OpEdge() = default;
+
+  const std::vector<LogicalBlobId>& lbis() const { return lbis_; }
+  const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns() const { return lbi2ibns_; }
+  bool has_diff() const { return has_diff_; }
+  std::string VisualStr() const override;
+
+  void set_has_diff(bool val) { has_diff_ = val; }
+
+ private:
+  std::vector<LogicalBlobId> lbis_;
+  HashMap<LogicalBlobId, std::vector<std::string>> lbi2ibns_;
+  bool has_diff_;
+};
+
+class OpGraph final : public Graph<OpNode, OpEdge> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(OpGraph);
+  explicit OpGraph(const JobDesc* job_desc) : job_desc_(job_desc) { Init(); }
+  ~OpGraph() = default;
+
+  void InferOpModelSize(HashMap<std::string, size_t>* op_name2model_size);
+
+  int32_t GetModelSplitAxis(const std::string& op_name, const LogicalBlobId& lbi) const;
+  BalancedSplitter GetBalancedSplitter(const std::string& op_name, const LogicalBlobId& lbi) const;
+  const SbpParallel& GetSbpParallel(const std::string& op_name, const LogicalBlobId& lbi) const;
+  DataType GetBlobDataType(const LogicalBlobId& lbi) const;
+  void CheckBlobDescs(const std::string& op_name,
+                      const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const;
+
+  // a set of nodes is called a pseudo chain if they can merge into a chain regardless of the
+  // connections before their source nodes
+  void ForEachPseudoChain(const std::function<void(const HashSet<OpNode*>&)>& Handler) const;
+
+ private:
+  void Init();
+  void InitNodes();
+  void InitEdges();
+  void FixOpParallelDesc() const;
+  void UpdateOpNodeHasInDiff() const;
+  void InferTimeShape() const;
+  void InferNoParallelBlobDesc() const;
+  void InferIsModelBlob() const;
+  void InferSbpParallel() const;
+  void InferLogicalBlobDesc() const;
+  bool IsModelBlob(const std::string& op_name, const LogicalBlobId& lbi) const;
+  bool IsDataBlob(const std::string& op_name, const LogicalBlobId& lbi) const;
+  std::string GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const;
+  LogicalBlobId GetLogicalBlobIdKey(const std::string& op_name, const LogicalBlobId& lbi) const;
+  void ForEachPseudoChain(const std::vector<OpNode*>& nodes,
+                          const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable,
+                          const std::function<void(const HashSet<OpNode*>&)>& Handler) const;
+  void ReverseTopoGetPseudoChain(
+      const HashSet<OpNode*>& op_nodes, HashSet<OpNode*>* chain,
+      const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable) const;
+  std::function<bool(OpNode* src, OpNode* dst)> MakePredicatorIsReachable() const;
+  void ForEachComponentWithSameDataParallelDescAndTimeShape(
+      const std::function<void(const std::vector<OpNode*>&)>& Handler) const;
+
+  int64_t GetSplitNum(const std::string& op_name, const LogicalBlobId& lbi) const;
+  const JobDesc* job_desc_;
+  HashMap<std::string, OpNode*> op_name2op_node_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_
diff --git a/oneflow/core/graph/pack_forward_task_node.cpp b/oneflow/core/graph/pack_forward_task_node.cpp
index ff6bdec8b520cf75e49b4627410ddeed2e50018d..36ff141f95e91e62078c530bdc9a3e98ac693113 100644
--- a/oneflow/core/graph/pack_forward_task_node.cpp
+++ b/oneflow/core/graph/pack_forward_task_node.cpp
@@ -51,8 +51,7 @@ void PackForwardCompTaskNode::BuildExecGphAndRegst() {
   *out_blob = *in_blob;
   const PackOp* pack_op = dynamic_cast<const PackOp*>(op.get());
   CHECK_NOTNULL(pack_op);
-  CHECK_EQ(pack_op->GetPackNum(parallel_ctx()->parallel_num()),
-           related_unpack_in_blob->shape().At(0) / in_blob->shape().At(0));
+  CHECK_EQ(pack_op->GetPackNum(), related_unpack_in_blob->shape().At(0) / in_blob->shape().At(0));
   out_blob->mut_shape().Set(0, related_unpack_in_blob->shape().At(0));
   if (out_blob->has_dim0_valid_num_field()) {
     out_blob->mut_dim0_inner_shape() = related_unpack_in_blob->dim0_inner_shape();
@@ -66,7 +65,7 @@ void PackForwardCompTaskNode::InferProducedDataRegstTimeShape() {
 
   const PackOp* pack_op = dynamic_cast<const PackOp*>(logical_node()->SoleOp().get());
   CHECK_NOTNULL(pack_op);
-  int64_t pack_num = pack_op->GetPackNum(parallel_ctx()->parallel_num());
+  int64_t pack_num = pack_op->GetPackNum();
   CHECK_GT(time_shape_dim_vec.size(), 0);
   CHECK_EQ(pack_num, time_shape_dim_vec.back());
   time_shape_dim_vec.pop_back();
diff --git a/oneflow/core/graph/reduce_comp_task_node_if.h b/oneflow/core/graph/reduce_comp_task_node_if.h
index 0131bec319a1f5e9b1d75d518a1f915608b98ad5..988fcc18dca217298bfd8601ef666af480d6cb60 100644
--- a/oneflow/core/graph/reduce_comp_task_node_if.h
+++ b/oneflow/core/graph/reduce_comp_task_node_if.h
@@ -3,7 +3,7 @@
 
 #include "oneflow/core/register/register_desc.h"
 #include "oneflow/core/graph/compute_task_node.h"
-#include "logical_node.h"
+#include "oneflow/core/graph/logical_node.h"
 
 namespace oneflow {
 
diff --git a/oneflow/core/graph/reduce_concat_compute_task_node.h b/oneflow/core/graph/reduce_concat_compute_task_node.h
index 36a32511178b1cb325f944bcfd3097848c2e7583..00d8c2973c81e14b7ce70f27210059c26fef4e37 100644
--- a/oneflow/core/graph/reduce_concat_compute_task_node.h
+++ b/oneflow/core/graph/reduce_concat_compute_task_node.h
@@ -16,7 +16,7 @@ class ReduceConcatCompTaskNode final : public CompTaskNode, public ReduceCompTas
   void ConsumeAllRegsts() override;
 
   TaskType GetTaskType() const override { return TaskType::kReduceConcat; }
-  CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; }
+  CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kReduceCtrl; }
 
   void EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) override;
 
diff --git a/oneflow/core/graph/reduce_identity_task_node.cpp b/oneflow/core/graph/reduce_identity_task_node.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c889d3cfd0a9a70b78b76a4e9d60320ceec64002
--- /dev/null
+++ b/oneflow/core/graph/reduce_identity_task_node.cpp
@@ -0,0 +1,37 @@
+#include "oneflow/core/graph/reduce_identity_task_node.h"
+#include "oneflow/core/graph/logical_node.h"
+#include "oneflow/core/operator/reduce_identity_op.h"
+
+namespace oneflow {
+
+void ReduceIdentityCompTaskNode::ProduceAllRegstsAndBindEdges() {
+  ProduceRegst("out", false);
+  BindEdgeWithProducedRegst(SoleOutEdge(), "out");
+}
+
+void ReduceIdentityCompTaskNode::EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) {
+  ctx.EnableMemSharing4Regst(GetProducedRegst("out").get(), 0);
+  ctx.EnableMemSharing4Regst(GetSoleConsumedRegst("in").get(), 0);
+}
+
+void ReduceIdentityCompTaskNode::ConsumeAllRegsts() {
+  ConsumeRegst("in", SoleInEdge()->GetSoleRegst());
+}
+
+void ReduceIdentityCompTaskNode::BuildExecGphAndRegst() {
+  std::shared_ptr<const Operator> op = logical_node()->SoleOp();
+  ExecNode* exec_node = mut_exec_gph().NewNode();
+  exec_node->mut_op() = op;
+  exec_node->BindBnWithRegst(op->SoleIbn(), GetSoleConsumedRegst("in"));
+
+  std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
+  out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn()));
+  exec_node->BindBnWithRegst(op->SoleObn(), out_regst);
+  exec_node->InferBlobDescs(parallel_ctx());
+}
+
+void ReduceIdentityCompTaskNode::InferProducedDataRegstTimeShape() {
+  NaiveInferProducedDataRegstTimeShape();
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/graph/reduce_identity_task_node.h b/oneflow/core/graph/reduce_identity_task_node.h
new file mode 100644
index 0000000000000000000000000000000000000000..2f2c5d7b794b018af753887e825ccd774230d7e9
--- /dev/null
+++ b/oneflow/core/graph/reduce_identity_task_node.h
@@ -0,0 +1,31 @@
+#ifndef ONEFLOW_CORE_GRAPH_REDUCE_IDENTITY_TASK_NODE_H_
+#define ONEFLOW_CORE_GRAPH_REDUCE_IDENTITY_TASK_NODE_H_
+
+#include "oneflow/core/graph/compute_task_node.h"
+#include "oneflow/core/graph/logical_node.h"
+#include "oneflow/core/graph/pipe_compute_task_node.h"
+#include "oneflow/core/graph/reduce_comp_task_node_if.h"
+
+namespace oneflow {
+
+class ReduceIdentityCompTaskNode final : public CompTaskNode, public ReduceCompTaskNodeIf {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ReduceIdentityCompTaskNode);
+  ReduceIdentityCompTaskNode() = default;
+  ~ReduceIdentityCompTaskNode() override = default;
+
+  TaskType GetTaskType() const override { return TaskType::kReduceIdentity; }
+  CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kReduceCtrl; }
+  void EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) override;
+
+  void ProduceAllRegstsAndBindEdges() override;
+  void ConsumeAllRegsts() override;
+
+ private:
+  void BuildExecGphAndRegst() override;
+  void InferProducedDataRegstTimeShape() override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_GRAPH_REDUCE_IDENTITY_TASK_NODE_H_
diff --git a/oneflow/core/graph/reduce_split_compute_task_node.cpp b/oneflow/core/graph/reduce_split_compute_task_node.cpp
index c31fd0c327c94cec2fa0bb04aabc4137e24e76d8..245aa773fd269a9179673f8fef58c37fb87b91e2 100644
--- a/oneflow/core/graph/reduce_split_compute_task_node.cpp
+++ b/oneflow/core/graph/reduce_split_compute_task_node.cpp
@@ -5,6 +5,19 @@
 
 namespace oneflow {
 
+namespace {
+
+int32_t GetDataRegstDescCnt(
+    const HashMap<std::string, std::shared_ptr<RegstDesc>> name2regst_desc) {
+  size_t cnt = 0;
+  for (const auto& pair : name2regst_desc) {
+    cnt += pair.second->regst_desc_type().has_data_regst_desc();
+  }
+  return cnt;
+}
+
+}  // namespace
+
 void ReduceSplitCompTaskNode::ProduceAllRegstsAndBindEdges() {
   std::vector<EdgeInfo> edge_infos;
   for (TaskEdge* edge : out_edges()) {
@@ -32,17 +45,23 @@ void ReduceSplitCompTaskNode::ConsumeAllRegsts() {
   ConsumeRegst("in", this->SoleInEdge()->GetSoleRegst());
 }
 
+TaskNode* ReduceSplitCompTaskNode::GetPrevReduceTaskNode(TaskType task_type) {
+  CHECK(task_type == TaskType::kReduceConcat || task_type == TaskType::kReduceIdentity);
+  TaskNode* task_node =
+      FindPredReduceTaskNodeIf([&](TaskNode* node) { return node->GetTaskType() == task_type; });
+  CHECK_NOTNULL(task_node);
+  return task_node;
+}
+
 void ReduceSplitCompTaskNode::BuildExecGphAndRegst() {
   ExecNode* node = mut_exec_gph().NewNode();
   std::shared_ptr<Operator> reduce_split_op = this->logical_node()->SoleOp();
   node->mut_op() = reduce_split_op;
   node->BindBnWithRegst(reduce_split_op->SoleIbn(), GetSoleConsumedRegst("in"));
 
-  TaskNode* reduce_concat_node = FindPredReduceTaskNodeIf(
-      [](TaskNode* node) { return node->GetTaskType() == TaskType::kReduceConcat; });
-  CHECK(reduce_concat_node);
+  TaskNode* reduce_concat_node = GetPrevReduceTaskNode(TaskType::kReduceConcat);
 
-  CHECK_EQ(reduce_concat_node->consumed_regsts().size(), produced_regsts().size());
+  CHECK_EQ(reduce_concat_node->consumed_regsts().size(), GetDataRegstDescCnt(produced_regsts()));
   FOR_RANGE(size_t, i, 0, reduce_split_op->output_bns().size()) {
     std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out_" + std::to_string(i));
     CHECK(out_regst.get() != nullptr);
@@ -53,7 +72,7 @@ void ReduceSplitCompTaskNode::BuildExecGphAndRegst() {
 }
 
 void ReduceSplitCompTaskNode::FixPackedBlobDescOfProducedRegst() {
-  int64_t out_regst_num = produced_regsts().size();
+  int64_t out_regst_num = GetDataRegstDescCnt(produced_regsts());
   FOR_RANGE(int64_t, idx, 0, out_regst_num) {
     std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out_" + std::to_string(idx));
     CHECK(out_regst->IsLocked());
@@ -65,7 +84,7 @@ void ReduceSplitCompTaskNode::FixPackedBlobDescOfProducedRegst() {
 
 void ReduceSplitCompTaskNode::EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) {
   CHECK_EQ(GetRankCtx().TotalSegmentCount(), 1);
-  size_t split_num = produced_regsts().size();
+  size_t split_num = GetDataRegstDescCnt(produced_regsts());
   int64_t offset = 0;
   FOR_RANGE(int32_t, idx, 0, split_num) {
     RegstDesc* split_out_regst = GetProducedRegst("out_" + std::to_string(idx)).get();
diff --git a/oneflow/core/graph/reduce_split_compute_task_node.h b/oneflow/core/graph/reduce_split_compute_task_node.h
index 8fd05420a2932f42f0f5f061bc2a8118025cf64c..87a0028de529b06278bd4b41226841d3e4d1d07e 100644
--- a/oneflow/core/graph/reduce_split_compute_task_node.h
+++ b/oneflow/core/graph/reduce_split_compute_task_node.h
@@ -16,9 +16,11 @@ class ReduceSplitCompTaskNode final : public CompTaskNode, public ReduceCompTask
   void ConsumeAllRegsts() override;
 
   TaskType GetTaskType() const override { return TaskType::kReduceSplit; }
-  CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; }
+  CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kReduceCtrl; }
   void EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) override;
 
+  TaskNode* GetPrevReduceTaskNode(TaskType task_type);
+
  private:
   void BuildExecGphAndRegst() override;
   void FixPackedBlobDescOfProducedRegst() override;
diff --git a/oneflow/core/graph/regst_lifetime_graph.cpp b/oneflow/core/graph/regst_lifetime_graph.cpp
index dd7fa852de8c7ee7ba95eb4292e18bfdaf6c89b6..645949abf5eb928a403dae3efa5fc1e51b07af61 100644
--- a/oneflow/core/graph/regst_lifetime_graph.cpp
+++ b/oneflow/core/graph/regst_lifetime_graph.cpp
@@ -3,17 +3,17 @@
 namespace oneflow {
 
 RegstLifetimeGraph::RegstLifetimeGraph(
-    const std::list<const RegstDescProto*>& regst_descs,
+    const std::vector<const RegstDescProto*>& regst_descs,
     const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds) {
-  std::list<RegstLifetimeNode*> nodes;
+  std::vector<RegstLifetimeNode*> nodes;
   InitNodes(regst_descs, ComputeLifetimeActorIds, &nodes);
   InitEdges(nodes);
 }
 
 void RegstLifetimeGraph::InitNodes(
-    const std::list<const RegstDescProto*>& regst_descs,
+    const std::vector<const RegstDescProto*>& regst_descs,
     const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds,
-    std::list<RegstLifetimeNode*>* nodes) {
+    std::vector<RegstLifetimeNode*>* nodes) {
   for (const RegstDescProto* regst_desc : regst_descs) {
     auto lifetime_actor_ids = std::make_unique<HashSet<int64_t>>();
     ComputeLifetimeActorIds(regst_desc, lifetime_actor_ids.get());
@@ -23,7 +23,7 @@ void RegstLifetimeGraph::InitNodes(
   }
 }
 
-void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) {
+void RegstLifetimeGraph::InitEdges(const std::vector<RegstLifetimeNode*>& nodes) {
   HashMap<int64_t, HashSet<RegstLifetimeNode*>> task_id2intersected_nodes;
   for (RegstLifetimeNode* node : nodes) {
     for (int64_t task_id : node->lifetime_actor_ids()) {
@@ -46,25 +46,26 @@ void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) {
 }
 
 void RegstLifetimeGraph::ForEachSameColoredRegstDescs(
-    const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) const {
+    const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) const {
+  std::vector<const RegstLifetimeNode*> nodes;
+  ForEachNode([&](const RegstLifetimeNode* node) { nodes.push_back(node); });
+  std::sort(nodes.begin(), nodes.end(),
+            [&](const RegstLifetimeNode* lhs, const RegstLifetimeNode* rhs) {
+              return lhs->byte_size() > rhs->byte_size();
+            });
   HashMap<const RegstLifetimeNode*, std::set<int32_t>> node2excluded_color_ids;
   HashMap<const RegstLifetimeNode*, int32_t> node2color_id;
-  auto ForEachIntersected = &RegstLifetimeNode::ForEachNodeOnInOutEdge;
-  ForEachNode([&](const RegstLifetimeNode* start) {
-    if (node2color_id.find(start) != node2color_id.end()) { return; }
-    BfsForEachNode({start}, ForEachIntersected, [&](const RegstLifetimeNode* node) {
-      if (node2color_id.find(node) != node2color_id.end()) { return; }
-      int32_t color_id = 0;
-      const auto& excluded_color_ids = node2excluded_color_ids[node];
-      for (; excluded_color_ids.find(color_id) != excluded_color_ids.end(); ++color_id) {}
-      node2color_id[node] = color_id;
-      (node->*ForEachIntersected)([&](const RegstLifetimeNode* intersected) {
-        if (node2color_id.find(intersected) != node2color_id.end()) { return; }
-        node2excluded_color_ids[intersected].insert(color_id);
-      });
+  for (const RegstLifetimeNode* node : nodes) {
+    int32_t color_id = 0;
+    const auto& excluded_color_ids = node2excluded_color_ids[node];
+    for (; excluded_color_ids.find(color_id) != excluded_color_ids.end(); ++color_id) {}
+    node2color_id[node] = color_id;
+    node->ForEachNodeOnInOutEdge([&](const RegstLifetimeNode* intersected) {
+      if (node2color_id.find(intersected) != node2color_id.end()) { return; }
+      node2excluded_color_ids[intersected].insert(color_id);
     });
-  });
-  HashMap<int32_t, std::list<const RegstDescProto*>> color_id2regst_descs;
+  }
+  HashMap<int32_t, std::vector<const RegstDescProto*>> color_id2regst_descs;
   for (const auto& pair : node2color_id) {
     color_id2regst_descs[pair.second].push_back(&pair.first->regst_desc());
   }
diff --git a/oneflow/core/graph/regst_lifetime_graph.h b/oneflow/core/graph/regst_lifetime_graph.h
index d5aa92271bf6538ad9793c804ef407e81d835622..d60585c7aac5fb7a9fd9dc36f9dbb9848a6b1754 100644
--- a/oneflow/core/graph/regst_lifetime_graph.h
+++ b/oneflow/core/graph/regst_lifetime_graph.h
@@ -3,6 +3,7 @@
 
 #include "oneflow/core/graph/graph.h"
 #include "oneflow/core/register/register_desc.pb.h"
+#include "oneflow/core/register/runtime_register_desc.h"
 
 namespace oneflow {
 
@@ -20,37 +21,41 @@ class RegstLifetimeNode final : public Node<RegstLifetimeNode, RegstLifetimeEdge
   OF_DISALLOW_COPY_AND_MOVE(RegstLifetimeNode);
   RegstLifetimeNode(const RegstDescProto* regst_desc,
                     std::unique_ptr<HashSet<int64_t>>&& lifetime_actor_ids)
-      : regst_desc_(regst_desc), lifetime_actor_ids_(std::move(lifetime_actor_ids)) {}
+      : regst_desc_(regst_desc),
+        lifetime_actor_ids_(std::move(lifetime_actor_ids)),
+        byte_size_(RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst()) {
+    CHECK_EQ(regst_desc->register_num(), 1);
+  }
   ~RegstLifetimeNode() = default;
 
   int64_t regst_desc_id() const { return regst_desc().regst_desc_id(); }
   const RegstDescProto& regst_desc() const { return *regst_desc_; }
   const HashSet<int64_t>& lifetime_actor_ids() const { return *lifetime_actor_ids_; }
+  size_t byte_size() const { return byte_size_; }
 
  private:
   const RegstDescProto* regst_desc_;
   std::unique_ptr<HashSet<int64_t>> lifetime_actor_ids_;
+  size_t byte_size_;
 };
 
 class RegstLifetimeGraph final : public Graph<const RegstLifetimeNode, RegstLifetimeEdge> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(RegstLifetimeGraph);
   RegstLifetimeGraph(
-      const std::list<const RegstDescProto*>& regst_descs,
+      const std::vector<const RegstDescProto*>& regst_descs,
       const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds);
   ~RegstLifetimeGraph() = default;
 
   void ForEachSameColoredRegstDescs(
-      const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) const;
+      const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) const;
 
  private:
   void InitNodes(
-      const std::list<const RegstDescProto*>& regst_descs,
+      const std::vector<const RegstDescProto*>& regst_descs,
       const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds,
-      std::list<RegstLifetimeNode*>* nodes);
-  void InitEdges(const std::list<RegstLifetimeNode*>& nodes);
-  HashMap<const RegstLifetimeNode*, HashSet<const RegstLifetimeNode*>>
-      regst_lifetime_node2intersected_nodes_;
+      std::vector<RegstLifetimeNode*>* nodes);
+  void InitEdges(const std::vector<RegstLifetimeNode*>& nodes);
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/graph/repeat_backward_compute_task_node.cpp b/oneflow/core/graph/repeat_backward_compute_task_node.cpp
index e6c32ce9d1558cb27bc77538f8444a3f6002c42e..7474ee42816767e35e3b80a003a99e10a73d1378 100644
--- a/oneflow/core/graph/repeat_backward_compute_task_node.cpp
+++ b/oneflow/core/graph/repeat_backward_compute_task_node.cpp
@@ -34,7 +34,7 @@ void RepeatBackwardCompTaskNode::InferProducedDataRegstTimeShape() {
   CHECK(this->logical_node()->SoleOp()->op_conf().has_repeat_conf());
   const RepeatOp* repeat_op = dynamic_cast<RepeatOp*>(this->logical_node()->SoleOp().get());
   CHECK_NOTNULL(repeat_op);
-  int32_t repeat_num = repeat_op->GetRepeatNum(parallel_ctx()->parallel_num());
+  int32_t repeat_num = repeat_op->GetRepeatNum();
   CHECK(!time_shape_dim_vec.empty());
   CHECK(time_shape_dim_vec.back() == repeat_num);
   time_shape_dim_vec.pop_back();
diff --git a/oneflow/core/graph/repeat_forward_compute_task_node.cpp b/oneflow/core/graph/repeat_forward_compute_task_node.cpp
index 5022bebc802b8712766bd49868a2a98318818e55..f52509380b4c1ef85bdc21269007a3bf41317052 100644
--- a/oneflow/core/graph/repeat_forward_compute_task_node.cpp
+++ b/oneflow/core/graph/repeat_forward_compute_task_node.cpp
@@ -33,7 +33,7 @@ void RepeatForwardCompTaskNode::InferProducedDataRegstTimeShape() {
       GetSoleConsumedRegst("in")->data_regst_time_shape()->dim_vec();
   const RepeatOp* repeat_op = dynamic_cast<RepeatOp*>(this->logical_node()->SoleOp().get());
   CHECK_NOTNULL(repeat_op);
-  int32_t repeat_num = repeat_op->GetRepeatNum(parallel_ctx()->parallel_num());
+  int32_t repeat_num = repeat_op->GetRepeatNum();
   time_shape_dim_vec.push_back(repeat_num);
   GetProducedRegst("out")->mut_data_regst_time_shape()->reset(new Shape(time_shape_dim_vec));
 }
diff --git a/oneflow/core/graph/sharable_mem_block_graph.cpp b/oneflow/core/graph/sharable_mem_block_graph.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..50f5f7205fa05ec497f09e62094583e0c2a792e7
--- /dev/null
+++ b/oneflow/core/graph/sharable_mem_block_graph.cpp
@@ -0,0 +1,81 @@
+#include "oneflow/core/graph/sharable_mem_block_graph.h"
+#include "oneflow/core/register/register_desc.h"
+#include "oneflow/core/register/runtime_register_desc.h"
+
+namespace oneflow {
+
+namespace {
+
+bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
+                                       const PlanTaskGraph& plan_task_graph) {
+  auto ChainId4TaskId = [&](int64_t task_id) {
+    return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
+  };
+  int64_t producer_chain_id = ChainId4TaskId(regst_desc.producer_task_id());
+  for (int64_t consumer_task_id : regst_desc.consumer_task_id()) {
+    if (ChainId4TaskId(consumer_task_id) != producer_chain_id) { return false; }
+  }
+  return true;
+}
+
+}  // namespace
+
+SharableMemBlockGraph::SharableMemBlockGraph(
+    const PlanTaskGraph& plan_task_gph,
+    const std::function<bool(const RegstDescProto&)>& IsSharable) {
+  auto ForEachSharableChainRegstDesc =
+      [&](const std::function<void(int64_t, const RegstDescProto&)>& Handler) {
+        for (const TaskProto& task : plan_task_gph.plan().task()) {
+          for (const auto& pair : task.produced_regst_desc()) {
+            if (IsConsumersAndProducerInSameChain(pair.second, plan_task_gph)
+                && IsSharable(pair.second)) {
+              Handler(task.task_set_info().chain_id(), pair.second);
+            }
+          }
+        }
+      };
+  HashMap<std::pair<int64_t, MemBlock>, HashSet<const RegstDescProto*>>
+      chain_id7mem_block2regst_descs;
+  HashSet<int64_t> mem_block_ids_check;
+  ForEachSharableChainRegstDesc([&](int64_t chain_id, const RegstDescProto& regst_desc) {
+    int32_t idx = 0;
+    for (const auto& mem_block : regst_desc.mem_block_hierarchy()) {
+      if (idx++ == 0) { CHECK(mem_block_ids_check.emplace(mem_block.mem_block_id()).second); }
+      auto& regst_descs = chain_id7mem_block2regst_descs[std::make_pair(chain_id, mem_block)];
+      CHECK(regst_descs.emplace(&regst_desc).second);
+    }
+  });
+  HashMap<std::pair<int64_t, MemBlock>, SharableMemBlockNode*> chain_id7mem_block2node;
+  for (const auto& pair : chain_id7mem_block2regst_descs) {
+    auto* node =
+        new SharableMemBlockNode(pair.first.first, pair.first.second, pair.second, plan_task_gph);
+    AddAllocatedNode(node);
+    CHECK(chain_id7mem_block2node.emplace(pair.first, node).second);
+  }
+  HashSet<const SharableMemBlockNode*> connected_children;
+  ForEachSharableChainRegstDesc([&](int64_t chain_id, const RegstDescProto& regst_desc) {
+    SharableMemBlockNode* child = nullptr;
+    for (const auto& mem_block : regst_desc.mem_block_hierarchy()) {
+      auto* parent = chain_id7mem_block2node.at(std::make_pair(chain_id, mem_block));
+      if (child != nullptr && connected_children.find(child) == connected_children.end()) {
+        Connect(parent, NewEdge(), child);
+        CHECK(connected_children.emplace(child).second);
+      }
+      child = parent;
+    }
+  });
+}
+
+void SharableMemBlockGraph::ForEachSourceNodeGroup(
+    const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy,
+    const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const {
+  HashMap<int64_t, std::vector<const SharableMemBlockNode*>> group_key2source_nodes;
+  for (const SharableMemBlockNode* source : source_nodes()) {
+    group_key2source_nodes[GroupBy(source)].push_back(source);
+  }
+  for (const auto& pair : group_key2source_nodes) {
+    if (pair.second.size() > 1) { Handler(pair.second); }
+  }
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/graph/sharable_mem_block_graph.h b/oneflow/core/graph/sharable_mem_block_graph.h
new file mode 100644
index 0000000000000000000000000000000000000000..33354fddb6975dde37be6e1510fab5c0e12c3378
--- /dev/null
+++ b/oneflow/core/graph/sharable_mem_block_graph.h
@@ -0,0 +1,55 @@
+#ifndef ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_
+#define ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_
+
+#include "oneflow/core/graph/graph.h"
+#include "oneflow/core/register/register_desc.pb.h"
+#include "oneflow/core/graph/plan_task_graph.h"
+
+namespace oneflow {
+
+class SharableMemBlockEdge;
+
+class SharableMemBlockNode final : public Node<SharableMemBlockNode, SharableMemBlockEdge> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockNode);
+  SharableMemBlockNode(int64_t chain_id, const MemBlock& mem_block,
+                       const HashSet<const RegstDescProto*>& regst_descs,
+                       const PlanTaskGraph& plan_task_graph)
+      : chain_id_(chain_id),
+        mem_block_(mem_block),
+        regst_descs_(regst_descs.begin(), regst_descs.end()) {}
+
+  ~SharableMemBlockNode() = default;
+
+  int64_t chain_id() const { return chain_id_; }
+  const std::vector<const RegstDescProto*>& regst_descs() const { return regst_descs_; }
+  const MemBlock& mem_block() const { return mem_block_; }
+
+ private:
+  const int64_t chain_id_;
+  const MemBlock mem_block_;
+  const std::vector<const RegstDescProto*> regst_descs_;
+};
+
+class SharableMemBlockEdge final : public Edge<SharableMemBlockNode, SharableMemBlockEdge> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockEdge);
+  SharableMemBlockEdge() = default;
+  ~SharableMemBlockEdge() = default;
+};
+
+class SharableMemBlockGraph final : public Graph<const SharableMemBlockNode, SharableMemBlockEdge> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockGraph);
+  SharableMemBlockGraph(const PlanTaskGraph& plan_task_gph,
+                        const std::function<bool(const RegstDescProto&)>& IsSharable);
+  ~SharableMemBlockGraph() = default;
+
+  void ForEachSourceNodeGroup(
+      const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy,
+      const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_
diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp
index e74e65afdc01befb740a23c506e03f81584b79a2..daa1efd3ddfebae1fe2b35f2a1b5eecf27d6661f 100644
--- a/oneflow/core/graph/task_graph.cpp
+++ b/oneflow/core/graph/task_graph.cpp
@@ -11,9 +11,83 @@
 #include "oneflow/core/graph/reduce_split_compute_task_node.h"
 #include "oneflow/core/register/runtime_blob_desc.h"
 #include "oneflow/core/job/thrd_id_generator.h"
+#include "oneflow/core/graph/reduce_identity_task_node.h"
+#include "oneflow/core/operator/variable_op.h"
+#include "oneflow/core/operator/constant_op.h"
 
 namespace oneflow {
 
+namespace {
+
+bool IsConnectToTickOp(const TaskNode* node) {
+  const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node);
+  if (comp_task_node == nullptr) { return false; }
+  if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; }
+  const Operator* op = comp_task_node->logical_node()->SoleOp().get();
+  if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; }
+  if (dynamic_cast<const ConstantOp*>(op) != nullptr) { return true; }
+  return false;
+}
+
+template<typename ReduceType = ReduceIdentityLogicalNode>
+typename std::enable_if<std::is_same<ReduceType, ReduceIdentityLogicalNode>::value
+                            || std::is_same<ReduceType, ReduceSplitLogicalNode>::value,
+                        std::function<int32_t(const LogicalNode*)>>::type
+MakeGetterReduceTaskNodeCtrlOrder(const LogicalGraph& logical_graph) {
+  std::vector<const ReduceType*> logical_nodes;
+  logical_graph.ForEachNode([&](LogicalNode* node) {
+    auto* logical_node = dynamic_cast<ReduceType*>(node);
+    if (logical_node == nullptr) { return; }
+    logical_nodes.push_back(logical_node);
+  });
+  std::sort(logical_nodes.begin(), logical_nodes.end(),
+            [](const ReduceType* lhs, const ReduceType* rhs) {
+              return lhs->order_in_logical_graph() < rhs->order_in_logical_graph();
+            });
+  auto logical_node2ctrl_order = std::make_shared<HashMap<const LogicalNode*, int32_t>>();
+  int32_t lazy_count = Global<JobDesc>::Get()->all_reduce_lazy_ratio() * logical_nodes.size();
+  for (int32_t i = 0; i < logical_nodes.size(); ++i) {
+    int32_t ctrl_order = 0;
+    if (i > lazy_count) {
+      ctrl_order = -i;
+    } else {
+      ctrl_order = i;
+    }
+    (*logical_node2ctrl_order)[logical_nodes[i]] = ctrl_order;
+  }
+  return [logical_node2ctrl_order](const LogicalNode* identity_node) {
+    return logical_node2ctrl_order->at(identity_node);
+  };
+}
+
+void ForEachDeviceSrcUntrainableNode(const std::vector<NormalForwardCompTaskNode*>& fw_nodes,
+                                     const std::function<void(CompTaskNode*)>& Handler) {
+  HashSet<const TaskNode*> fw_nodes_set(fw_nodes.begin(), fw_nodes.end());
+  auto IsSourceTaskNode = [&](NormalForwardCompTaskNode* node) {
+    for (TaskEdge* edge : node->in_edges()) {
+      if (fw_nodes_set.find(edge->src_node()) != fw_nodes_set.end()) { return false; }
+    }
+    return true;
+  };
+  auto HasBwNode = [&](NormalForwardCompTaskNode* node) {
+    const auto* fw_logical_node = dynamic_cast<const ForwardLogicalNode*>(node->logical_node());
+    return fw_logical_node->bw_node() != nullptr;
+  };
+  for (NormalForwardCompTaskNode* fw_node : fw_nodes) {
+    if (IsSourceTaskNode(fw_node) && !HasBwNode(fw_node)) { Handler(fw_node); }
+  }
+}
+
+bool IsTimeShapeContain(const Shape& big_shape, const Shape& small_shape) {
+  if (big_shape.NumAxes() < small_shape.NumAxes()) { return false; }
+  FOR_RANGE(int, i, 0, small_shape.NumAxes()) {
+    if (big_shape.At(i) != small_shape.At(i)) { return false; }
+  }
+  return true;
+}
+
+}  // namespace
+
 TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
   logical_gph_ = std::move(logical_gph);
   HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
@@ -105,13 +179,13 @@ void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAll
                                        std::function<void(TaskNode* node)> Handler) const {
   auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
     node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
-      if (IsBackEdge(node_on_in_edge, node)) return;
+      if (IsBackEdge(node_on_in_edge, node)) { return; }
       Handler(const_cast<TaskNode*>(node_on_in_edge));
     });
   };
   auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) {
     node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
-      if (IsBackEdge(node, node_on_out_edge)) return;
+      if (IsBackEdge(node, node_on_out_edge)) { return; }
       Handler(const_cast<TaskNode*>(node_on_out_edge));
     });
   };
@@ -159,7 +233,8 @@ void TaskGraph::MergeChainAndSetOrderInGraphForEachNode() {
 
 void TaskGraph::BuildCtrlRegstDescInSameChain() {
   HashMap<int64_t, TaskNode*> chain_id2node;
-  for (auto node : ordered_task_nodes_) {
+  for (auto* node : ordered_task_nodes_) {
+    if (IsConnectToTickOp(node)) { continue; }
     int64_t chain_id = node->chain_id();
     auto iter = chain_id2node.find(chain_id);
     if (iter == chain_id2node.end()) {
@@ -171,6 +246,115 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() {
   }
 }
 
+void TaskGraph::AddReduceSequenceCtrlEdges() {
+  HashMap<int64_t, std::vector<ReduceSplitCompTaskNode*>> global_thrd_id2split_nodes;
+  for (auto* node : ordered_task_nodes_) {
+    auto* split_node = dynamic_cast<ReduceSplitCompTaskNode*>(node);
+    if (split_node == nullptr) { continue; }
+    int64_t global_thrd_id = Global<IDMgr>::Get()->GlobalThrdId4TaskId(split_node->task_id());
+    global_thrd_id2split_nodes[global_thrd_id].push_back(split_node);
+  }
+  auto GetCtrlOrder = MakeGetterReduceTaskNodeCtrlOrder<ReduceSplitLogicalNode>(*logical_gph_);
+  for (auto& pair : global_thrd_id2split_nodes) {
+    auto& split_nodes = pair.second;
+    std::sort(split_nodes.begin(), split_nodes.end(),
+              [&](ReduceSplitCompTaskNode* lhs, ReduceSplitCompTaskNode* rhs) {
+                return GetCtrlOrder(lhs->logical_node()) < GetCtrlOrder(rhs->logical_node());
+              });
+    ReduceSplitCompTaskNode* prev_split_node = split_nodes.at(0);
+    for (auto* split_node : split_nodes) {
+      if (prev_split_node != split_node) {
+        auto* to_node = split_node->GetPrevReduceTaskNode(TaskType::kReduceIdentity);
+        TaskNode* from_node = prev_split_node;
+        if (GetCtrlOrder(split_node->logical_node()) < 0) {
+          from_node = prev_split_node->GetPrevReduceTaskNode(TaskType::kReduceIdentity);
+        }
+        from_node->BuildCtrlRegstDescIfNeed(to_node);
+      }
+      prev_split_node = split_node;
+    }
+  }
+}
+
+void TaskGraph::AddMdUpdtCtrlEdgesWithinReduceSplitNode() {
+  auto GetOrderInReduceGroup = [&](NormalMdUpdtCompTaskNode* md_updt_node) {
+    const auto* logical_node =
+        dynamic_cast<const NormalMdUpdtLogicalNode*>(md_updt_node->logical_node());
+    return logical_node->order_in_reduce_group();
+  };
+  for (auto* node : ordered_task_nodes_) {
+    auto* split_node = dynamic_cast<ReduceSplitCompTaskNode*>(node);
+    if (split_node == nullptr) { continue; }
+    std::vector<NormalMdUpdtCompTaskNode*> md_updt_nodes;
+    split_node->ForEachNodeOnOutEdge([&](TaskNode* node) {
+      auto* md_updt_node = dynamic_cast<NormalMdUpdtCompTaskNode*>(node);
+      if (md_updt_node == nullptr) { return; }
+      md_updt_nodes.push_back(md_updt_node);
+    });
+    std::sort(md_updt_nodes.begin(), md_updt_nodes.end(),
+              [&](NormalMdUpdtCompTaskNode* lhs, NormalMdUpdtCompTaskNode* rhs) {
+                return GetOrderInReduceGroup(lhs) < GetOrderInReduceGroup(rhs);
+              });
+    NormalMdUpdtCompTaskNode* prev_md_updt = md_updt_nodes.at(0);
+    for (auto* md_updt_node : md_updt_nodes) {
+      if (md_updt_node != prev_md_updt) { prev_md_updt->BuildCtrlRegstDescIfNeed(md_updt_node); }
+      prev_md_updt = md_updt_node;
+    }
+  }
+}
+
+void TaskGraph::AddReduceNoBwForwardNodeOverlapingCtrlEdges() {
+  HashMap<int64_t, std::vector<ReduceIdentityCompTaskNode*>> global_thrd_id2identity_nodes;
+  HashMap<std::pair<int64_t, int64_t>, std::vector<NormalForwardCompTaskNode*>>
+      global_dev_phy_id2fw_nodes;
+  const auto* id_mgr = Global<IDMgr>::Get();
+  for (auto* node : ordered_task_nodes_) {
+    if (id_mgr->GetDeviceTypeFromThrdId(node->thrd_id()) == DeviceType::kCPU) { continue; }
+    int64_t global_thrd_id = id_mgr->GlobalThrdId4TaskId(node->task_id());
+    auto* identity_node = dynamic_cast<ReduceIdentityCompTaskNode*>(node);
+    auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
+    if (identity_node != nullptr) {
+      global_thrd_id2identity_nodes[global_thrd_id].push_back(identity_node);
+    } else if (fw_node != nullptr) {
+      int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(node->thrd_id());
+      global_dev_phy_id2fw_nodes[std::make_pair(node->machine_id(), dev_phy_id)].push_back(fw_node);
+    } else {
+      // do nothing
+    }
+  }
+  auto GetIdentityNodeOrder = [&](const ReduceIdentityCompTaskNode* id_node) {
+    const auto* id_logical_node =
+        dynamic_cast<const ReduceIdentityLogicalNode*>(id_node->logical_node());
+    return id_logical_node->order_in_logical_graph();
+  };
+  for (auto& pair : global_thrd_id2identity_nodes) {
+    auto& identity_nodes = pair.second;
+    std::sort(identity_nodes.begin(), identity_nodes.end(),
+              [&](ReduceIdentityCompTaskNode* lhs, ReduceIdentityCompTaskNode* rhs) {
+                return GetIdentityNodeOrder(lhs) < GetIdentityNodeOrder(rhs);
+              });
+    auto* first_identity_node = identity_nodes.at(0);
+    int64_t machine_id = first_identity_node->machine_id();
+    int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(first_identity_node->thrd_id());
+    const auto& fw_nodes = global_dev_phy_id2fw_nodes.at(std::make_pair(machine_id, dev_phy_id));
+    const Shape& identity_time_shape =
+        *first_identity_node->GetProducedRegst("out")->data_regst_time_shape();
+    ForEachDeviceSrcUntrainableNode(fw_nodes, [&](CompTaskNode* node) {
+      std::shared_ptr<RegstDesc> regst_desc = node->GetProducedRegst("out");
+      if (!regst_desc) { return; }
+      const Shape& time_shape = *regst_desc->data_regst_time_shape();
+      if (!IsTimeShapeContain(time_shape, identity_time_shape)) { return; }
+      CHECK_EQ(time_shape.elem_cnt() % identity_time_shape.elem_cnt(), 0);
+      int regst_desc_num = time_shape.elem_cnt() / identity_time_shape.elem_cnt();
+      RegstDesc* ctrl_regst_desc = node->BuildCtrlRegstDesc(first_identity_node);
+      ctrl_regst_desc->UpdtMinRegstNumIfNeed(regst_desc_num);
+      ctrl_regst_desc->UpdtMaxRegstNumIfNeed(regst_desc_num);
+      ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num(
+          regst_desc_num);
+    });
+  }
+}
+
 void TaskGraph::EnableMemSharingInReduceStruct() {
   auto GetPredReduceTaskNode = [](TaskNode* succ) {
     std::vector<TaskNode*> nodes;
@@ -223,6 +407,94 @@ void TaskGraph::EnableMemSharingInReduceStruct() {
   });
 }
 
+void TaskGraph::EnableMemSharingAfterAllManualSetForMdUpdt() {
+  ForEachNode([&](TaskNode* node) {
+    auto* updt = dynamic_cast<NormalMdUpdtCompTaskNode*>(node);
+    if (!updt) { return; }
+    updt->EnableMemSharingBetweenFirstInAndProcessedMdDiffRegst();
+  });
+}
+
+void TaskGraph::EnableMemSharingInVariableOp() {
+  ForEachNode([&](TaskNode* node) {
+    if (node->exec_gph().node_num() != 1) { return; }
+    auto* variable_op = dynamic_cast<const VariableOp*>(node->exec_gph().SoleNode()->op().get());
+    if (variable_op == nullptr) { return; }
+    std::string model_bn = variable_op->op_conf().variable_conf().model_name();
+    auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
+    auto* bw_task_node = dynamic_cast<NormalBackwardCompTaskNode*>(node);
+    if (fw_task_node != nullptr) {
+      const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(model_bn);
+      RegstDesc* model_regst = fw_task_node->GetSoleConsumedRegst("model").get();
+      CHECK_EQ(model_regst->min_register_num(), 1);
+      CHECK_EQ(model_regst->max_register_num(), 1);
+      model_regst->set_enable_mem_sharing(true);
+      if (model_regst->mem_shared_id() == -1) {
+        model_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId());
+        model_regst->set_mem_shared_offset(0);
+      }
+      RegstDesc* out_regst = fw_task_node->GetProducedRegst("out").get();
+      CHECK_EQ(out_regst->min_register_num(), 1);
+      CHECK_EQ(out_regst->max_register_num(), 1);
+      CHECK_EQ(out_regst->NumOfLbi(), 1);
+      out_regst->set_enable_mem_sharing(true);
+      out_regst->set_mem_shared_id(model_regst->mem_shared_id());
+      out_regst->set_mem_shared_offset(model_regst->mem_shared_offset()
+                                       + model_regst->ByteOffsetInPackedBlobDescBody(lbi));
+      variable_op->set_is_fw_inplace(true);
+    } else if (bw_task_node != nullptr) {
+      const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(GenDiffBn(model_bn));
+      RegstDesc* model_diff_regst = bw_task_node->GetProducedRegst("model_diff").get();
+      CHECK_EQ(model_diff_regst->min_register_num(), 1);
+      CHECK_EQ(model_diff_regst->max_register_num(), 1);
+      model_diff_regst->set_enable_mem_sharing(true);
+      if (model_diff_regst->mem_shared_id() == -1) {
+        model_diff_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId());
+        model_diff_regst->set_mem_shared_offset(0);
+      }
+      RegstDesc* out_diff_regst = bw_task_node->GetSoleConsumedRegst("out_diff").get();
+      if (out_diff_regst->min_register_num() != 1) { return; }
+      if (out_diff_regst->max_register_num() != 1) { return; }
+      if (out_diff_regst->NumOfLbi() != 1) { return; }
+      out_diff_regst->set_enable_mem_sharing(true);
+      out_diff_regst->set_mem_shared_id(model_diff_regst->mem_shared_id());
+      out_diff_regst->set_mem_shared_offset(
+          model_diff_regst->mem_shared_offset()
+          + model_diff_regst->ByteOffsetInPackedBlobDescBody(lbi));
+      variable_op->set_is_bw_inplace(true);
+    } else {
+      // do nothing
+    }
+  });
+}
+
+void TaskGraph::EnableInplaceMemSharing() {
+  AcyclicTopoForEachNode([&](TaskNode* node) {
+    if (node->exec_gph().node_num() != 1) { return; }
+    const Operator* op = node->exec_gph().SoleNode()->op().get();
+    auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
+    auto* bw_task_node = dynamic_cast<NormalBackwardCompTaskNode*>(node);
+    RegstDesc* input_regst = nullptr;
+    RegstDesc* output_regst = nullptr;
+    if (op->IsForwardInplace() && fw_task_node) {
+      input_regst = fw_task_node->GetSoleConsumedRegst("in").get();
+      output_regst = fw_task_node->GetProducedRegst("out").get();
+    } else if (op->IsBackwardInplace() && bw_task_node) {
+      input_regst = bw_task_node->GetSoleConsumedRegst(GenDiffBn("out")).get();
+      output_regst = bw_task_node->GetProducedRegst(GenDiffBn("in")).get();
+    } else {
+      // do nothing
+      return;
+    }
+    if (input_regst->NumOfLbi() != 1) { return; }
+    if (output_regst->NumOfLbi() != 1) { return; }
+    if (input_regst->mem_shared_inplace_block_id() == -1) {
+      input_regst->set_mem_shared_inplace_block_id(Global<IDMgr>::Get()->NewMemBlockId());
+    }
+    output_regst->set_mem_shared_inplace_block_id(input_regst->mem_shared_inplace_block_id());
+  });
+}
+
 void TaskGraph::RmUselessConsumeRelationshipBetweenFwBw() {
   for (TaskNode* task_node : ordered_task_nodes_) {
     auto bw_node = dynamic_cast<NormalBackwardCompTaskNode*>(task_node);
@@ -331,6 +603,31 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
   }
 }
 
+DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByTickToSource) {
+  CHECK(src_logical->SoleOp()->op_conf().has_tick_conf());
+  HashMap<size_t, CompTaskNode*> machine_id2tick_task;
+  HashMap<size_t, std::vector<CompTaskNode*>> machine_id2dst_tasks;
+  for (CompTaskNode* tick_node : sorted_src_comp_tasks) {
+    machine_id2tick_task[tick_node->machine_id()] = tick_node;
+  }
+  for (CompTaskNode* dst_node : sorted_dst_comp_tasks) {
+    machine_id2dst_tasks[dst_node->machine_id()].push_back(dst_node);
+  }
+
+  CompTaskNode* first_tick = sorted_src_comp_tasks.at(0);
+  for (const auto& pair : machine_id2dst_tasks) {
+    size_t machine_id = pair.first;
+    for (CompTaskNode* dst_node : pair.second) {
+      if (machine_id2tick_task.find(machine_id) != machine_id2tick_task.end()) {
+        Connect<TaskNode>(machine_id2tick_task.at(machine_id), NewEdge(), dst_node);
+      } else {
+        TaskNode* next_node = AddCopyCommNetTaskBetween(first_tick, dst_node);
+        Connect<TaskNode>(first_tick, NewEdge(), next_node);
+      }
+    }
+  }
+}
+
 DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySelectOneSourceToSoleSink) {
   CHECK_EQ(sorted_dst_comp_tasks.size(), 1);
   CompTaskNode* sole_dst_comp_task = sorted_dst_comp_tasks.front();
diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h
index d7609a7e714f79822fae18da0db4e4faefe041b3..db74b89263ad8daa93ac111b63af7805e6b0119b 100644
--- a/oneflow/core/graph/task_graph.h
+++ b/oneflow/core/graph/task_graph.h
@@ -20,8 +20,14 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
   const char* TypeName() const override { return "TaskGraph"; }
   void RemoveEmptyRegsts();
   void AddOrderingCtrlEdgeInSameChain();
+  void AddReduceSequenceCtrlEdges();
+  void AddMdUpdtCtrlEdgesWithinReduceSplitNode();
+  void AddReduceNoBwForwardNodeOverlapingCtrlEdges();
 
   void EnableMemSharingInReduceStruct();
+  void EnableMemSharingAfterAllManualSetForMdUpdt();
+  void EnableMemSharingInVariableOp();
+  void EnableInplaceMemSharing();
 
   void AddOrderCtrlEdgeBetweenCopyAndMdUpdt();
   void RmUselessConsumeRelationshipBetweenFwBw();
@@ -32,6 +38,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
 
   DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing);
   DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne);
+  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByRecordLoadToTick);
+  DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByTickToSource);
   DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySelectOneSourceToSoleSink);
   DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceScatter2ReduceAdd);
   DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceAdd2ReduceGather);
@@ -40,6 +48,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
  private:
   void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
                               std::function<void(TaskNode* node)> Handler) const;
+
   void BuildTaskPath(
       CompTaskNode* src, CompTaskNode* dst,
       std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp
index 6b57c7ce18cff500ab304116bbcd51675cdc1463..674762770b5a716d9d2959982cf8546b400f1735 100644
--- a/oneflow/core/graph/task_node.cpp
+++ b/oneflow/core/graph/task_node.cpp
@@ -422,5 +422,5 @@ std::map<TaskType, std::string> task_type2color = {
     {kDecodeRandom, "1"},    {kPackForward, "11"},
     {kPackBackward, "12"},   {kUnpackForward, "11"},
     {kUnpackBackward, "12"}, {kRepeatForward, "2"},
-    {kRepeatBackward, "3"}};
+    {kRepeatBackward, "3"},  {kReduceIdentity, "2"}};
 }  // namespace oneflow
diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h
index 624de108fdb22cc9e025a9d6a78dae59e591f820..30034659e815aaa1e4982c746ec9ac32b1fb6e40 100644
--- a/oneflow/core/graph/task_node.h
+++ b/oneflow/core/graph/task_node.h
@@ -47,6 +47,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
   int64_t LocalWorkStreamId() const;
   int64_t GlobalWorkStreamId() const;
   int64_t GpuPhyId() const { return Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(thrd_id_); }
+  virtual int64_t AreaId4ChainMerge() const { return area_id(); }
 
   // Setters
   void set_machine_id(int64_t val);
diff --git a/oneflow/core/graph/unpack_forward_task_node.cpp b/oneflow/core/graph/unpack_forward_task_node.cpp
index 655453456477ed08b856d6f7f176c20295c25720..934068d74cbe9e8307185bc81286b44a6035f722 100644
--- a/oneflow/core/graph/unpack_forward_task_node.cpp
+++ b/oneflow/core/graph/unpack_forward_task_node.cpp
@@ -36,7 +36,7 @@ void UnpackForwardCompTaskNode::InferProducedDataRegstTimeShape() {
   const UnpackOp* op = dynamic_cast<UnpackOp*>(logical_node()->SoleOp().get());
   CHECK_NOTNULL(op);
   int64_t in_piece_size = in_regst->GetBlobDesc(op->BnInOp2Lbi("in"))->shape().At(0);
-  int64_t unpack_num = op->GetUnpackNum(parallel_ctx()->parallel_num());
+  int64_t unpack_num = op->GetUnpackNum();
   CHECK_EQ(0, in_piece_size % unpack_num);
   time_shape_dim_vec.push_back(unpack_num);
   *out_regst->mut_data_regst_time_shape() = std::make_shared<Shape>(std::move(time_shape_dim_vec));
diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp
index 65b551b1d969b2d3e6880a1875496ae65e14fbf7..cb4f2f38fc70300b9dd3e7d93d10d5c6349f7ccc 100644
--- a/oneflow/core/job/compiler.cpp
+++ b/oneflow/core/job/compiler.cpp
@@ -1,6 +1,7 @@
 #include "oneflow/core/job/compiler.h"
 #include "oneflow/core/persistence/tee_persistent_log_stream.h"
 #include "oneflow/core/device/cudnn_conv_ctx_cache.h"
+#include "oneflow/core/graph/op_graph.h"
 
 namespace oneflow {
 
@@ -95,7 +96,11 @@ Plan Compiler::DoCompile() {
 #ifdef WITH_CUDA
   Global<CudnnConvCtxCache>::New();
 #endif
+  Global<JobDesc>::Get()->FixAndOptimizeDLNet();
   const JobDesc* job_desc = Global<JobDesc>::Get();
+  TeePersistentLogStream::Create("optimized_job_conf")->Write(job_desc->job_conf());
+  Global<OpGraph>::New(job_desc);
+  Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_op_graph.dot");
   auto logical_gph = std::make_unique<LogicalGraph>(job_desc->IsTrain());
   int64_t total_mbn_num = logical_gph->total_mbn_num();
   auto task_gph = std::make_unique<TaskGraph>(std::move(logical_gph));
@@ -104,14 +109,24 @@ Plan Compiler::DoCompile() {
   task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
   task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
   task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::Build);
+  if (job_desc->IsTrain()) {
+    task_gph->AddReduceSequenceCtrlEdges();
+    task_gph->AddMdUpdtCtrlEdgesWithinReduceSplitNode();
+  }
   task_gph->RemoveEmptyRegsts();
   task_gph->AddOrderingCtrlEdgeInSameChain();
   if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
     task_gph->EnableMemSharingInReduceStruct();
+    task_gph->EnableMemSharingAfterAllManualSetForMdUpdt();  // must last mem shared manual set
   }
+  task_gph->EnableInplaceMemSharing();
   if (job_desc->IsTrain()) { task_gph->AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); }
   if (job_desc->IsTrain()) { task_gph->RmUselessConsumeRelationshipBetweenFwBw(); }
   task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
+  if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) {
+    task_gph->EnableMemSharingInVariableOp();
+  }
+  if (job_desc->IsTrain()) { task_gph->AddReduceNoBwForwardNodeOverlapingCtrlEdges(); }
 
   Plan plan;
   task_gph->ForEachNode([&](TaskNode* task_node) {
@@ -121,6 +136,7 @@ Plan Compiler::DoCompile() {
   plan.set_total_mbn_num(total_mbn_num);
   GenNetTopo(&plan);
   ToDotFile(plan, "/dot/plan.dot");
+  Global<OpGraph>::Delete();
 #ifdef WITH_CUDA
   Global<CudnnConvCtxCache>::Delete();
 #endif
diff --git a/oneflow/core/job/id_manager.cpp b/oneflow/core/job/id_manager.cpp
index 4b3e16d9b3e3186e0e5a4099469742af3548619e..9e17203412d989696a15b34352d80bbf65be9252 100644
--- a/oneflow/core/job/id_manager.cpp
+++ b/oneflow/core/job/id_manager.cpp
@@ -16,9 +16,12 @@ int64_t IDMgr::GetGpuNcclGatherThrdId(int64_t dev_phy_id) const {
 int64_t IDMgr::GetGpuMixThrdId(int64_t dev_phy_id) const {
   return gpu_device_num_ * 5 + dev_phy_id;
 }
-int64_t IDMgr::GetGpuMdUpdtThrdId(int64_t dev_phy_id) const {
+int64_t IDMgr::GetGpuReduceCtrlThrdId(int64_t dev_phy_id) const {
   return gpu_device_num_ * 6 + dev_phy_id;
 }
+int64_t IDMgr::GetGpuMdUpdtThrdId(int64_t dev_phy_id) const {
+  return gpu_device_num_ * 7 + dev_phy_id;
+}
 int64_t IDMgr::GetCpuDeviceThrdId(int64_t dev_phy_id) const {
   return gpu_device_num_ * GetCudaWorkTypeSize() + dev_phy_id;
 }
@@ -76,6 +79,11 @@ int64_t IDMgr::GlobalWorkStreamId4ActorId(int64_t actor_id) const {
   return GlobalWorkStreamId4TaskId(actor_id);
 }
 
+int64_t IDMgr::GlobalThrdId4TaskId(int64_t task_id) const {
+  int shift = local_work_stream_id_bit_num_ + task_id_bit_num_;
+  return (task_id >> shift) << shift;
+}
+
 int64_t IDMgr::LocalWorkStreamId4TaskId(int64_t task_id) const {
   int64_t tmp = (task_id << (machine_id_bit_num_ + thread_id_bit_num_));
   tmp &= ~(static_cast<int64_t>(1) << 63);
@@ -100,6 +108,7 @@ IDMgr::IDMgr() {
   CHECK_LT(gpu_device_num_ + cpu_device_num_, (static_cast<int64_t>(1) << thread_id_bit_num_) - 3);
   regst_desc_id_count_ = 0;
   mem_shared_id_count_ = 0;
+  mem_block_id_count_ = 0;
 }
 
 int64_t IDMgr::GetMachineThrdId(int64_t machine_id, int64_t thrd_id) {
diff --git a/oneflow/core/job/id_manager.h b/oneflow/core/job/id_manager.h
index 59dc7ff65ac29cb4e6b9cdb8217593b68841e91f..5cda9288a858deabb9728d75989eb2c136565c85 100644
--- a/oneflow/core/job/id_manager.h
+++ b/oneflow/core/job/id_manager.h
@@ -19,6 +19,7 @@ class IDMgr final {
   int64_t GetGpuNcclScatterThrdId(int64_t dev_phy_id) const;
   int64_t GetGpuNcclGatherThrdId(int64_t dev_phy_id) const;
   int64_t GetGpuMixThrdId(int64_t dev_phy_id) const;
+  int64_t GetGpuReduceCtrlThrdId(int64_t dev_phy_id) const;
   int64_t GetGpuMdUpdtThrdId(int64_t dev_phy_id) const;
   int64_t GetCpuDeviceThrdId(int64_t dev_phy_id) const;
   int64_t CommNetThrdId() const;
@@ -27,6 +28,7 @@ class IDMgr final {
   int64_t NewTaskId(int64_t machine_id, int64_t thrd_id, int64_t local_work_stream_id);
   int64_t NewRegstDescId() { return regst_desc_id_count_++; }
   int64_t NewMemSharedId() { return mem_shared_id_count_++; }
+  int64_t NewMemBlockId() { return mem_block_id_count_++; }
 
   // MemZoneId
   int64_t CpuMemZoneId() const { return Global<JobDesc>::Get()->GpuDeviceNum(); }
@@ -53,6 +55,10 @@ class IDMgr final {
   int64_t AllocateLocalWorkStreamId(int64_t machine_id, int64_t thrd_id);
   int64_t LocalWorkStreamId4TaskId(int64_t task_id) const;
   int64_t LocalWorkStreamId4ActorId(int64_t actor_id) const;
+  // global_thread_id
+  // sign | machine_id | thrd_id | 0  | 0
+  //  1   |     10     |   11    | 21 | 21
+  int64_t GlobalThrdId4TaskId(int64_t task_id) const;
   // global_work_stream_id
   // sign | machine_id | thrd_id | local_work_stream_id | 0
   //  1   |     10     |   11    |          21          | 21
@@ -69,6 +75,7 @@ class IDMgr final {
   int64_t cpu_device_num_;
   int64_t regst_desc_id_count_;
   int64_t mem_shared_id_count_;
+  int64_t mem_block_id_count_;
   HashMap<int64_t, int64_t> machine_thrd_id2num_of_tasks_;
   HashMap<int64_t, int64_t> machine_thrd_id2stream_id_cnt_;
   HashMap<int64_t, int64_t> stream_id2chain_cnt_;
diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp
index 6c3e533644ad35cd92d811f7dd582269a3886edb..b95e8155792bd0e8c3a834ee2095df2bdeeac30a 100644
--- a/oneflow/core/job/improver.cpp
+++ b/oneflow/core/job/improver.cpp
@@ -7,7 +7,10 @@
 #include "oneflow/core/job/profiler.h"
 #include "oneflow/core/graph/plan_task_graph.h"
 #include "oneflow/core/graph/regst_lifetime_graph.h"
+#include "oneflow/core/graph/sharable_mem_block_graph.h"
 #include "oneflow/core/actor/act_event_logger.h"
+#include "oneflow/core/thread/thread_pool.h"
+#include "oneflow/core/common/blocking_counter.h"
 
 namespace oneflow {
 
@@ -27,16 +30,10 @@ bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
   return true;
 }
 
-bool IsSharableRegstWithConsumer(const RegstDescProto& regst_desc,
-                                 const std::function<int64_t(int64_t)>& ChainId4TaskId) {
-  return regst_desc.mem_shared_id() == -1 && regst_desc.consumer_task_id_size() > 0
-         && regst_desc.enable_mem_sharing() && regst_desc.register_num() == 1
-         && IsConsumersAndProducerInSameChain(regst_desc, ChainId4TaskId);
-}
-
 void ForEachSharableStreamRegstDescsWithoutConsumer(
-    const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
-  HashMap<int64_t, std::list<const RegstDescProto*>> global_work_stream_id2regst_descs;
+    const Plan& plan,
+    const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
+  HashMap<int64_t, std::vector<const RegstDescProto*>> global_work_stream_id2regst_descs;
   for (const auto& task : plan.task()) {
     int64_t global_work_stream_id = Global<IDMgr>::Get()->GlobalWorkStreamId4TaskId(task.task_id());
     for (const auto& pair : task.produced_regst_desc()) {
@@ -50,64 +47,125 @@ void ForEachSharableStreamRegstDescsWithoutConsumer(
   }
 }
 
-void ForEachSharableChainRegstDescsWithConsumer(
-    const Plan& plan, const std::function<int64_t(int64_t)>& ChainId4TaskId,
-    const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
-  HashMap<int64_t, std::list<const TaskProto*>> chain_id2task_proto;
-  for (const TaskProto& task : plan.task()) {
-    chain_id2task_proto[task.task_set_info().chain_id()].push_back(&task);
-  }
-  for (const auto& chain_tasks_pair : chain_id2task_proto) {
-    if (chain_tasks_pair.second.size() == 1) { continue; }
-    std::list<const RegstDescProto*> regst_descs;
-    for (const TaskProto* task : chain_tasks_pair.second) {
-      for (const auto& pair : task->produced_regst_desc()) {
-        if (IsSharableRegstWithConsumer(pair.second, ChainId4TaskId)) {
-          regst_descs.push_back(&pair.second);
-        }
-      }
-    }
-    if (regst_descs.size() > 1) { Handler(regst_descs); }
-  }
-}
-
 void ForEachSameColoredStreamRegstDescWithoutConsumer(
-    const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
+    const Plan& plan,
+    const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
   auto GetProducerTaskId = [](const RegstDescProto* regst_desc, HashSet<int64_t>* ret_actor_ids) {
     CHECK(regst_desc->enable_mem_sharing());
     ret_actor_ids->insert(regst_desc->producer_task_id());
   };
   ForEachSharableStreamRegstDescsWithoutConsumer(
-      plan, [&](const std::list<const RegstDescProto*>& regst_descs) {
+      plan, [&](const std::vector<const RegstDescProto*>& regst_descs) {
         RegstLifetimeGraph(regst_descs, GetProducerTaskId).ForEachSameColoredRegstDescs(Handler);
       });
 }
 
+void ForEachSameColoredChainRegstRegstDescs(
+    const SharableMemBlockGraph& sharable_mem_block_gph,
+    const std::function<std::vector<const RegstDescProto*>(
+        const std::vector<const SharableMemBlockNode*>&)>& GetRegstDescs,
+    const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>&
+        ComputeLifetimeSameChainActorIds,
+    const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
+  std::vector<std::vector<const SharableMemBlockNode*>> sharable_mem_blocks_vec;
+  sharable_mem_block_gph.ForEachSourceNodeGroup(
+      &SharableMemBlockNode::chain_id,
+      [&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) {
+        sharable_mem_blocks_vec.push_back(sharable_mem_blocks);
+      });
+  std::vector<std::vector<std::vector<const RegstDescProto*>>> same_colored_regst_descs_vec(
+      sharable_mem_blocks_vec.size());
+  int64_t cpu_num = std::thread::hardware_concurrency();
+  int64_t thread_pool_size = std::min<int64_t>(sharable_mem_blocks_vec.size(), cpu_num);
+  BlockingCounter counter(sharable_mem_blocks_vec.size());
+  ThreadPool thread_pool(thread_pool_size);
+  FOR_RANGE(int64_t, i, 0, sharable_mem_blocks_vec.size()) {
+    thread_pool.AddWork([i, &GetRegstDescs, &ComputeLifetimeSameChainActorIds,
+                         &sharable_mem_blocks_vec, &same_colored_regst_descs_vec, &counter]() {
+      const auto& sharable_mem_blocks = sharable_mem_blocks_vec.at(i);
+      RegstLifetimeGraph(GetRegstDescs(sharable_mem_blocks), ComputeLifetimeSameChainActorIds)
+          .ForEachSameColoredRegstDescs([&](const std::vector<const RegstDescProto*>& regst_descs) {
+            same_colored_regst_descs_vec.at(i).push_back(regst_descs);
+          });
+      counter.Decrease();
+    });
+  }
+  counter.WaitUntilCntEqualZero();
+  for (const auto& regst_descs_vec : same_colored_regst_descs_vec) {
+    for (const auto& regst_descs : regst_descs_vec) { Handler(regst_descs); }
+  }
+}
+
 void ForEachSameColoredChainRegstDescWithConsumer(
     const PlanTaskGraph& plan_task_graph,
-    const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
+    const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
+  // construct SharableMemBlockGraph
+  auto ChainId4TaskId = [&](int64_t task_id) {
+    return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
+  };
+  auto IsSharableRegstWithConsumer = [&](const RegstDescProto& regst_desc) {
+    return regst_desc.mem_shared_id() == -1 && regst_desc.consumer_task_id_size() > 0
+           && regst_desc.enable_mem_sharing() && regst_desc.register_num() == 1
+           && IsConsumersAndProducerInSameChain(regst_desc, ChainId4TaskId);
+  };
+  SharableMemBlockGraph sharable_mem_block_gph(plan_task_graph, IsSharableRegstWithConsumer);
+  sharable_mem_block_gph.ForEachNode([&](const SharableMemBlockNode* sharable_mem_block) {
+    CHECK_EQ(sharable_mem_block->mem_block().mem_reduce_method(), MemReduceMethod::kMemMax);
+  });
+  // group regst_descs for pre-colored regst_descs.
+  // example:
+  // given dlnet: A -> B -> C -> D -> E -> F -> H -> I, where D is a inplace op.
+  // Regst(C) and Regst(D) are pre-colored with same color as a group, which
+  // then shares memory with other regsts like A, B, E, ...
+  HashMap<const RegstDescProto*, std::vector<const RegstDescProto*>> header2members;
+  for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_block_gph.source_nodes()) {
+    auto regst_descs = sharable_mem_block->regst_descs();
+    HashMap<const RegstDescProto*, size_t> regst_desc2mem_size;
+    for (const RegstDescProto* regst_desc : regst_descs) {
+      size_t size = RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst();
+      CHECK(regst_desc2mem_size.emplace(regst_desc, size).second);
+    }
+    std::sort(regst_descs.begin(), regst_descs.end(),
+              [&](const RegstDescProto* lhs, const RegstDescProto* rhs) {
+                return regst_desc2mem_size.at(lhs) > regst_desc2mem_size.at(rhs);
+              });
+    header2members.emplace(regst_descs.at(0), regst_descs);
+  }
+  auto GetRegstDescs = [&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) {
+    std::vector<const RegstDescProto*> ret;
+    for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_blocks) {
+      for (const RegstDescProto* regst_desc : sharable_mem_block->regst_descs()) {
+        if (header2members.find(regst_desc) != header2members.end()) {
+          ret.push_back(regst_desc);
+          break;
+        }
+      }
+    }
+    return ret;
+  };
   auto ComputeLifetimeSameChainActorIds = [&](const RegstDescProto* regst_desc,
                                               HashSet<int64_t>* ret_actor_ids) {
     CHECK(regst_desc->enable_mem_sharing());
     ret_actor_ids->clear();
-    plan_task_graph.ComputeLifetimeSameChainActorIds(regst_desc, ret_actor_ids);
+    for (const RegstDescProto* member : header2members.at(regst_desc)) {
+      plan_task_graph.ComputeLifetimeSameChainActorIds(member, ret_actor_ids);
+    }
   };
-  auto ChainId4TaskId = [&](int64_t task_id) {
-    return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
+  auto AppendGroupMembers = [&](const std::vector<const RegstDescProto*>& regst_descs) {
+    std::vector<const RegstDescProto*> members;
+    for (const auto* header : regst_descs) {
+      for (const auto* member : header2members.at(header)) { members.push_back(member); }
+    }
+    Handler(members);
   };
-  const Plan& plan = plan_task_graph.plan();
-  ForEachSharableChainRegstDescsWithConsumer(
-      plan, ChainId4TaskId, [&](const std::list<const RegstDescProto*>& regst_descs) {
-        RegstLifetimeGraph(regst_descs, ComputeLifetimeSameChainActorIds)
-            .ForEachSameColoredRegstDescs(Handler);
-      });
+  ForEachSameColoredChainRegstRegstDescs(sharable_mem_block_gph, GetRegstDescs,
+                                         ComputeLifetimeSameChainActorIds, AppendGroupMembers);
 }
 
 void ForEachImprovedMemSharedId(const PlanTaskGraph& plan_task_graph,
                                 const std::function<void(int64_t, int64_t)>& Handler) {
-  using RegstDescs = std::list<const RegstDescProto*>;
   const Plan& plan = plan_task_graph.plan();
-  auto HandleMemSharedId = [&](const RegstDescs& regst_descs) {
+  auto HandleMemSharedId = [&](const std::vector<const RegstDescProto*>& regst_descs) {
     int64_t mem_shared_id = Global<IDMgr>::Get()->NewMemSharedId();
     for (const RegstDescProto* regst_desc : regst_descs) {
       Handler(regst_desc->regst_desc_id(), mem_shared_id);
diff --git a/oneflow/core/job/init_op_conf.cpp b/oneflow/core/job/init_op_conf.cpp
index 9e6b04e2348e60ee088aba910eaba091bad3e212..9587584bab392e767d8ccfd3c872957d6125a3ba 100644
--- a/oneflow/core/job/init_op_conf.cpp
+++ b/oneflow/core/job/init_op_conf.cpp
@@ -219,6 +219,12 @@ void InitInitializerConf(InitializerConf* initializer, const InitializerConf::Ty
       initializer->set_allocated_random_normal_conf(random_normal_conf);
       break;
     }
+    case InitializerConf::kTruncatedNormalConf: {
+      TruncatedNormalInitializerConf* truncated_normal_conf = new TruncatedNormalInitializerConf();
+      truncated_normal_conf->set_std(param1);
+      initializer->set_allocated_truncated_normal_conf(truncated_normal_conf);
+      break;
+    }
     case InitializerConf::kXavierConf: {
       XavierInitializerConf* xavier_conf = new XavierInitializerConf();
       xavier_conf->set_variance_norm(static_cast<VarianceNorm>(static_cast<int>(param1)));
@@ -231,6 +237,20 @@ void InitInitializerConf(InitializerConf* initializer, const InitializerConf::Ty
       initializer->set_allocated_msra_conf(msra_conf);
       break;
     }
+    case InitializerConf::kRangeConf: {
+      RangeInitializerConf* range_conf = new RangeInitializerConf();
+      range_conf->set_start(param1);
+      range_conf->set_stride(param2);
+      initializer->set_allocated_range_conf(range_conf);
+      break;
+    }
+    case InitializerConf::kIntRangeConf: {
+      IntRangeInitializerConf* int_range_conf = new IntRangeInitializerConf();
+      int_range_conf->set_start(static_cast<int64_t>(param1));
+      int_range_conf->set_stride(static_cast<int64_t>(param2));
+      initializer->set_allocated_int_range_conf(int_range_conf);
+      break;
+    }
     case InitializerConf::TYPE_NOT_SET: {
       LOG(INFO) << "InitializerConf::TYPE_NOT_SET";
       break;
diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto
index 3bdad1aed56426fb36913070aa75132b3fa2835e..07ed54c52c9f196df7a2191a0e67a6ff5c38b104 100644
--- a/oneflow/core/job/job_conf.proto
+++ b/oneflow/core/job/job_conf.proto
@@ -67,18 +67,30 @@ message OtherConf {
   optional bool save_downloaded_file_to_local_fs = 109 [default = false];
   optional uint64 rdma_mem_block_mbyte = 110 [default = 8];
   optional uint64 rdma_recv_msg_buf_mbyte = 111 [default = 6];
+  required FileSystemConf data_fs_conf = 112;
+  required FileSystemConf snapshot_fs_conf = 113;
 
-  optional int64 reduce_group_size = 112 [default = 20];
-  optional bool collect_act_event = 113 [default = false];
-  optional bool enable_mem_sharing = 114 [default = true];
+  optional bool collect_act_event = 125 [default = false];
+  optional bool enable_mem_sharing = 126 [default = true];
   optional bool enable_write_snapshot = 130 [default = true];
   optional bool enable_blob_mem_sharing = 140 [default = true];
   optional bool enable_nccl = 142 [default = true];
   optional bool use_nccl_inter_node_communication = 143 [default = false];
   optional int64 cudnn_buf_limit_mbyte = 144 [default = 1024];  // 1GByte
-
-  required FileSystemConf data_fs_conf = 121;
-  required FileSystemConf snapshot_fs_conf = 122;
+  optional int64 all_reduce_group_num = 145 [default = 8];
+  // example:
+  //   all_reduce_lazy_ratio = 0.5
+  // It means that half of all_reduce nodes overlap with the forward pass of next batch
+  optional float all_reduce_lazy_ratio = 146 [default = 0.5];
+  optional int64 all_reduce_group_min_mbyte = 147 [default = 16];
+  // example:
+  //   total weight bytes is 1024M
+  //   all_reduce_group_num = 8
+  //   all_reduce_group_min_mbyte = 16
+  //   all_reduce_group_size_warmup = 2
+  // Each nodes' weight size are [16, 32, 64, 128, 128, 128, 128, 128, 128, 128, 16].
+  // You can see the actual number of reduce group is slightly bigger than all_reduce_group_num.
+  optional float all_reduce_group_size_warmup = 148 [default = 2];
 
   oneof JobType {
     TrainConf train_conf = 200;
diff --git a/oneflow/core/job/job_desc.cpp b/oneflow/core/job/job_desc.cpp
index abc2ccec121cdf6b3def19f29dc84701366467cd..51d31f7e1237a45735d149bf3cdbd7917a277ef1 100644
--- a/oneflow/core/job/job_desc.cpp
+++ b/oneflow/core/job/job_desc.cpp
@@ -1,10 +1,146 @@
 #include "oneflow/core/job/job_desc.h"
 #include "oneflow/core/job/parallel_desc.h"
+#include "oneflow/core/operator/operator.h"
 #include "oneflow/core/common/protobuf.h"
 #include "oneflow/core/persistence/hadoop/hadoop_file_system.h"
+#include "oneflow/core/graph/graph.h"
+#include "oneflow/core/graph/op_graph.h"
 
 namespace oneflow {
 
+std::function<const ParallelConf*(const std::string&)> MakeGetterParallelConf4OpName(
+    const Placement& placement) {
+  auto op_name2parallel_conf = std::make_shared<HashMap<std::string, const ParallelConf*>>();
+  for (const auto& placement_group : placement.placement_group()) {
+    for (const std::string& op_name : placement_group.op_set().op_name()) {
+      const ParallelConf* parallel_conf = &placement_group.parallel_conf();
+      CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second);
+    }
+  }
+  return [op_name2parallel_conf](const std::string& op_name) {
+    return op_name2parallel_conf->at(op_name);
+  };
+}
+
+std::function<ParallelConf*(const std::string&)> MakeGetterMutParallelConf4OpName(
+    Placement* placement) {
+  auto op_name2parallel_conf = std::make_shared<HashMap<std::string, ParallelConf*>>();
+  FOR_RANGE(int, idx, 0, placement->placement_group_size()) {
+    auto* placement_group = placement->mutable_placement_group(idx);
+    for (const std::string& op_name : placement_group->op_set().op_name()) {
+      ParallelConf* parallel_conf = placement_group->mutable_parallel_conf();
+      CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second);
+    }
+  }
+  return [op_name2parallel_conf](const std::string& op_name) {
+    return op_name2parallel_conf->at(op_name);
+  };
+}
+
+namespace {
+
+std::function<OperatorConf*(const std::string&)> MakeMutableOperatorConf4OpName(
+    JobConf1* job_conf) {
+  auto op_name2op_conf = std::make_shared<HashMap<std::string, OperatorConf*>>();
+  FOR_RANGE(int, idx, 0, job_conf->net().op_size()) {
+    OperatorConf* op_conf = job_conf->mutable_net()->mutable_op(idx);
+    CHECK(op_name2op_conf->emplace(op_conf->name(), op_conf).second);
+  }
+  return [op_name2op_conf](const std::string& op_name) { return op_name2op_conf->at(op_name); };
+}
+
+void AddIdentityOp(const std::string& prefix, JobConf1* job_conf,
+                   const HashSet<LogicalBlobId>& input_lbis,
+                   HashMap<LogicalBlobId, LogicalBlobId>* old_lbi2new_lbi,
+                   const ParallelConf& parallel_conf) {
+  // add tuple identity op
+  OperatorConf* tuple_identity_op = job_conf->mutable_net()->add_op();
+  tuple_identity_op->set_name(prefix + NewUniqueId());
+  TupleIdentityOpConf* tuple_identity_op_conf = tuple_identity_op->mutable_tuple_identity_conf();
+  int32_t idx = 0;
+  for (const LogicalBlobId& lbi : input_lbis) {
+    std::string blob_name = std::string("out_") + std::to_string(idx++);
+    {
+      LogicalBlobId output_lbi;
+      output_lbi.set_op_name(tuple_identity_op->name());
+      output_lbi.set_blob_name(blob_name);
+      CHECK(old_lbi2new_lbi->emplace(lbi, output_lbi).second);
+    }
+    tuple_identity_op_conf->add_in(lbi.op_name() + "/" + lbi.blob_name());
+    tuple_identity_op_conf->add_out(blob_name);
+  }
+  // add placement of tuple identity op
+  PlacementGroup* p_group = job_conf->mutable_placement()->add_placement_group();
+  *(p_group->mutable_op_set()->add_op_name()) = tuple_identity_op->name();
+  *(p_group->mutable_parallel_conf()) = parallel_conf;
+}
+
+void SetPbMessageField(PbMessage* pb_msg, const std::string& field, const std::string& old_val,
+                       const std::string& new_val) {
+  const PbFd* fd = pb_msg->GetDescriptor()->FindFieldByName(field);
+  if (fd) {
+    CHECK_EQ(GetValFromPbMessage<std::string>(*pb_msg, field), old_val);
+    SetValInPbMessage<std::string>(pb_msg, field, new_val);
+  } else {
+    const std::pair<std::string, int32_t> prefix_idx = GenUnRepeatedBn(field);
+    CHECK_EQ(GetPbRpfFromPbMessage<std::string>(*pb_msg, prefix_idx.first).Get(prefix_idx.second),
+             old_val);
+    PbRpf<std::string>* rpf = MutPbRpfFromPbMessage<std::string>(pb_msg, prefix_idx.first);
+    *rpf->Mutable(prefix_idx.second) = new_val;
+  }
+}
+
+void AddIdentityOpAndReconnect(
+    const std::string& identity_op_name_prefix, JobConf1* job_conf,
+    const std::vector<OpEdge*>& op_edges,
+    const std::function<OperatorConf*(const std::string&)>& MutOperatorConf4OpName,
+    const ParallelConf& parallel_conf) {
+  // add identity op
+  HashSet<LogicalBlobId> lbis;
+  for (OpEdge* edge : op_edges) { lbis.insert(edge->lbis().begin(), edge->lbis().end()); }
+  HashMap<LogicalBlobId, LogicalBlobId> old_lbi2new_lbi;
+  AddIdentityOp(identity_op_name_prefix, job_conf, lbis, &old_lbi2new_lbi, parallel_conf);
+  // reconnect to identity op
+  for (OpEdge* edge : op_edges) {
+    OperatorConf* op_conf = MutOperatorConf4OpName(edge->dst_node()->op().op_name());
+    PbMessage* op_type_conf = MutableMessageInPbMessage(op_conf, op_conf->op_type_case());
+    for (const LogicalBlobId& lbi : edge->lbis()) {
+      std::string lbn_check = GenLogicalBlobName(lbi);
+      std::string identity_out_lbn = GenLogicalBlobName(old_lbi2new_lbi.at(lbi));
+      for (const std::string& ibn : edge->lbi2ibns().at(lbi)) {
+        SetPbMessageField(op_type_conf, ibn, lbn_check, identity_out_lbn);
+      }
+    }
+  }
+}
+
+}  // namespace
+
+int64_t JobDesc::all_reduce_group_min_byte() const {
+  int64_t ret = job_conf_.other().all_reduce_group_min_mbyte() * 1024 * 1024;
+  CHECK_GT(ret, 0);
+  return ret;
+}
+
+float JobDesc::all_reduce_group_size_warmup() const {
+  float ret = job_conf_.other().all_reduce_group_size_warmup();
+  CHECK_GT(ret, 1);
+  return ret;
+}
+
+int64_t JobDesc::all_reduce_group_num() const {
+  int64_t ret = job_conf_.other().all_reduce_group_num();
+  CHECK_GT(ret, 0);
+  return ret;
+}
+
+float JobDesc::all_reduce_lazy_ratio() const {
+  float ratio = job_conf_.other().all_reduce_lazy_ratio();
+  CHECK_GE(ratio, 0.0);
+  CHECK_LE(ratio, 1.0);
+  return ratio;
+}
+
 int64_t JobDesc::piece_num_of_experiment_phase() const {
   return job_conf_.other().exp_run_conf().piece_num_of_experiment_phase();
 }
@@ -65,8 +201,8 @@ int64_t JobDesc::BatchSize() const {
 }
 int64_t JobDesc::NumOfPiecesInBatch() const {
   if (IsPredict()) { return 1; }
-  CHECK_EQ(BatchSize() % PieceSize(), 0);
-  return BatchSize() / PieceSize();
+  CHECK_EQ(BatchSize() % RecordPieceSize(), 0);
+  return BatchSize() / RecordPieceSize();
 }
 float JobDesc::primary_lr() const {
   CHECK(IsTrain());
@@ -195,6 +331,7 @@ void JobDesc::SplitDecodeOps() {
 void JobDesc::AddRecordLoadOps() {
   HashMap<std::pair<std::string, std::string>, std::vector<OperatorConf*>> data_info2decode_ops;
   HashMap<std::pair<std::string, std::string>, int32_t> data_info2suffix_length;
+  HashMap<std::pair<std::string, std::string>, const RandomShuffleConf*> data_info2shuffle_conf;
   size_t op_num = job_conf_.net().op_size();
   FOR_RANGE(size_t, idx, 0, op_num) {
     OperatorConf* op_conf = job_conf_.mutable_net()->mutable_op()->Mutable(idx);
@@ -210,6 +347,18 @@ void JobDesc::AddRecordLoadOps() {
     } else {
       data_info2suffix_length[data_info] = part_name_suffix_length;
     }
+    const RandomShuffleConf* shuffle_conf =
+        decode_conf.has_random_shuffle_conf() ? &decode_conf.random_shuffle_conf() : nullptr;
+    if (data_info2shuffle_conf.find(data_info) != data_info2shuffle_conf.end()) {
+      if (shuffle_conf == nullptr) {
+        CHECK(data_info2shuffle_conf.at(data_info) == nullptr);
+      } else {
+        CHECK(data_info2shuffle_conf.at(data_info) != nullptr);
+        CHECK_EQ(data_info2shuffle_conf.at(data_info)->buffer_size(), shuffle_conf->buffer_size());
+      }
+    } else {
+      CHECK(data_info2shuffle_conf.emplace(data_info, shuffle_conf).second);
+    }
   }
 
   HashMap<std::string, const ParallelConf*> name2parallel_conf;
@@ -246,6 +395,9 @@ void JobDesc::AddRecordLoadOps() {
       record_load_op->set_data_dir(pair.first.first);
       record_load_op->set_part_name_prefix(pair.first.second);
       record_load_op->set_part_name_suffix_length(data_info2suffix_length.at(pair.first));
+      if (data_info2shuffle_conf.at(pair.first) != nullptr) {
+        *record_load_op->mutable_random_shuffle_conf() = *data_info2shuffle_conf.at(pair.first);
+      }
       PlacementGroup* p_group = job_conf_.mutable_placement()->add_placement_group();
       *(p_group->mutable_op_set()->add_op_name()) = record_load_op_name;
       *(p_group->mutable_parallel_conf()) = *parallel_conf;
@@ -261,4 +413,128 @@ void JobDesc::AddRecordLoadOps() {
   }
 }
 
+void JobDesc::FixAndOptimizeDLNet() {
+  FixTickOpIfExists();
+  ConvertPseudoChainToChain();
+  if (IsTrain()) { AddIdentityOpForAllReduceOverlapingUntrainble(); }
+}
+
+void JobDesc::ConvertPseudoChainToChain() {
+  auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes,
+                                    HashSet<OpNode*>* source_nodes,
+                                    std::vector<OpEdge*>* source_edges) {
+    for (OpNode* node : chain_nodes) {
+      for (OpEdge* edge : node->in_edges()) {
+        if (chain_nodes.find(edge->src_node()) == chain_nodes.end()) {
+          source_edges->push_back(edge);
+          source_nodes->insert(node);
+        }
+      }
+    }
+  };
+  auto MutOperatorConf4OpName = MakeMutableOperatorConf4OpName(&job_conf_);
+  auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job_conf_.placement());
+  OpGraph(this).ForEachPseudoChain([&](const HashSet<OpNode*>& chain_nodes) {
+    HashSet<OpNode*> source_nodes;
+    std::vector<OpEdge*> source_edges;
+    GetSourceNodesAndEdges(chain_nodes, &source_nodes, &source_edges);
+    if (source_edges.size() <= 1) { return; }
+    if (source_nodes.size() <= 1) { return; }
+    if (chain_nodes.size() - source_nodes.size() <= 2) { return; }
+    const OpNode* first_node = *source_nodes.begin();
+    if (first_node->parallel_desc().device_type() == DeviceType::kCPU) { return; }
+    HashMap<bool, std::vector<OpEdge*>> has_diff2source_edges;
+    for (OpEdge* edge : source_edges) { has_diff2source_edges[edge->has_diff()].push_back(edge); }
+    for (const auto& pair : has_diff2source_edges) {
+      HashSet<OpNode*> src_nodes;
+      HashSet<OpNode*> dst_nodes;
+      for (OpEdge* edge : pair.second) {
+        src_nodes.emplace(edge->src_node());
+        dst_nodes.emplace(edge->dst_node());
+      }
+      if (src_nodes.size() > 1 && dst_nodes.size() > 1) {
+        AddIdentityOpAndReconnect("pseudo_chain_header_", &job_conf_, pair.second,
+                                  MutOperatorConf4OpName,
+                                  *ParallelConf4OpName(first_node->op().op_name()));
+      }
+    }
+  });
+}
+
+void JobDesc::AddIdentityOpForAllReduceOverlapingUntrainble() {
+  auto MutOperatorConf4OpName = MakeMutableOperatorConf4OpName(&job_conf_);
+  auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job_conf_.placement());
+  OpGraph(this).TopoForEachNode([&](OpNode* op_node) {
+    if (op_node->HasBackward()) { return; }
+    HashMap<bool, std::vector<OpEdge*>> has_bw2out_op_edges;
+    for (OpEdge* edge : op_node->out_edges()) {
+      has_bw2out_op_edges[edge->dst_node()->HasBackward()].push_back(edge);
+    }
+    if (has_bw2out_op_edges.size() <= 1) { return; }
+    // only handle op_nodes that:
+    // a) have no backward node;
+    // b) have trainable and untrainble consumers;
+
+    // group out_edge by trainable consumers' ParallelDesc
+    HashMap<ParallelDesc, std::vector<OpEdge*>> consumer_op_pr2edges;
+    for (OpEdge* edge : has_bw2out_op_edges.at(true)) {
+      ParallelDesc pr(*ParallelConf4OpName(edge->dst_node()->op().op_name()));
+      consumer_op_pr2edges[pr].push_back(edge);
+    }
+    for (const auto& pair : consumer_op_pr2edges) {
+      AddIdentityOpAndReconnect(
+          "all_reduce_overlapping_untrainable_", &job_conf_, pair.second, MutOperatorConf4OpName,
+          *ParallelConf4OpName(pair.second.at(0)->dst_node()->op().op_name()));
+    }
+  });
+}
+
+void JobDesc::FixTickOpIfExists() {
+  auto MutParallelConf4OpName = MakeGetterMutParallelConf4OpName(job_conf_.mutable_placement());
+  OperatorConf* tick_op_conf = nullptr;
+  FOR_RANGE(int, idx, 0, job_conf_.mutable_net()->op_size()) {
+    OperatorConf* op_conf = job_conf_.mutable_net()->mutable_op(idx);
+    if (op_conf->has_tick_conf()) {
+      CHECK(tick_op_conf == nullptr);
+      tick_op_conf = op_conf;
+    }
+  }
+  if (tick_op_conf == nullptr) { return; }
+  std::map<OperatorConf::OpTypeCase, std::vector<OperatorConf*>> op_type_case2source_op_confs;
+  FOR_RANGE(int, idx, 0, job_conf_.mutable_net()->op_size()) {
+    OperatorConf* op_conf = job_conf_.mutable_net()->mutable_op(idx);
+    if (op_conf == tick_op_conf) { continue; }
+    DeviceType device_type = ParallelDesc(*MutParallelConf4OpName(op_conf->name())).device_type();
+    if (ConstructOp(*op_conf, device_type)->input_bns().size() == 0) {
+      op_type_case2source_op_confs[op_conf->op_type_case()].push_back(op_conf);
+    }
+  }
+  if (op_type_case2source_op_confs.find(OperatorConf::kRecordLoadConf)
+      != op_type_case2source_op_confs.end()) {
+    CHECK_EQ(op_type_case2source_op_confs.size(), 1);
+  }
+  // set input of tick op
+  OperatorConf* source_op_conf = op_type_case2source_op_confs.cbegin()->second.at(0);
+  ParallelConf* source_parallel_conf = MutParallelConf4OpName(source_op_conf->name());
+  DeviceType device_type = ParallelDesc(*source_parallel_conf).device_type();
+  std::shared_ptr<Operator> source_op = ConstructOp(*source_op_conf, device_type);
+  CHECK_GE(source_op->output_bns().size(), 1);
+  LogicalBlobId src_first_output_lbi = source_op->BnInOp2Lbi(source_op->output_bns().Get(0));
+  std::string source_op_output_lbn = GenLogicalBlobName(src_first_output_lbi);
+  CHECK_EQ(tick_op_conf->tick_conf().has_in(), false);
+  tick_op_conf->mutable_tick_conf()->set_in(source_op_output_lbn);
+  // fix tick op placement
+  *MutParallelConf4OpName(tick_op_conf->name()) = *source_parallel_conf;
+  // add log_counter op connecting to tick op, making tick op always consumed
+  OperatorConf* tick_log_counter = job_conf_.mutable_net()->add_op();
+  tick_log_counter->set_name("tick_log_counter_" + NewUniqueId());
+  LogCounterOpConf* tick_log_counter_conf = tick_log_counter->mutable_log_counter_conf();
+  tick_log_counter_conf->set_in(tick_op_conf->name() + "/" + tick_op_conf->tick_conf().out());
+  tick_log_counter_conf->set_interval(MaxVal<int32_t>::value);
+  // add placement of tick_log_counter op
+  PlacementGroup* p_group = job_conf_.mutable_placement()->add_placement_group();
+  *(p_group->mutable_op_set()->add_op_name()) = tick_log_counter->name();
+  *(p_group->mutable_parallel_conf()) = *source_parallel_conf;
+}
+
 }  // namespace oneflow
diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h
index a1593ccd7bc2d507d4f666085d91992a30a8ad54..0b1fc41fe2d53b7cc7e97d8c6c849c0e060e71cd 100644
--- a/oneflow/core/job/job_desc.h
+++ b/oneflow/core/job/job_desc.h
@@ -7,6 +7,7 @@
 #include "oneflow/core/job/placement.pb.h"
 #include "oneflow/core/job/resource.pb.h"
 #include "oneflow/core/persistence/file_system.h"
+#include "oneflow/core/register/logical_blob_id.pb.h"
 
 namespace oneflow {
 
@@ -17,6 +18,7 @@ class JobDesc final {
   ~JobDesc() = default;
 
   // Common
+  const JobConf1& job_conf() const { return job_conf_; }
   const DLNetConf& dlnet_conf() const { return job_conf_.net(); }
   const Resource& resource() const { return job_conf_.resource(); }
   const Placement& placement() const { return job_conf_.placement(); }
@@ -35,7 +37,7 @@ class JobDesc final {
   int32_t MaxMdSaveWorkerNum() const { return job_conf_.resource().max_mdsave_worker_num(); }
   bool IsTrain() const { return job_conf_.other().has_train_conf(); }
   bool IsPredict() const { return job_conf_.other().has_predict_conf(); }
-  int64_t PieceSize() const { return job_conf_.other().piece_size(); }
+  int64_t RecordPieceSize() const { return job_conf_.other().piece_size(); }
   int64_t piece_num_of_experiment_phase() const;
   bool enable_experiment_run() const;
   float available_zone_mem_ratio() const;
@@ -58,7 +60,10 @@ class JobDesc final {
   bool use_nccl_inter_node_communication() const {
     return job_conf_.other().use_nccl_inter_node_communication();
   }
-  int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); }
+  int64_t all_reduce_group_num() const;
+  int64_t all_reduce_group_min_byte() const;
+  float all_reduce_group_size_warmup() const;
+  float all_reduce_lazy_ratio() const;
   int64_t cudnn_buf_limit_mbyte() const { return job_conf_.other().cudnn_buf_limit_mbyte(); }
   int64_t GetMachineId(const std::string& addr) const;
 
@@ -79,6 +84,9 @@ class JobDesc final {
   float bias_l2() const;
   int32_t DataPartNum() const;
 
+  // fix and Optimize
+  void FixAndOptimizeDLNet();
+
  private:
   friend class Global<JobDesc>;
   JobDesc(const std::string& job_conf_filepath);
@@ -87,10 +95,19 @@ class JobDesc final {
   void SanityCheck();
   void SplitDecodeOps();
   void AddRecordLoadOps();
+  void ConvertPseudoChainToChain();
+  void AddIdentityOpForChainMergeOptimization();
+  void AddIdentityOpForAllReduceOverlapingUntrainble();
+  void FixTickOpIfExists();
 
   JobConf1 job_conf_;
 };
 
+std::function<const ParallelConf*(const std::string&)> MakeGetterParallelConf4OpName(
+    const Placement& placement);
+std::function<ParallelConf*(const std::string&)> MakeGetterMutParallelConf4OpName(
+    Placement* placement);
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_JOB_JOB_DESC_H_
diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp
index e1c5a0ef7284a731d46bd6617fef15c6a2e9c945..37756f65e16ca21e770efbd6e996a8bb2791cbeb 100644
--- a/oneflow/core/job/parallel_desc.cpp
+++ b/oneflow/core/job/parallel_desc.cpp
@@ -34,7 +34,9 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
       CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kGPU);
       device_type_ = DeviceType::kGPU;
     }
-    sorted_machine_ids_.push_back(mchn_id);
+    if (machine_id_set.find(mchn_id) == machine_id_set.end()) {
+      sorted_machine_ids_.push_back(mchn_id);
+    }
     int64_t minus_pos = device_id_str.find("-");
     if (minus_pos == std::string::npos) {
       device_id_str = device_id_str + "-" + device_id_str;
@@ -54,42 +56,6 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
   SanityCheck();
 }
 
-void ParallelDesc::RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num) {
-  if (max_device_num >= parallel_num_) { return; }
-  LOG_IF(WARNING, op_name != "") << "parallel_num of " << op_name
-                                 << " is greater than max_device_num " << max_device_num;
-  int32_t device_cnt = 0;
-  int64_t max_machine_id = -1;
-  for (int64_t machine_id : sorted_machine_ids_) {
-    auto it = machine_id2sorted_dev_phy_ids_.find(machine_id);
-    int32_t cur_device_num = it->second.size();
-    int32_t cur_device_max_num = max_device_num - device_cnt;
-    if (cur_device_num > cur_device_max_num) {
-      it->second.erase(it->second.begin() + cur_device_max_num, it->second.end());
-      if (it->second.empty()) {
-        max_machine_id = machine_id - 1;
-      } else {
-        max_machine_id = machine_id;
-      }
-      break;
-    } else {
-      device_cnt += cur_device_num;
-    }
-  }
-  CHECK_NE(max_machine_id, -1);
-  FOR_EACH(it, sorted_machine_ids_) {
-    if (*it > max_machine_id) {
-      sorted_machine_ids_.erase(it, sorted_machine_ids_.end());
-      break;
-    }
-  }
-  EraseIf<int64_t, std::vector<int64_t>>(&machine_id2sorted_dev_phy_ids_,
-                                         [&](HashMap<int64_t, std::vector<int64_t>>::iterator it) {
-                                           return it->first > max_machine_id;
-                                         });
-  parallel_num_ = max_device_num;
-}
-
 void ParallelDesc::RandomSelectOneDeviceAndRemoveTheOthers() {
   CHECK_GE(parallel_num_, 1);
   if (parallel_num_ == 1) { return; }
diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h
index 224d24816a9180ab0f672628e840e522148a6841..b2d3561b5f0a258414d441f788c1b7f7c5bac225 100644
--- a/oneflow/core/job/parallel_desc.h
+++ b/oneflow/core/job/parallel_desc.h
@@ -11,12 +11,13 @@ namespace oneflow {
 void ParseDeviceNameConf(const std::string& device_name, int64_t* mchn_id, std::string* device_tag,
                          std::string* device_id_str);
 
-class ParallelDesc {
+class ParallelDesc final {
  public:
   // OF_DISALLOW_COPY_AND_MOVE(ParallelDesc);
   ParallelDesc() = delete;
   ~ParallelDesc() = default;
 
+  ParallelDesc(const ParallelDesc&) = default;
   ParallelDesc(const ParallelConf& user_conf);
 
   // Getters
@@ -32,13 +33,12 @@ class ParallelDesc {
   // Setters
   void set_policy(ParallelPolicy val) { policy_ = val; }
   void set_device_type(DeviceType device_type) { device_type_ = device_type; }
-  void RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num);
-  void RemoveNeedlessDevice(int32_t max_device_num) { RemoveNeedlessDevice("", max_device_num); }
   void RandomSelectOneDeviceAndRemoveTheOthers();
   void UseCPUDevicesOnMaster();
 
   //
   bool Equal(const ParallelDesc& rhs) const;
+  bool operator==(const ParallelDesc& rhs) const { return Equal(rhs); }
   bool Equal(const ParallelDesc* rhs) const { return Equal(*rhs); }
 
  private:
@@ -58,4 +58,20 @@ std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx(
 
 }  // namespace oneflow
 
+namespace std {
+
+template<>
+struct hash<oneflow::ParallelDesc> {
+  size_t operator()(const oneflow::ParallelDesc& pr) const {
+    std::string str;
+    for (int machine_id : pr.sorted_machine_ids()) {
+      str += "::" + std::to_string(machine_id) + ":";
+      for (int dev_id : pr.sorted_dev_phy_ids(machine_id)) { str += std::to_string(dev_id) + ","; }
+    }
+    return hash<std::string>()(str);
+  }
+};
+
+}  // namespace std
+
 #endif  // ONEFLOW_CORE_JOB_PARALLEL_DESC_H_
diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp
index da7e207b83f20003cf89e1cc8e970cd50836ea49..e886051661bf5c5ec0c0b33dcb3b27e4de42435f 100644
--- a/oneflow/core/job/runtime.cpp
+++ b/oneflow/core/job/runtime.cpp
@@ -7,6 +7,7 @@
 #include "oneflow/core/thread/thread_manager.h"
 #include "oneflow/core/actor/act_event_logger.h"
 #include "oneflow/core/graph/task_node.h"
+#include "oneflow/core/device/cuda_util.h"
 
 namespace oneflow {
 
@@ -103,6 +104,9 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) {
     }
 #endif
   }
+#ifdef WITH_CUDA
+  InitGlobalCudaDeviceProp();
+#endif
   Global<SnapshotMgr>::New(plan);
   Global<MemoryAllocator>::New();
   Global<RegstMgr>::New(plan);
diff --git a/oneflow/core/job/sbp_infer_hint.cpp b/oneflow/core/job/sbp_infer_hint.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..22117b868227b2d6cd4a3b91d174114e259498b4
--- /dev/null
+++ b/oneflow/core/job/sbp_infer_hint.cpp
@@ -0,0 +1,40 @@
+#include "oneflow/core/job/sbp_infer_hint.h"
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/common/protobuf.h"
+
+namespace oneflow {
+
+int64_t SbpInferHint::parallel_num() const {
+  CHECK_GT(parallel_num_, 0);
+  return parallel_num_;
+}
+
+int64_t SbpInferHint::num_axes() const {
+  CHECK_GT(num_axes_, 0);
+  return num_axes_;
+}
+
+bool SbpInferHint::has_split_axis() const {
+  bool ret = sbp_parallel_.has_split_parallel();
+  if (ret) {
+    CHECK_GE(sbp_parallel_.split_parallel().axis(), 0);
+    CHECK_LT(sbp_parallel_.split_parallel().axis(), num_axes());
+  }
+  return ret;
+}
+
+int64_t SbpInferHint::split_axis() const {
+  CHECK(has_split_axis());
+  return sbp_parallel_.split_parallel().axis();
+}
+
+bool SbpInferHint::is_model_split() const { return is_model_blob() && has_split_axis(); }
+bool SbpInferHint::is_model_broadcast() const {
+  return is_model_blob() && sbp_parallel_.has_broadcast_parallel();
+}
+bool SbpInferHint::is_data_split() const { return is_data_blob() && has_split_axis(); }
+bool SbpInferHint::is_data_partial_sum() const {
+  return is_data_blob() && sbp_parallel_.has_partial_sum_parallel();
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/job/sbp_infer_hint.h b/oneflow/core/job/sbp_infer_hint.h
new file mode 100644
index 0000000000000000000000000000000000000000..6e89a0651e11480c50749b961cb4ae94d7b8aaf4
--- /dev/null
+++ b/oneflow/core/job/sbp_infer_hint.h
@@ -0,0 +1,41 @@
+#ifndef ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_
+#define ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_
+
+#include "oneflow/core/job/sbp_parallel.pb.h"
+
+namespace oneflow {
+
+class SbpInferHint final {
+ public:
+  SbpInferHint(bool is_model_blob, int64_t parallel_num, int64_t num_axes,
+               const SbpParallel& sbp_parallel)
+      : is_model_blob_(is_model_blob),
+        parallel_num_(parallel_num),
+        num_axes_(num_axes),
+        sbp_parallel_(sbp_parallel) {}
+  SbpInferHint(const SbpInferHint&) = default;
+  ~SbpInferHint() = default;
+
+  // Getters
+  bool is_model_blob() const { return is_model_blob_; }
+  int64_t parallel_num() const;
+  int64_t num_axes() const;
+  int64_t split_axis() const;
+  bool has_split_axis() const;
+  bool is_model_split() const;
+  bool is_model_broadcast() const;
+  bool is_data_split() const;
+  bool is_data_partial_sum() const;
+  bool is_data_blob() const { return !is_model_blob(); }
+  const SbpParallel& sbp_parallel() const { return sbp_parallel_; }
+
+ private:
+  const bool is_model_blob_;
+  const int64_t parallel_num_;
+  const int64_t num_axes_;
+  const SbpParallel sbp_parallel_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_
diff --git a/oneflow/core/job/sbp_parallel.cpp b/oneflow/core/job/sbp_parallel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5eb6d089e190749276471f85c4da6448faa036a7
--- /dev/null
+++ b/oneflow/core/job/sbp_parallel.cpp
@@ -0,0 +1,29 @@
+#include "oneflow/core/job/sbp_parallel.h"
+#include "oneflow/core/common/protobuf.h"
+
+namespace oneflow {
+
+bool operator==(const SbpParallel& lhs, const SbpParallel& rhs) {
+  return PbMd().Equivalent(lhs, rhs);
+}
+
+bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs) { return !(lhs == rhs); }
+
+//  S -> S
+//  P -> C
+//  C -> P
+SbpParallel GetDualSbpParallel(const SbpParallel& sbp_parallel) {
+  SbpParallel ret(sbp_parallel);
+  if (sbp_parallel.has_split_parallel()) {
+    //  do nothing
+  } else if (sbp_parallel.has_broadcast_parallel()) {
+    ret.mutable_partial_sum_parallel();
+  } else if (sbp_parallel.has_partial_sum_parallel()) {
+    ret.mutable_broadcast_parallel();
+  } else {
+    UNIMPLEMENTED();
+  }
+  return ret;
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/job/sbp_parallel.h b/oneflow/core/job/sbp_parallel.h
new file mode 100644
index 0000000000000000000000000000000000000000..100cfac3159254b61e92ad93f3db97c5ba30b748
--- /dev/null
+++ b/oneflow/core/job/sbp_parallel.h
@@ -0,0 +1,14 @@
+#ifndef ONEFLOW_CORE_JOB_SBP_PARALLEL_H_
+#define ONEFLOW_CORE_JOB_SBP_PARALLEL_H_
+
+#include "oneflow/core/job/sbp_parallel.pb.h"
+
+namespace oneflow {
+
+bool operator==(const SbpParallel& lhs, const SbpParallel& rhs);
+bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs);
+SbpParallel GetDualSbpParallel(const SbpParallel&);
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_
diff --git a/oneflow/core/job/sbp_parallel.proto b/oneflow/core/job/sbp_parallel.proto
new file mode 100644
index 0000000000000000000000000000000000000000..c89b23c7235a87c96cd6808b55c6284183f3cba4
--- /dev/null
+++ b/oneflow/core/job/sbp_parallel.proto
@@ -0,0 +1,20 @@
+syntax = "proto2";
+package oneflow;
+
+message SplitParallel {
+  required int64 axis = 1;
+}
+
+message BroadcastParallel {
+}
+
+message PartialSumParallel {
+}
+
+message SbpParallel {
+  oneof parallel_type {
+    SplitParallel split_parallel = 1;
+    BroadcastParallel broadcast_parallel = 2;
+    PartialSumParallel partial_sum_parallel = 3;
+  }
+}
diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto
index 4f8d9b5432e0f95f2b2cec570eab8dfd8e76b7f6..7de6d3339f35ebbf52f8709719a93e892a410a5f 100644
--- a/oneflow/core/job/task.proto
+++ b/oneflow/core/job/task.proto
@@ -41,6 +41,7 @@ enum TaskType {
   kUnpackBackward = 33;
   kRepeatForward = 34;
   kRepeatBackward = 35;
+  kReduceIdentity = 36;
 };
 
 enum AreaType {
diff --git a/oneflow/core/kernel/accuracy_kernel.cpp b/oneflow/core/kernel/accuracy_kernel.cpp
index 60b818bbc05356c35b699cce033ff35cf0dc44e1..1f2a496e03eaed78849018a8e135cfb79b98693c 100644
--- a/oneflow/core/kernel/accuracy_kernel.cpp
+++ b/oneflow/core/kernel/accuracy_kernel.cpp
@@ -1,17 +1,30 @@
 #include "oneflow/core/kernel/accuracy_kernel.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
 
 namespace oneflow {
 
 template<DeviceType device_type, typename PredType, typename LabelType>
 void AccuracyKernel<device_type, PredType, LabelType>::SetAccuracyInstanceNumBlob(
     const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
-  CHECK_GE(this->op_attribute().input_bns().size(), 2);
-  this->CheckSameDim0ValidNum(this->op_attribute().input_bns(), BnInOp2Blob);
-  int64_t dim0_valid_num_sum =
-      BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum();
-  KernelUtil<device_type, PredType>::Set(
-      ctx.device_ctx, static_cast<PredType>(dim0_valid_num_sum),
-      BnInOp2Blob("accuracy_instance_num")->mut_dptr<PredType>());
+  const Blob* weight = BnInOp2Blob("weight");
+  Blob* accuracy_instance_num = BnInOp2Blob("accuracy_instance_num");
+  if (weight == nullptr) {
+    CHECK_GE(this->op_attribute().input_bns().size(), 2);
+    this->CheckSameDim0ValidNum(this->op_attribute().input_bns(), BnInOp2Blob);
+    int64_t dim0_valid_num_sum =
+        BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum();
+    KernelUtil<device_type, PredType>::Set(ctx.device_ctx,
+                                           static_cast<PredType>(dim0_valid_num_sum),
+                                           accuracy_instance_num->mut_dptr<PredType>());
+  } else {
+    Blob* weight_reduce_tmp = BnInOp2Blob("weight_reduce_tmp");
+    CHECK_LE(weight->shape().elem_cnt(), weight_reduce_tmp->shape().elem_cnt());
+    const int64_t num_instance = weight->shape().elem_cnt();
+    NdarrayUtil<device_type, PredType>::ReduceSum(
+        ctx.device_ctx, XpuVarNdarray<PredType>({1}, accuracy_instance_num->mut_dptr<PredType>()),
+        XpuVarNdarray<const PredType>({num_instance}, weight->dptr<PredType>()),
+        XpuVarNdarray<PredType>({num_instance}, weight_reduce_tmp->mut_dptr<PredType>()));
+  }
 }
 
 template<DeviceType device_type, typename PredType, typename LabelType>
@@ -19,6 +32,8 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* X = BnInOp2Blob("prediction");
   const Blob* label = BnInOp2Blob("label");
+  const Blob* weight = BnInOp2Blob("weight");
+  if (weight != nullptr) { CHECK_EQ(label->shape().elem_cnt(), weight->shape().elem_cnt()); }
   Blob* accuracy = BnInOp2Blob("accuracy");
   auto kernel_conf = this->kernel_conf();
   const int32_t top_k = kernel_conf.op_attribute().op_conf().accuracy_conf().top_k();
@@ -29,7 +44,7 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent(
 
   AccuracyKernelUtil<device_type, PredType, LabelType>::Forward(
       ctx.device_ctx, N, D, top_k, X->dptr<PredType>(), label->dptr<LabelType>(),
-      accuracy->mut_dptr<PredType>());
+      weight ? weight->dptr<PredType>() : nullptr, accuracy->mut_dptr<PredType>());
   SetAccuracyInstanceNumBlob(ctx, BnInOp2Blob);
 }
 
@@ -48,8 +63,9 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardRecordIdInDevicePi
 template<typename PredType, typename LabelType>
 struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> {
   static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
-                      const PredType* XData, const LabelType* labelData, PredType* accuracyData) {
-    int correct = 0;
+                      const PredType* XData, const LabelType* labelData, const PredType* weight,
+                      PredType* accuracyData) {
+    PredType correct = 0;
     for (int i = 0; i < N; ++i) {
       auto label_i = labelData[i];
       auto label_pred = XData[i * D + label_i];
@@ -60,10 +76,9 @@ struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> {
           if (++cnt > top_k) { break; }
         }
       }
-      if (cnt <= top_k) { ++correct; }
+      if (cnt <= top_k) { correct += weight ? weight[i] : OneVal<PredType>::value; }
     }
-    CHECK_LE(correct, N);
-    *accuracyData = static_cast<PredType>(correct);
+    *accuracyData = correct;
   }
 };
 
diff --git a/oneflow/core/kernel/accuracy_kernel.cu b/oneflow/core/kernel/accuracy_kernel.cu
index 46a6e51e25fe03523cd75082103eae4a1b7575f2..f5671d517652785d9f068af17c49e333eae5a99e 100644
--- a/oneflow/core/kernel/accuracy_kernel.cu
+++ b/oneflow/core/kernel/accuracy_kernel.cu
@@ -17,10 +17,10 @@ __global__ void AccuracySetZeroKernel(PredType* accuracy) {
 template<typename PredType, typename LabelType>
 __global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const int32_t top_k,
                                       const PredType* Xdata, const LabelType* labelData,
-                                      PredType* accuracy) {
+                                      const PredType* weight, PredType* accuracy) {
   typedef cub::BlockReduce<int32_t, kCudaThreadsNumPerBlock> BlockReduce;
   __shared__ typename BlockReduce::TempStorage temp_storage;
-  int32_t correct = 0;
+  PredType correct = 0;
   for (int32_t row = blockIdx.x; row < N; row += gridDim.x) {
     const LabelType label = labelData[row];
     const PredType label_pred = Xdata[row * D + label];
@@ -30,20 +30,22 @@ __global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const in
       if (pred > label_pred || (pred == label_pred && col <= label)) { ++ngt; }
     }
     ngt = BlockReduce(temp_storage).Sum(ngt);
-    if (ngt <= top_k) { ++correct; }
+    if (ngt <= top_k) { correct += weight ? weight[row] : OneVal<PredType>::value; }
     __syncthreads();
   }
-  if (threadIdx.x == 0) { gpu_atomic_add(accuracy, static_cast<PredType>(correct)); }
+  if (threadIdx.x == 0) { gpu_atomic_add(accuracy, correct); }
 }
 }  // namespace
 
 template<typename PredType, typename LabelType>
 struct AccuracyKernelUtil<DeviceType::kGPU, PredType, LabelType> {
   static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
-                      const PredType* XData, const LabelType* labelData, PredType* accuracyData) {
+                      const PredType* XData, const LabelType* labelData, const PredType* weight,
+                      PredType* accuracyData) {
     AccuracySetZeroKernel<<<1, 1, 0, ctx->cuda_stream()>>>(accuracyData);
     AccuracyComputeKernel<<<BlocksNum4ThreadsNum(N), kCudaThreadsNumPerBlock, 0,
-                            ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, accuracyData);
+                            ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, weight,
+                                                  accuracyData);
   };
 };
 #define MAKE_ENTRY(data_type_pair, label_type_pair)                                      \
diff --git a/oneflow/core/kernel/accuracy_kernel.h b/oneflow/core/kernel/accuracy_kernel.h
index d3a47ba3a781499b0ba729c48f71a6e6aa482be5..c79a50f061af3616aa6861ce67d88f2777c846d1 100644
--- a/oneflow/core/kernel/accuracy_kernel.h
+++ b/oneflow/core/kernel/accuracy_kernel.h
@@ -27,7 +27,8 @@ class AccuracyKernel final : public KernelIf<device_type> {
 template<DeviceType device_type, typename PredType, typename LabelType>
 struct AccuracyKernelUtil {
   static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k,
-                      const PredType* XData, const LabelType* labelData, PredType* accuracyData);
+                      const PredType* XData, const LabelType* labelData, const PredType* weight,
+                      PredType* accuracyData);
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/adam_model_update_kernel.cpp b/oneflow/core/kernel/adam_model_update_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..498c2f688386543190262804f68f8b7de8d5b8d6
--- /dev/null
+++ b/oneflow/core/kernel/adam_model_update_kernel.cpp
@@ -0,0 +1,114 @@
+#include "oneflow/core/kernel/adam_model_update_kernel.h"
+#include "oneflow/core/kernel/normal_model_update_kernel.cuh"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+void UpdateMomentEstimate(int64_t n, bool do_bias_correction, T beta, int32_t p,
+                          const T* model_diff, const T* beta_t, T* moment) {
+  FOR_RANGE(int64_t, i, 0, n) {
+    // Update biased moment estimate
+    moment[i] = beta * moment[i] + (1 - beta) * std::pow(model_diff[i], p);
+    if (do_bias_correction) {
+      // Correct deviation of moment estimate
+      moment[i] = moment[i] / (1 - *beta_t);
+    }
+  }
+}
+
+}  // namespace
+
+template<DeviceType device_type, typename T>
+void AdamMdUpdateKernel<device_type, T>::InitModelBlobsWithRandomSeed(
+    DeviceCtx* ctx, std::mt19937* random_seed_gen,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const auto& adam_conf = this->op_conf().normal_mdupdt_conf().user_conf().adam_conf();
+  InitializerConf m_init_conf;
+  InitializerConf v_init_conf;
+  m_init_conf.mutable_constant_conf()->set_value(0.0f);
+  v_init_conf.mutable_constant_conf()->set_value(0.0f);
+  KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &m_init_conf, 0, BnInOp2Blob("m"));
+  KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &v_init_conf, 0, BnInOp2Blob("v"));
+  if (!adam_conf.do_bias_correction()) { return; }
+  InitializerConf beta1_init_conf;
+  InitializerConf beta2_init_conf;
+  beta1_init_conf.mutable_constant_conf()->set_value(adam_conf.beta1());
+  beta2_init_conf.mutable_constant_conf()->set_value(adam_conf.beta2());
+  KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &beta1_init_conf, 0,
+                                                       BnInOp2Blob("beta1_t"));
+  KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &beta2_init_conf, 0,
+                                                       BnInOp2Blob("beta2_t"));
+}
+
+template<DeviceType device_type, typename T>
+void AdamMdUpdateKernel<device_type, T>::InitModelBlobsWithDir(
+    DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const auto& adam_conf = this->op_conf().normal_mdupdt_conf().user_conf().adam_conf();
+  Blob* m_blob = BnInOp2Blob("m");
+  Blob* v_blob = BnInOp2Blob("v");
+  KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, m_blob, "m",
+                                                m_blob->shape().At(0), m_blob->shape().Count(1));
+  KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, v_blob, "v",
+                                                v_blob->shape().At(0), v_blob->shape().Count(1));
+  if (!adam_conf.do_bias_correction()) { return; }
+  Blob* beta1_t_blob = BnInOp2Blob("beta1_t");
+  Blob* beta2_t_blob = BnInOp2Blob("beta2_t");
+  KernelUtil<device_type, T>::InitializeWithDir(
+      ctx, part_id, part_num, model_load_dir, beta1_t_blob, "beta1_t", beta1_t_blob->shape().At(0),
+      beta1_t_blob->shape().Count(1));
+  KernelUtil<device_type, T>::InitializeWithDir(
+      ctx, part_id, part_num, model_load_dir, beta2_t_blob, "beta2_t", beta2_t_blob->shape().At(0),
+      beta2_t_blob->shape().Count(1));
+}
+
+template<DeviceType device_type, typename T>
+void AdamMdUpdateKernel<device_type, T>::UpdateModel(
+    DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
+    int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Blob* model_blob = BnInOp2Blob("model");
+  Blob* m_blob = BnInOp2Blob("m");
+  Blob* v_blob = BnInOp2Blob("v");
+  Blob* beta1_t_blob = BnInOp2Blob("beta1_t");
+  Blob* beta2_t_blob = BnInOp2Blob("beta2_t");
+  const AdamModelUpdateConf& adam_conf =
+      this->op_conf().normal_mdupdt_conf().user_conf().adam_conf();
+  if ((next_model_vid != 1) && adam_conf.do_bias_correction()) {
+    KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta1()),
+                                     beta1_t_blob->mut_dptr<T>(), 1);
+    KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta2()),
+                                     beta2_t_blob->mut_dptr<T>(), 1);
+  }
+  AdamMdUpdateKernelUtil<device_type, T>::UpdateModel(
+      ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, learning_rate, l1, l2,
+      static_cast<T>(adam_conf.beta1()), static_cast<T>(adam_conf.beta2()),
+      static_cast<T>(adam_conf.epsilon()), adam_conf.do_bias_correction(), next_model_vid,
+      (beta1_t_blob ? beta1_t_blob->dptr<T>() : nullptr),
+      (beta2_t_blob ? beta2_t_blob->dptr<T>() : nullptr), BnInOp2Blob("model_diff")->mut_dptr<T>(),
+      model_blob->mut_dptr<T>(), m_blob->mut_dptr<T>(), v_blob->mut_dptr<T>());
+}
+
+template<typename T>
+class AdamMdUpdateKernelUtil<DeviceType::kCPU, T> final {
+ public:
+  static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr,
+                          T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon,
+                          bool do_bias_correction, int64_t next_model_vid, const T* beta1_t,
+                          const T* beta2_t, T* model_diff, T* model, T* m, T* v) {
+    // first-order moment
+    UpdateMomentEstimate<T>(n, do_bias_correction, beta1, 1, model_diff, beta1_t, m);
+    // second-order moment
+    UpdateMomentEstimate<T>(n, do_bias_correction, beta2, 2, model_diff, beta2_t, v);
+    FOR_RANGE(int64_t, i, 0, n) {
+      model_diff[i] = m[i] / (std::sqrt(v[i]) + epsilon);
+      T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]);
+      model[i] = model[i] - learning_rate * reg_diff;
+    }
+  }
+};
+
+DEFINE_MDUPDT_KERNEL_CREATOR(Adam);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/adam_model_update_kernel.cu b/oneflow/core/kernel/adam_model_update_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a229180fef5bfee509881df0140cf0884616f3c4
--- /dev/null
+++ b/oneflow/core/kernel/adam_model_update_kernel.cu
@@ -0,0 +1,93 @@
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/kernel/adam_model_update_kernel.h"
+#include "oneflow/core/kernel/normal_model_update_kernel.cuh"
+
+namespace oneflow {
+
+namespace {
+
+template<int32_t power>
+struct PowUtil;
+
+template<>
+struct PowUtil<1> final {
+  template<typename T>
+  __device__ static T pow(const T x) {
+    return x;
+  }
+};
+
+template<>
+struct PowUtil<2> final {
+  template<typename T>
+  __device__ static T pow(const T x) {
+    return x * x;
+  }
+};
+
+template<bool do_bias_correction, typename T>
+__device__ typename std::enable_if<do_bias_correction>::type ScaleMomentum(const T beta_t,
+                                                                           T* moment) {
+  *moment /= (1 - beta_t);
+}
+
+template<bool do_bias_correction, typename T>
+__device__ typename std::enable_if<!do_bias_correction>::type ScaleMomentum(const T beta_t,
+                                                                            T* moment) {}
+
+template<int32_t power, bool do_bias_correction, typename T>
+__device__ void UpdateMomentEstimate(T beta, const T* model_diff, const T* beta_t, T* moment) {
+  // Update biased moment estimate
+  *moment = beta * (*moment) + (1 - beta) * PowUtil<power>::pow(*model_diff);
+  // Correct deviation of moment estimate
+  ScaleMomentum<do_bias_correction>(*beta_t, moment);
+}
+
+template<typename T>
+__device__ void UpdateModel(const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, T epsilon,
+                            T* model_diff, T* model, T* m, T* v) {
+  *model_diff = *m / (sqrt(*v) + epsilon);
+  T reg_diff = RegularizeDiff(*model_diff, *batch_instance_num_ptr, l1, l2, *model);
+  *model = *model - learning_rate * reg_diff;
+}
+
+template<bool do_bias_correction, typename T>
+__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T learning_rate, T l1,
+                               T l2, T beta1, T beta2, T epsilon, const T* beta1_t,
+                               const T* beta2_t, T* model_diff, T* model, T* m, T* v) {
+  CUDA_1D_KERNEL_LOOP(i, n) {
+    UpdateMomentEstimate<1, do_bias_correction>(beta1, model_diff + i, beta1_t, m + i);
+    UpdateMomentEstimate<2, do_bias_correction>(beta2, model_diff + i, beta2_t, v + i);
+    UpdateModel(batch_instance_num_ptr, learning_rate, l1, l2, epsilon, model_diff + i, model + i,
+                m + i, v + i);
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final {
+ public:
+  static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr,
+                          T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon,
+                          bool do_bias_correction, int64_t next_model_vid, const T* beta1_t,
+                          const T* beta2_t, T* model_diff, T* model, T* m, T* v) {
+    if (do_bias_correction) {
+      UpdateModelGpu<true, T>
+          <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+              n, batch_instance_num_ptr, learning_rate, l1, l2, beta1, beta2, epsilon, beta1_t,
+              beta2_t, model_diff, model, m, v);
+    } else {
+      UpdateModelGpu<false, T>
+          <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+              n, batch_instance_num_ptr, learning_rate, l1, l2, beta1, beta2, epsilon, beta1_t,
+              beta2_t, model_diff, model, m, v);
+    }
+  }
+};
+
+#define INSTANTIATE_GPU_KERNEL_UTIL(type_cpp, type_proto) \
+  template class AdamMdUpdateKernelUtil<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GPU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/adam_model_update_kernel.h b/oneflow/core/kernel/adam_model_update_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..3031f7bf2349babc574628dc6de3da6cff67f436
--- /dev/null
+++ b/oneflow/core/kernel/adam_model_update_kernel.h
@@ -0,0 +1,40 @@
+#ifndef ONEFLOW_CORE_KERNEL_ADAM_MODEL_UPDATE_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_ADAM_MODEL_UDPATE_KERNEL_H_
+
+#include "oneflow/core/kernel/normal_model_update_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class AdamMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(AdamMdUpdateKernel);
+  AdamMdUpdateKernel() = default;
+  ~AdamMdUpdateKernel() = default;
+
+ private:
+  void InitModelBlobsWithRandomSeed(
+      DeviceCtx* ctx, std::mt19937* random_seed_gen,
+      std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num,
+                             const std::string& model_load_dir,
+                             std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2,
+                   int64_t next_model_vid,
+                   std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+};
+
+template<DeviceType device_type, typename T>
+class AdamMdUpdateKernelUtil final {
+ public:
+  static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate,
+                          T l1, T l2, T beta1, T beta2, T epsilon, bool do_bias_correction,
+                          int64_t next_model_vid, const T* beta1_t, const T* beta2_t, T* model_diff,
+                          T* model, T* m, T* v);
+};
+
+DECLARE_MDUPDT_KERNEL_CREATOR(Adam);
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_ADAM_MODEL_UPDATE_KERNEL_H_
diff --git a/oneflow/core/kernel/batch_gather_kernel.cpp b/oneflow/core/kernel/batch_gather_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..501252a85cd4ac1030474e0e054eaf0c8d414d08
--- /dev/null
+++ b/oneflow/core/kernel/batch_gather_kernel.cpp
@@ -0,0 +1,118 @@
+#include "oneflow/core/kernel/batch_gather_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+Shape GetFlatShape(const Shape& shape, const int64_t axis) {
+  CHECK_GT(shape.NumAxes(), 0);
+  CHECK_GE(axis, 0);
+  CHECK_LT(axis, shape.NumAxes());
+  return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)});
+}
+
+template<DeviceType device_type, typename T, typename K>
+void BatchGatherForward(DeviceCtx* ctx, const Blob* in, const Blob* indices, Blob* out) {
+  const int64_t axis = indices->shape().NumAxes() - 1;
+  const Shape flat_out_shape = GetFlatShape(out->shape(), axis);
+  BatchGatherKernelUtil<device_type, T, K>::Forward(ctx, in->dptr<T>(), indices->dptr<K>(),
+                                                    flat_out_shape, in->shape().At(axis),
+                                                    out->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T, typename K>
+void BatchGatherBackward(DeviceCtx* ctx, const Blob* out_diff, const Blob* indices, Blob* in_diff) {
+  Memset<device_type>(ctx, in_diff->mut_dptr<T>(), 0, in_diff->ByteSizeOfDataContentField());
+  const int64_t axis = indices->shape().NumAxes() - 1;
+  const Shape flat_out_diff_shape = GetFlatShape(out_diff->shape(), axis);
+  BatchGatherKernelUtil<device_type, T, K>::Backward(ctx, out_diff->dptr<T>(), indices->dptr<K>(),
+                                                     flat_out_diff_shape, in_diff->shape().At(axis),
+                                                     in_diff->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+struct BatchGatherSwitchUtil final {
+#define MAKE_BATCH_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>
+#define DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(func_name)                    \
+  DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_BATCH_GATHER_SWITCH_ENTRY, \
+                            MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
+  DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherForward);
+  DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherBackward);
+#undef DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC
+#undef MAKE_BATCH_GATHER_SWITCH_ENTRY
+};
+
+}  // namespace
+
+template<DeviceType device_type, typename T>
+const PbMessage& BatchGatherKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().batch_gather_conf();
+}
+
+template<DeviceType device_type, typename T>
+void BatchGatherKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  BatchGatherSwitchUtil<device_type, T>::SwitchBatchGatherForward(
+      SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, BnInOp2Blob("in"),
+      BnInOp2Blob("indices"), BnInOp2Blob("out"));
+}
+
+template<DeviceType device_type, typename T>
+void BatchGatherKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  BatchGatherSwitchUtil<device_type, T>::SwitchBatchGatherBackward(
+      SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx,
+      BnInOp2Blob(GenDiffBn("out")), BnInOp2Blob("indices"), BnInOp2Blob(GenDiffBn("in")));
+}
+
+template<typename T, typename K>
+struct BatchGatherKernelUtil<DeviceType::kCPU, T, K> final {
+  static void Forward(DeviceCtx* ctx, const T* in, const K* indices, const Shape& flat_out_shape,
+                      const int64_t gather_dim_size, T* out);
+  static void Backward(DeviceCtx* ctx, const T* out_diff, const K* indices,
+                       const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff);
+};
+
+template<typename T, typename K>
+void BatchGatherKernelUtil<DeviceType::kCPU, T, K>::Forward(DeviceCtx* ctx, const T* in,
+                                                            const K* indices,
+                                                            const Shape& flat_out_shape,
+                                                            const int64_t gather_dim_size, T* out) {
+  const int64_t batch_num = flat_out_shape.At(0);
+  const int64_t indices_num = flat_out_shape.At(1);
+  const int64_t instance_size = flat_out_shape.At(2);
+  FOR_RANGE(int64_t, batch_idx, 0, batch_num) {
+    FOR_RANGE(int64_t, i, 0, indices_num) {
+      const K idx = indices[batch_idx * indices_num + i];
+      CHECK(idx >= 0 && idx < gather_dim_size);
+      const T* from = in + batch_idx * gather_dim_size * instance_size + idx * instance_size;
+      T* to = out + batch_idx * indices_num * instance_size + i * instance_size;
+      std::copy(from, from + instance_size, to);
+    }
+  }
+}
+
+template<typename T, typename K>
+void BatchGatherKernelUtil<DeviceType::kCPU, T, K>::Backward(DeviceCtx* ctx, const T* out_diff,
+                                                             const K* indices,
+                                                             const Shape& flat_out_diff_shape,
+                                                             const int64_t gather_dim_size,
+                                                             T* in_diff) {
+  const int64_t batch_num = flat_out_diff_shape.At(0);
+  const int64_t indices_num = flat_out_diff_shape.At(1);
+  const int64_t instance_size = flat_out_diff_shape.At(2);
+  FOR_RANGE(int64_t, batch_idx, 0, batch_num) {
+    FOR_RANGE(int64_t, i, 0, indices_num) {
+      const int64_t idx = indices[batch_idx * indices_num + i];
+      CHECK(idx >= 0 && idx < gather_dim_size);
+      const T* from = out_diff + batch_idx * indices_num * instance_size + i * instance_size;
+      T* to = in_diff + batch_idx * gather_dim_size * instance_size + idx * instance_size;
+      std::transform(from, from + instance_size, to, to, std::plus<T>());
+    }
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBatchGatherConf, BatchGatherKernel,
+                           FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/batch_gather_kernel.cu b/oneflow/core/kernel/batch_gather_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..4292f4fdfaa3d69587db6eda0dacb1d0b91d9fa5
--- /dev/null
+++ b/oneflow/core/kernel/batch_gather_kernel.cu
@@ -0,0 +1,87 @@
+#include "oneflow/core/kernel/batch_gather_kernel.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+#include <assert.h>
+
+namespace oneflow {
+
+namespace {
+
+template<typename K>
+__device__ int64_t GetInOffset(const int64_t out_offset, const K* indices,
+                               const int64_t indices_num, const int64_t instance_size,
+                               const int64_t gather_dim_size) {
+  const int64_t batch_idx = out_offset / (indices_num * instance_size);
+  const int64_t indices_idx = out_offset % (indices_num * instance_size) / instance_size;
+  const int64_t inner_idx = out_offset % instance_size;
+  const int64_t idx = indices[batch_idx * indices_num + indices_idx];
+  assert(idx >= 0 && idx < gather_dim_size);
+  return batch_idx * gather_dim_size * instance_size + idx * instance_size + inner_idx;
+}
+
+template<typename T, typename K>
+__global__ void BatchGatherForwardGpu(const int64_t elem_cnt, const T* in, const K* indices,
+                                      const int64_t indices_num, const int64_t instance_size,
+                                      const int64_t gather_dim_size, T* out) {
+  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
+    out[i] = in[GetInOffset<K>(i, indices, indices_num, instance_size, gather_dim_size)];
+  }
+}
+
+template<typename T, typename K>
+__global__ void BatchGatherBackwardGpu(const int64_t elem_cnt, const T* out_diff, const K* indices,
+                                       const int64_t indices_num, const int64_t instance_size,
+                                       const int64_t gather_dim_size, T* in_diff) {
+  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
+    gpu_atomic_add(
+        in_diff + GetInOffset<K>(i, indices, indices_num, instance_size, gather_dim_size),
+        out_diff[i]);
+  }
+}
+
+}  // namespace
+
+template<typename T, typename K>
+struct BatchGatherKernelUtil<DeviceType::kGPU, T, K> final {
+  static void Forward(DeviceCtx* ctx, const T* in, const K* indices, const Shape& flat_out_shape,
+                      const int64_t gather_dim_size, T* out);
+  static void Backward(DeviceCtx* ctx, const T* out_diff, const K* indices,
+                       const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff);
+};
+
+template<typename T, typename K>
+void BatchGatherKernelUtil<DeviceType::kGPU, T, K>::Forward(DeviceCtx* ctx, const T* in,
+                                                            const K* indices,
+                                                            const Shape& flat_out_shape,
+                                                            const int64_t gather_dim_size, T* out) {
+  const int64_t batch_num = flat_out_shape.At(0);
+  const int64_t indices_num = flat_out_shape.At(1);
+  const int64_t instance_size = flat_out_shape.At(2);
+  const int64_t elem_cnt = batch_num * indices_num * instance_size;
+  BatchGatherForwardGpu<T, K>
+      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          elem_cnt, in, indices, indices_num, instance_size, gather_dim_size, out);
+}
+
+template<typename T, typename K>
+void BatchGatherKernelUtil<DeviceType::kGPU, T, K>::Backward(DeviceCtx* ctx, const T* out_diff,
+                                                             const K* indices,
+                                                             const Shape& flat_out_diff_shape,
+                                                             const int64_t gather_dim_size,
+                                                             T* in_diff) {
+  const int64_t batch_num = flat_out_diff_shape.At(0);
+  const int64_t indices_num = flat_out_diff_shape.At(1);
+  const int64_t instance_size = flat_out_diff_shape.At(2);
+  const int64_t elem_cnt = batch_num * indices_num * instance_size;
+  BatchGatherBackwardGpu<T, K>
+      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          elem_cnt, out_diff, indices, indices_num, instance_size, gather_dim_size, in_diff);
+}
+
+#define MAKE_BATCH_GATHER_KERNEL_UTIL_ENTRY(in_type_pair, index_type_pair)                \
+  template struct BatchGatherKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(in_type_pair), \
+                                        OF_PP_PAIR_FIRST(index_type_pair)>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_BATCH_GATHER_KERNEL_UTIL_ENTRY, FLOATING_DATA_TYPE_SEQ,
+                                 INT_DATA_TYPE_SEQ);
+#undef MAKE_BATCH_GATHER_KERNEL_UTIL_ENTRY
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/batch_gather_kernel.h b/oneflow/core/kernel/batch_gather_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..57521b14ff054ba8fde0d6aee8457e433e0c964f
--- /dev/null
+++ b/oneflow/core/kernel/batch_gather_kernel.h
@@ -0,0 +1,33 @@
+#ifndef ONEFLOW_CORE_KERNEL_BATCH_GATHER_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_BATCH_GATHER_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class BatchGatherKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BatchGatherKernel);
+  BatchGatherKernel() = default;
+  ~BatchGatherKernel() override = default;
+
+ private:
+  const PbMessage& GetCustomizedOpConf() const override;
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BackwardDataContent(const KernelCtx& ctx,
+                           std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+};
+
+template<DeviceType device_type, typename T, typename K>
+struct BatchGatherKernelUtil final {
+  static void Forward(DeviceCtx* ctx, const T* in, const K* indices, const Shape& flat_out_shape,
+                      const int64_t gather_dim_size, T* out);
+  static void Backward(DeviceCtx* ctx, const T* out_diff, const K* indices,
+                       const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_BATCH_GATHER_KERNEL_H_
diff --git a/oneflow/core/kernel/bias_add_kernel.cpp b/oneflow/core/kernel/bias_add_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f44094296033db870fec0804cb717d978e9056ef
--- /dev/null
+++ b/oneflow/core/kernel/bias_add_kernel.cpp
@@ -0,0 +1,56 @@
+#include "oneflow/core/kernel/bias_add_kernel.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void BiasAddKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a_blob = BnInOp2Blob("a");
+  const Blob* b_blob = BnInOp2Blob("b");
+  const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
+  Blob* out_blob = BnInOp2Blob("out");
+
+  // out = bias_multiplier * b + a
+  Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr<T>(), a_blob->dptr<T>(),
+                      a_blob->ByteSizeOfDataContentField());
+  KernelUtil<device_type, T>::OFGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans,
+                                     out_blob->shape().At(0), out_blob->shape().At(1), 1,
+                                     OneVal<T>::value, bias_mul_blob->dptr<T>(), b_blob->dptr<T>(),
+                                     OneVal<T>::value, out_blob->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void BiasAddKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
+  Blob* a_diff_blob = BnInOp2Blob("a_diff");
+  Blob* b_diff_blob = BnInOp2Blob("b_diff");
+
+  Memcpy<device_type>(ctx.device_ctx, a_diff_blob->mut_dptr<T>(), out_diff_blob->dptr<T>(),
+                      out_diff_blob->ByteSizeOfDataContentField());
+  // b_diff = bias_multiplier * out_diff
+  KernelUtil<device_type, T>::OFGemm(
+      ctx.device_ctx, CblasTrans, CblasNoTrans, 1, b_diff_blob->shape().At(0),
+      out_diff_blob->shape().At(0), OneVal<T>::value, bias_mul_blob->dptr<T>(),
+      out_diff_blob->dptr<T>(), ZeroVal<T>::value, b_diff_blob->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void BiasAddKernel<device_type, T>::InitConstBufBlobs(
+    DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  InitializerConf bias_multiplier_initializer_conf;
+  bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
+  KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0,
+                                                 BnInOp2Blob("bias_multiplier"));
+}
+
+template<DeviceType device_type, typename T>
+const PbMessage& BiasAddKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().bias_add_conf();
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBiasAddConf, BiasAddKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/bias_add_kernel.h b/oneflow/core/kernel/bias_add_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..78b5c16cd7430c10c962d52017311d30bf70cfdb
--- /dev/null
+++ b/oneflow/core/kernel/bias_add_kernel.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_KERNEL_BIAS_ADD_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_BIAS_ADD_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class BiasAddKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BiasAddKernel);
+  BiasAddKernel() = default;
+  ~BiasAddKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  void InitConstBufBlobs(DeviceCtx*,
+                         std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  const PbMessage& GetCustomizedOpConf() const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_BIAS_ADD_KERNEL_H_
diff --git a/oneflow/core/kernel/boxing_kernel.cpp b/oneflow/core/kernel/boxing_kernel.cpp
index 645110161ae2d064ce6211b80d5359a4c2096163..c35d42f0c5a720447465fd0eef9e4a4177c0edbb 100644
--- a/oneflow/core/kernel/boxing_kernel.cpp
+++ b/oneflow/core/kernel/boxing_kernel.cpp
@@ -351,6 +351,41 @@ void BoxingKernel<T>::SetMaxColId(const KernelCtx& ctx,
   }
 }
 
+template<typename T>
+void BoxingKernel<T>::ForwardLossInstanceNum(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const PbRpf<std::string>& input_bns = op_attribute().input_bns();
+  const PbRpf<std::string>& output_bns = op_attribute().output_bns();
+  CHECK_GT(input_bns.size(), 0);
+  float in_loss_instance_num = BnInOp2Blob(input_bns.Get(0))->loss_instance_num();
+  const float loss_instance_num_epsilon = 1e-8;
+  const BoxingOpConf& conf = op_conf().boxing_conf();
+  if (conf.in_box_case() == BoxingOpConf::kConcatBox) {
+    FOR_RANGE(int32_t, i, 1, input_bns.size()) {
+      in_loss_instance_num += BnInOp2Blob(input_bns.Get(i))->loss_instance_num();
+    }
+  } else if (conf.in_box_case() == BoxingOpConf::kAddBox) {
+    FOR_RANGE(int32_t, i, 1, input_bns.size()) {
+      CHECK_LT(std::fabs(BnInOp2Blob(input_bns.Get(i))->loss_instance_num() - in_loss_instance_num),
+               loss_instance_num_epsilon);
+    }
+  } else {
+    UNIMPLEMENTED();
+  }
+  if (conf.out_box_case() == BoxingOpConf::kSplitBox) {
+    const float out_loss_instance_num = in_loss_instance_num / output_bns.size();
+    FOR_RANGE(int32_t, i, 0, output_bns.size()) {
+      BnInOp2Blob(output_bns.Get(i))->set_loss_instance_num(out_loss_instance_num);
+    }
+  } else if (conf.out_box_case() == BoxingOpConf::kCloneBox) {
+    FOR_RANGE(int32_t, i, 0, output_bns.size()) {
+      BnInOp2Blob(output_bns.Get(i))->set_loss_instance_num(in_loss_instance_num);
+    }
+  } else {
+    UNIMPLEMENTED();
+  }
+}
+
 ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kBoxingConf, BoxingKernel, ARITHMETIC_DATA_TYPE_SEQ);
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/boxing_kernel.h b/oneflow/core/kernel/boxing_kernel.h
index f8440525908dd1ecc1cafaafdb756efd75699d08..2f332e89ddbe43f97000f0f0f9b634ccc68abeee 100644
--- a/oneflow/core/kernel/boxing_kernel.h
+++ b/oneflow/core/kernel/boxing_kernel.h
@@ -33,6 +33,8 @@ class BoxingKernel final : public KernelIf<DeviceType::kCPU> {
 
   void SetColId(const KernelCtx&, std::function<Blob*(const std::string&)>) const;
   void SetMaxColId(const KernelCtx&, std::function<Blob*(const std::string&)>) const;
+  void ForwardLossInstanceNum(const KernelCtx& ctx,
+                              std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 
   PbRpf<std::string> ibn_0_;
   PbRpf<std::string> obn_0_;
diff --git a/oneflow/core/kernel/broadcast_add_kernel.cpp b/oneflow/core/kernel/broadcast_add_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c626c361a83968b37bfa2bb958406db3000d5590
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_add_kernel.cpp
@@ -0,0 +1,43 @@
+#include "oneflow/core/kernel/broadcast_add_kernel.h"
+#include "oneflow/core/ndarray/binary_func.h"
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/common/preprocessor.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void BroadcastAddKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a_blob = BnInOp2Blob("a");
+  const Blob* b_blob = BnInOp2Blob("b");
+  Blob* out_blob = BnInOp2Blob("out");
+  size_t num_axes = out_blob->shape().NumAxes();
+  NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncAdd>(
+      kernel_ctx.device_ctx, XpuVarNdarray<T>(out_blob, num_axes),
+      XpuVarNdarray<const T>(a_blob, num_axes), XpuVarNdarray<const T>(b_blob, num_axes));
+}
+
+template<DeviceType device_type, typename T>
+void BroadcastAddKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  Blob* a_diff_blob = BnInOp2Blob("a_diff");
+  Blob* b_diff_blob = BnInOp2Blob("b_diff");
+  Blob* bw_buf_blob = BnInOp2Blob("bw_buf");
+  size_t num_axes = out_diff_blob->shape().NumAxes();
+  if (a_diff_blob) {
+    NdarrayUtil<device_type, T>::ReduceSum(
+        kernel_ctx.device_ctx, XpuVarNdarray<T>(a_diff_blob, num_axes),
+        XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes));
+  }
+  if (b_diff_blob) {
+    NdarrayUtil<device_type, T>::ReduceSum(
+        kernel_ctx.device_ctx, XpuVarNdarray<T>(b_diff_blob, num_axes),
+        XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes));
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastAddConf, BroadcastAddKernel,
+                           ARITHMETIC_DATA_TYPE_SEQ);
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/broadcast_add_kernel.h b/oneflow/core/kernel/broadcast_add_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..196f056b7707354a354d74da2756836bda7fe110
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_add_kernel.h
@@ -0,0 +1,24 @@
+#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_ADD_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_BROADCAST_ADD_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class BroadcastAddKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastAddKernel);
+  BroadcastAddKernel() = default;
+  ~BroadcastAddKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_BROADCAST_ADD_KERNEL_H_
diff --git a/oneflow/core/kernel/broadcast_div_kernel.cpp b/oneflow/core/kernel/broadcast_div_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..06f43a1a998bbe1af9bbb5d06a812a4b95ea3714
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_div_kernel.cpp
@@ -0,0 +1,101 @@
+#include "oneflow/core/kernel/broadcast_div_kernel.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void BroadcastDivKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a = BnInOp2Blob("a");
+  const Blob* b = BnInOp2Blob("b");
+  Blob* out = BnInOp2Blob("out");
+  int64_t n = out->shape().elem_cnt();
+  if (a->shape().elem_cnt() == 1) {
+    CHECK_EQ(n, b->shape().elem_cnt());
+    KernelUtil<device_type, T>::Replicate(ctx.device_ctx, n, out->mut_dptr<T>(), a->dptr<T>());
+    KernelUtil<device_type, T>::Div(ctx.device_ctx, n, out->dptr<T>(), b->dptr<T>(),
+                                    out->mut_dptr<T>());
+  } else if (b->shape().elem_cnt() == 1) {
+    CHECK_EQ(n, a->shape().elem_cnt());
+    KernelUtil<device_type, T>::Replicate(ctx.device_ctx, n, out->mut_dptr<T>(), b->dptr<T>());
+    KernelUtil<device_type, T>::Div(ctx.device_ctx, n, a->dptr<T>(), out->dptr<T>(),
+                                    out->mut_dptr<T>());
+  } else {
+    size_t num_axes = out->shape().NumAxes();
+    NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncDiv>(
+        ctx.device_ctx, XpuVarNdarray<T>(out, num_axes), XpuVarNdarray<const T>(a, num_axes),
+        XpuVarNdarray<const T>(b, num_axes));
+  }
+}
+
+template<DeviceType device_type, typename T>
+void BroadcastDivKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a = BnInOp2Blob("a");
+  const Blob* b = BnInOp2Blob("b");
+  const Blob* out_diff = BnInOp2Blob(GenDiffBn("out"));
+  Blob* tmp = BnInOp2Blob("bw_buf");
+  Blob* a_diff = BnInOp2Blob(GenDiffBn("a"));
+  Blob* b_diff = BnInOp2Blob(GenDiffBn("b"));
+
+  int64_t n = out_diff->shape().elem_cnt();
+  KernelUtil<device_type, T>::Reciprocal(ctx.device_ctx, b->shape().elem_cnt(), b->dptr<T>(),
+                                         tmp->mut_dptr<T>());
+  if (a->shape().elem_cnt() == 1) {
+    if (a_diff) {
+      KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, tmp->dptr<T>(), 1,
+                                      a_diff->mut_dptr<T>());
+    }
+    if (b_diff) {
+      KernelUtil<device_type, T>::Square(ctx.device_ctx, n, tmp->dptr<T>(), tmp->mut_dptr<T>());
+      KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, tmp->dptr<T>(), a->dptr<T>(),
+                                              tmp->mut_dptr<T>());
+      KernelUtil<device_type, T>::Axpy(ctx.device_ctx, n, static_cast<T>(-2), tmp->dptr<T>(), 1,
+                                       tmp->mut_dptr<T>(), 1);
+      KernelUtil<device_type, T>::Mul(ctx.device_ctx, n, out_diff->dptr<T>(), tmp->dptr<T>(),
+                                      b_diff->mut_dptr<T>());
+    }
+  } else if (b->shape().elem_cnt() == 1) {
+    if (a_diff) {
+      KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, out_diff->dptr<T>(),
+                                              tmp->dptr<T>(), a_diff->mut_dptr<T>());
+    }
+    if (b_diff) {
+      KernelUtil<device_type, T>::Square(ctx.device_ctx, 1, tmp->dptr<T>(), tmp->mut_dptr<T>());
+      KernelUtil<device_type, T>::Axpy(ctx.device_ctx, 1, static_cast<T>(-2), tmp->dptr<T>(), 1,
+                                       tmp->mut_dptr<T>(), 1);
+      KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, a->dptr<T>(), 1,
+                                      b_diff->mut_dptr<T>());
+      KernelUtil<device_type, T>::Mul(ctx.device_ctx, 1, tmp->dptr<T>(), b_diff->dptr<T>(),
+                                      b_diff->mut_dptr<T>());
+    }
+  } else {
+    size_t num_axes = out_diff->shape().NumAxes();
+    XpuVarNdarray<const T> out_diff_tensor(out_diff, num_axes);
+    Blob* bw_buf_blob = BnInOp2Blob("bw_buf");
+    XpuVarNdarray<const T> const_tmp(out_diff_tensor.shape(), bw_buf_blob->dptr<T>());
+    XpuVarNdarray<T> tmp(out_diff_tensor.shape(), bw_buf_blob->mut_dptr<T>());
+    if (a_diff) {
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncDiv>(
+          ctx.device_ctx, tmp, out_diff_tensor, XpuVarNdarray<const T>(b, num_axes));
+      NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(a_diff, num_axes),
+                                             const_tmp, tmp);
+    }
+    if (b_diff) {
+      const Blob* out_blob = BnInOp2Blob("out");
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncDiv>(
+          ctx.device_ctx, tmp, XpuVarNdarray<const T>(out_blob, num_axes),
+          XpuVarNdarray<const T>(b, num_axes));
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+          ctx.device_ctx, tmp, out_diff_tensor, const_tmp);
+      NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(b_diff, num_axes),
+                                             const_tmp, tmp);
+      NdarrayUtil<device_type, T>::template ImplaceApplyUnary<UnaryFuncMinus>(
+          ctx.device_ctx, XpuVarNdarray<T>(b_diff, num_axes));
+    }
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastDivConf, BroadcastDivKernel,
+                           FLOATING_DATA_TYPE_SEQ);
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/broadcast_div_kernel.h b/oneflow/core/kernel/broadcast_div_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..7c8974f41fe4832a22ae36a620ebd03b4a02730e
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_div_kernel.h
@@ -0,0 +1,24 @@
+#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_DIV_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_BROADCAST_DIV_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class BroadcastDivKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastDivKernel);
+  BroadcastDivKernel() = default;
+  ~BroadcastDivKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_BROADCAST_DIV_KERNEL_H_
diff --git a/oneflow/core/kernel/broadcast_mul_kernel.cpp b/oneflow/core/kernel/broadcast_mul_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e1dd306f45d57d9b463fc657c66a90e60756ac77
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_mul_kernel.cpp
@@ -0,0 +1,80 @@
+#include "oneflow/core/kernel/broadcast_mul_kernel.h"
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void BroadcastMulKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a = BnInOp2Blob("a");
+  const Blob* b = BnInOp2Blob("b");
+  Blob* out = BnInOp2Blob("out");
+  int64_t n = out->shape().elem_cnt();
+  if (a->shape().elem_cnt() == 1) {
+    CHECK_EQ(n, b->shape().elem_cnt());
+    KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, b->dptr<T>(), a->dptr<T>(),
+                                            out->mut_dptr<T>());
+  } else if (b->shape().elem_cnt() == 1) {
+    CHECK_EQ(n, a->shape().elem_cnt());
+    KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, a->dptr<T>(), b->dptr<T>(),
+                                            out->mut_dptr<T>());
+  } else {
+    size_t num_axes = out->shape().NumAxes();
+    NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+        ctx.device_ctx, XpuVarNdarray<T>(out, num_axes), XpuVarNdarray<const T>(a, num_axes),
+        XpuVarNdarray<const T>(b, num_axes));
+  }
+}
+
+template<DeviceType device_type, typename T>
+void BroadcastMulKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a = BnInOp2Blob("a");
+  const Blob* b = BnInOp2Blob("b");
+  const Blob* out_diff = BnInOp2Blob(GenDiffBn("out"));
+  Blob* a_diff = BnInOp2Blob(GenDiffBn("a"));
+  Blob* b_diff = BnInOp2Blob(GenDiffBn("b"));
+  int64_t n = out_diff->shape().elem_cnt();
+  if (a->shape().elem_cnt() == 1) {
+    if (a_diff) {
+      KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, b->dptr<T>(), 1,
+                                      a_diff->mut_dptr<T>());
+    }
+    if (b_diff) {
+      KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, out_diff->dptr<T>(), a->dptr<T>(),
+                                              b_diff->mut_dptr<T>());
+    }
+  } else if (b->shape().elem_cnt() == 1) {
+    if (a_diff) {
+      KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, out_diff->dptr<T>(), b->dptr<T>(),
+                                              a_diff->mut_dptr<T>());
+    }
+    if (b_diff) {
+      KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, a->dptr<T>(), 1,
+                                      b_diff->mut_dptr<T>());
+    }
+  } else {
+    size_t num_axes = out_diff->shape().NumAxes();
+    XpuVarNdarray<const T> out_diff_tensor(out_diff, num_axes);
+    Blob* bw_buf_blob = BnInOp2Blob("bw_buf");
+    XpuVarNdarray<const T> const_tmp(out_diff_tensor.shape(), bw_buf_blob->dptr<T>());
+    XpuVarNdarray<T> tmp(out_diff_tensor.shape(), bw_buf_blob->mut_dptr<T>());
+    if (a_diff) {
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+          ctx.device_ctx, tmp, out_diff_tensor, XpuVarNdarray<const T>(b, num_axes));
+      NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(a_diff, num_axes),
+                                             const_tmp, tmp);
+    }
+    if (b_diff) {
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+          ctx.device_ctx, tmp, out_diff_tensor, XpuVarNdarray<const T>(a, num_axes));
+      NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(b_diff, num_axes),
+                                             const_tmp, tmp);
+    }
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastMulConf, BroadcastMulKernel,
+                           FLOATING_DATA_TYPE_SEQ);
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/broadcast_mul_kernel.h b/oneflow/core/kernel/broadcast_mul_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..24f66a20ed1e8b6fa1bc06e5991a93ef4b7017e2
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_mul_kernel.h
@@ -0,0 +1,24 @@
+#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_MUL_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_BROADCAST_MUL_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class BroadcastMulKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastMulKernel);
+  BroadcastMulKernel() = default;
+  ~BroadcastMulKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_BROADCAST_MUL_KERNEL_H_
diff --git a/oneflow/core/kernel/broadcast_sub_kernel.cpp b/oneflow/core/kernel/broadcast_sub_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b0d1d51d899c9a4fa0187244b46fd2fc2948553e
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_sub_kernel.cpp
@@ -0,0 +1,44 @@
+#include "oneflow/core/kernel/broadcast_sub_kernel.h"
+#include "oneflow/core/ndarray/binary_func.h"
+#include "oneflow/core/ndarray/unary_func.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void BroadcastSubKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a_blob = BnInOp2Blob("a");
+  const Blob* b_blob = BnInOp2Blob("b");
+  Blob* out_blob = BnInOp2Blob("out");
+  size_t num_axes = out_blob->shape().NumAxes();
+  NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncSub>(
+      kernel_ctx.device_ctx, XpuVarNdarray<T>(out_blob, num_axes),
+      XpuVarNdarray<const T>(a_blob, num_axes), XpuVarNdarray<const T>(b_blob, num_axes));
+}
+
+template<DeviceType device_type, typename T>
+void BroadcastSubKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  Blob* bw_buf_blob = BnInOp2Blob("bw_buf");
+  Blob* a_diff_blob = BnInOp2Blob("a_diff");
+  Blob* b_diff_blob = BnInOp2Blob("b_diff");
+  size_t num_axes = out_diff_blob->shape().NumAxes();
+  if (a_diff_blob) {
+    NdarrayUtil<device_type, T>::ReduceSum(
+        kernel_ctx.device_ctx, XpuVarNdarray<T>(a_diff_blob, num_axes),
+        XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes));
+  }
+  if (b_diff_blob) {
+    NdarrayUtil<device_type, T>::ReduceSum(
+        kernel_ctx.device_ctx, XpuVarNdarray<T>(b_diff_blob, num_axes),
+        XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes));
+    NdarrayUtil<device_type, T>::template ImplaceApplyUnary<UnaryFuncMinus>(
+        kernel_ctx.device_ctx, XpuVarNdarray<T>(b_diff_blob, num_axes));
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastSubConf, BroadcastSubKernel,
+                           FLOATING_DATA_TYPE_SEQ);
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/broadcast_sub_kernel.h b/oneflow/core/kernel/broadcast_sub_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..8f72f4a4bbb0293b73e9a40e6889c904b685c72e
--- /dev/null
+++ b/oneflow/core/kernel/broadcast_sub_kernel.h
@@ -0,0 +1,24 @@
+#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_SUB_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_BROADCAST_SUB_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class BroadcastSubKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastSubKernel);
+  BroadcastSubKernel() = default;
+  ~BroadcastSubKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_BROADCAST_SUB_KERNEL_H_
diff --git a/oneflow/core/kernel/cast_kernel.cpp b/oneflow/core/kernel/cast_kernel.cpp
index 3125df6969ab62819d2d20474167d0de8e5e6f2d..2750b8233e3a4a63d8402723058da7bb2dab4daf 100644
--- a/oneflow/core/kernel/cast_kernel.cpp
+++ b/oneflow/core/kernel/cast_kernel.cpp
@@ -6,26 +6,47 @@ namespace oneflow {
 
 namespace {
 
-template<typename T, typename U>
-void CopyBlob(const Blob* src, Blob* dst) {
+template<DeviceType device_type, typename T, typename U>
+void CopyBlob(DeviceCtx* ctx, const Blob* src, Blob* dst) {
   CHECK_EQ(src->shape(), dst->shape());
-  CopyElem(src->dptr<T>(), dst->mut_dptr<U>(), src->shape().elem_cnt());
+  if (device_type == DeviceType::kCPU) {
+    CopyElem(src->dptr<T>(), dst->mut_dptr<U>(), src->shape().elem_cnt());
+  } else if (device_type == DeviceType::kGPU) {
+    CopyElemOnGpu(ctx, src->dptr<T>(), dst->mut_dptr<U>(), src->shape().elem_cnt());
+  } else {
+    UNIMPLEMENTED();
+  }
 }
 
-#define MAKE_CAST_SWITCH_ENTRY(func_name, T, U) func_name<T, U>
-DEFINE_STATIC_SWITCH_FUNC(void, CopyBlob, MAKE_CAST_SWITCH_ENTRY,
-                          MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ),
-                          MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ));
-
 }  // namespace
 
-void CastKernel::ForwardDataContent(const KernelCtx& ctx,
-                                    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+#define MAKE_CAST_SWITCH_ENTRY(func_name, T, U) func_name<device_type, T, U>
+template<DeviceType device_type>
+struct CastUtil final {
+  DEFINE_STATIC_SWITCH_FUNC(void, CopyBlob, MAKE_CAST_SWITCH_ENTRY,
+                            MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ),
+                            MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ));
+};
+
+template<DeviceType device_type>
+void CastKernel<device_type>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* in_blob = BnInOp2Blob("in");
   Blob* out_blob = BnInOp2Blob("out");
-  SwitchCopyBlob(SwitchCase(in_blob->data_type(), out_blob->data_type()), in_blob, out_blob);
+  CastUtil<device_type>::SwitchCopyBlob(SwitchCase(in_blob->data_type(), out_blob->data_type()),
+                                        ctx.device_ctx, in_blob, out_blob);
+}
+
+template<DeviceType device_type>
+void CastKernel<device_type>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
+  CastUtil<device_type>::SwitchCopyBlob(
+      SwitchCase(out_diff_blob->data_type(), in_diff_blob->data_type()), ctx.device_ctx,
+      out_diff_blob, in_diff_blob);
 }
 
-REGISTER_KERNEL(OperatorConf::kCastConf, CastKernel);
+ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kCastConf, CastKernel);
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/cast_kernel.h b/oneflow/core/kernel/cast_kernel.h
index 381df7f1d61986df03100506606512a2cca886bf..259397c785ea1d0064ccc280e378e8ad742904c7 100644
--- a/oneflow/core/kernel/cast_kernel.h
+++ b/oneflow/core/kernel/cast_kernel.h
@@ -5,7 +5,8 @@
 
 namespace oneflow {
 
-class CastKernel final : public KernelIf<DeviceType::kCPU> {
+template<DeviceType device_type>
+class CastKernel final : public KernelIf<device_type> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(CastKernel);
   CastKernel() = default;
@@ -15,9 +16,7 @@ class CastKernel final : public KernelIf<DeviceType::kCPU> {
   void ForwardDataContent(const KernelCtx&,
                           std::function<Blob*(const std::string&)>) const override;
   void BackwardDataContent(const KernelCtx&,
-                           std::function<Blob*(const std::string&)>) const override {
-    TODO();
-  }
+                           std::function<Blob*(const std::string&)>) const override;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/clone_kernel.cpp b/oneflow/core/kernel/clone_kernel.cpp
index 3f1dedda242f93cfcd96f46d1c0e7fa9cd478900..a7e111f5e4a91519e889cb0d69547882543a58f2 100644
--- a/oneflow/core/kernel/clone_kernel.cpp
+++ b/oneflow/core/kernel/clone_kernel.cpp
@@ -21,7 +21,7 @@ void CloneKernel<device_type, T>::BackwardDataContent(
   size_t out_num = odbns.size();
   if (out_num == 0) return;
   Blob* in_diff_blob = BnInOp2Blob(this->op_attribute().input_diff_bns(0));
-  auto out_diff = [&](int32_t idx) {
+  auto out_diff = [=](int32_t idx) {
     return BnInOp2Blob(this->op_attribute().output_diff_bns(idx));
   };
   static const int kWidth = 8;
diff --git a/oneflow/core/kernel/constant_kernel.cpp b/oneflow/core/kernel/constant_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..91ae77632bacf04950b645c729b314a7aa873a31
--- /dev/null
+++ b/oneflow/core/kernel/constant_kernel.cpp
@@ -0,0 +1,23 @@
+#include "oneflow/core/kernel/constant_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void ConstantKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  if (is_init_) { return; }
+  CHECK(this->kernel_conf().has_constant_conf());
+  const ConstantKernelConf& conf = this->kernel_conf().constant_conf();
+  KernelUtil<device_type, T>::InitializeWithConf(ctx.device_ctx, conf.initializer(),
+                                                 conf.random_seed(), BnInOp2Blob("out"));
+  is_init_ = true;
+}
+
+template<DeviceType device_type, typename T>
+const PbMessage& ConstantKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().constant_conf();
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kConstantConf, ConstantKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/constant_kernel.h b/oneflow/core/kernel/constant_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..4895ae245338b7a0bac0e086d06d59b162f10323
--- /dev/null
+++ b/oneflow/core/kernel/constant_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_CONSTANT_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_CONSTANT_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class ConstantKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ConstantKernel);
+  ConstantKernel() : is_init_(false) {}
+  ~ConstantKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  const PbMessage& GetCustomizedOpConf() const override;
+
+  mutable bool is_init_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_CONSTANT_KERNEL_H_
diff --git a/oneflow/core/kernel/debug_kernel.cpp b/oneflow/core/kernel/debug_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d561fc9edda88cacb99d9ba6feba9aafba54dd0e
--- /dev/null
+++ b/oneflow/core/kernel/debug_kernel.cpp
@@ -0,0 +1,117 @@
+#include "oneflow/core/kernel/debug_kernel.h"
+#include "oneflow/core/operator/op_conf.pb.h"
+#include "oneflow/core/record/record.pb.h"
+#include "oneflow/core/record/ofrecord_raw_decoder.h"
+#include "oneflow/core/record/ofrecord_raw_encoder.h"
+
+namespace oneflow {
+
+namespace {
+
+size_t FeatureSize(const Feature& feature) {
+  if (feature.has_double_list()) { return feature.double_list().value_size(); }
+  if (feature.has_float_list()) { return feature.float_list().value_size(); }
+  if (feature.has_int32_list()) { return feature.int32_list().value_size(); }
+  UNIMPLEMENTED();
+  return 0;
+}
+
+template<typename T>
+void Decode(Blob* blob, const Feature& feature) {
+  OFRecordDecoderImpl<EncodeCase::kRaw, T> decoder;
+  CHECK_EQ(blob->shape().elem_cnt(), FeatureSize(feature));
+  decoder.ReadOneCol(nullptr, feature, BlobConf(), 0, blob->mut_dptr<T>(), blob->shape().elem_cnt(),
+                     []() { return 0; });
+}
+
+template<typename T>
+void EncodeAndDump(const Blob* blob, PersistentOutStream* out_stream) {
+  Feature feature;
+  OFRecordEncoderImpl<EncodeCase::kRaw, T> encoder;
+  encoder.EncodeBlob(nullptr, blob, &feature);
+  *out_stream << feature;
+  out_stream->Flush();
+}
+
+void InitConstFeature(Feature* feature, const std::string& filepath) {
+  PersistentInStream in_stream(SnapshotFS(), filepath);
+  int feature_size = -1;
+  CHECK_EQ(in_stream.Read(reinterpret_cast<char*>(&feature_size), sizeof(int64_t)), 0);
+  std::unique_ptr<char[]> buffer(new char[feature_size]);
+  CHECK_EQ(in_stream.Read(buffer.get(), feature_size), 0);
+  feature->ParseFromArray(buffer.get(), feature_size);
+}
+
+}  // namespace
+
+template<typename T>
+void DebugKernel<T>::InitOutStream(std::unique_ptr<PersistentOutStream>* out_stream,
+                                   const ParallelContext* parallel_ctx, const std::string& dir) {
+  const auto& conf = this->op_conf().debug_conf();
+  OfCallOnce(dir, SnapshotFS(), &fs::FileSystem::RecursivelyCreateDir);
+  int32_t part_name_suffix_length = conf.part_name_suffix_length();
+  std::string num = std::to_string(parallel_ctx->parallel_id());
+  int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0);
+  std::string file_path =
+      JoinPath(dir, conf.part_name_prefix() + std::string(zero_count, '0') + num);
+  out_stream->reset(new PersistentOutStream(SnapshotFS(), file_path));
+}
+
+template<typename T>
+void DebugKernel<T>::VirtualKernelInit(const ParallelContext* parallel_ctx) {
+  const auto& conf = this->op_conf().debug_conf();
+  if (conf.has_in_blob_dump_dir()) {
+    InitOutStream(&in_blob_out_stream_, parallel_ctx, conf.in_blob_dump_dir());
+  }
+  if (conf.has_out_diff_blob_dump_dir()) {
+    InitOutStream(&out_diff_blob_out_stream_, parallel_ctx, conf.out_diff_blob_dump_dir());
+  }
+  if (conf.has_const_out_feature_load_filepath()) {
+    InitConstFeature(&const_out_blob_feature_, conf.const_out_feature_load_filepath());
+  }
+  if (conf.has_const_in_diff_feature_load_filepath()) {
+    InitConstFeature(&const_in_diff_blob_feature_, conf.const_in_diff_feature_load_filepath());
+  }
+}
+
+template<typename T>
+void DebugKernel<T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  const auto& conf = this->op_conf().debug_conf();
+  if (conf.has_in_blob_dump_dir()) { EncodeAndDump<T>(in_blob, in_blob_out_stream_.get()); }
+  if (conf.const_out_case() == DebugOpConf::ConstOutCase::CONST_OUT_NOT_SET) {
+    out_blob->CopyDataContentFrom(ctx.device_ctx, in_blob);
+  } else if (conf.has_const_out_feature_load_filepath()) {
+    Decode<T>(out_blob, const_out_blob_feature_);
+  } else if (conf.has_const_out_feature()) {
+    Decode<T>(out_blob, conf.const_out_feature());
+  } else {
+    UNIMPLEMENTED();
+  }
+}
+
+template<typename T>
+void DebugKernel<T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Blob* in_diff_blob = BnInOp2Blob("in_diff");
+  const Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  const auto& conf = this->op_conf().debug_conf();
+  if (conf.has_out_diff_blob_dump_dir()) {
+    EncodeAndDump<T>(out_diff_blob, out_diff_blob_out_stream_.get());
+  }
+  if (conf.const_in_diff_case() == DebugOpConf::ConstInDiffCase::CONST_IN_DIFF_NOT_SET) {
+    in_diff_blob->CopyDataContentFrom(ctx.device_ctx, out_diff_blob);
+  } else if (conf.has_const_in_diff_feature_load_filepath()) {
+    Decode<T>(in_diff_blob, const_in_diff_blob_feature_);
+  } else if (conf.has_const_in_diff_feature()) {
+    Decode<T>(in_diff_blob, conf.const_in_diff_feature());
+  } else {
+    UNIMPLEMENTED();
+  }
+}
+
+ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kDebugConf, DebugKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/debug_kernel.h b/oneflow/core/kernel/debug_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..45dccfafe650d8ef426fd4ba962d07201eb2443d
--- /dev/null
+++ b/oneflow/core/kernel/debug_kernel.h
@@ -0,0 +1,33 @@
+#ifndef ONEFLOW_CORE_KERNEL_DEBUG_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_DEBUG_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<typename T>
+class DebugKernel final : public KernelIf<DeviceType::kCPU> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(DebugKernel);
+  DebugKernel() = default;
+  ~DebugKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  void InitOutStream(std::unique_ptr<PersistentOutStream>* out_stream,
+                     const ParallelContext* parallel_ctx, const std::string& dir);
+  void VirtualKernelInit(const ParallelContext* parallel_ctx);
+
+  std::unique_ptr<PersistentOutStream> in_blob_out_stream_;
+  std::unique_ptr<PersistentOutStream> out_diff_blob_out_stream_;
+  Feature const_out_blob_feature_;
+  Feature const_in_diff_blob_feature_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_DEBUG_KERNEL_H_
diff --git a/oneflow/core/kernel/dot_kernel.cpp b/oneflow/core/kernel/dot_kernel.cpp
index 43c5ef3c818df80bd238f7c0e0e29c0dfd0dd4b0..c08acfcf8d467abe146d684a485993d2e19998d6 100644
--- a/oneflow/core/kernel/dot_kernel.cpp
+++ b/oneflow/core/kernel/dot_kernel.cpp
@@ -19,7 +19,7 @@ void DotKernel<device_type, T>::ForwardDataContent(
   KernelUtil<device_type, T>::RowSum(ctx.device_ctx, piece_size, dim, tmp_blob->dptr<T>(),
                                      out_blob->mut_dptr<T>(), tmp_storage_blob->mut_dptr<T>(),
                                      sizeof(T) * piece_size * dim);
-  if (this->op_conf().matmul_conf().has_bias()) {
+  if (this->op_conf().dot_conf().has_bias()) {
     const Blob* bias_blob = BnInOp2Blob("bias");
     // out += bias
     KernelUtil<device_type, T>::Axpy(ctx.device_ctx, piece_size, OneVal<T>::value,
@@ -49,7 +49,7 @@ void DotKernel<device_type, T>::BackwardDataContent(
                                   tmp_blob->dptr<T>(), weight_blob->dptr<T>(),
                                   in_diff_blob->mut_dptr<T>());
 
-  if (this->op_conf().matmul_conf().has_bias()) {
+  if (this->op_conf().dot_conf().has_bias()) {
     Blob* bias_diff_blob = BnInOp2Blob("bias_diff");
     // bias_diff = out_diff
     KernelUtil<device_type, T>::Copy(ctx.device_ctx, out_diff_blob->shape().elem_cnt(),
diff --git a/oneflow/core/kernel/dropout_kernel.cpp b/oneflow/core/kernel/dropout_kernel.cpp
index da2f118db72501687307c1dd4df68af12e8c8597..5bc54419e7282def5b54aecbac6f31f981cb6bf3 100644
--- a/oneflow/core/kernel/dropout_kernel.cpp
+++ b/oneflow/core/kernel/dropout_kernel.cpp
@@ -37,7 +37,7 @@ void DropoutKernel<device_type, T>::BackwardDataContent(
 }
 
 template<DeviceType device_type, typename T>
-void DropoutKernel<device_type, T>::Dropout(DeviceCtx* ctx, const int64_t n, double dropout_rate,
+void DropoutKernel<device_type, T>::Dropout(DeviceCtx* ctx, const int64_t n, float dropout_rate,
                                             const T* x, float* random_mask, T* y) const {
   random_generator_->Uniform(n, random_mask);
   DropoutKernelUtil<device_type, T>::MaskAndScale(ctx, n, dropout_rate, 1 / (1 - dropout_rate), x,
@@ -46,7 +46,7 @@ void DropoutKernel<device_type, T>::Dropout(DeviceCtx* ctx, const int64_t n, dou
 
 template<DeviceType device_type, typename T>
 void DropoutKernel<device_type, T>::DropoutBackward(DeviceCtx* ctx, const int64_t n,
-                                                    double dropout_rate, const T* dy,
+                                                    float dropout_rate, const T* dy,
                                                     const float* random_mask, T* dx) const {
   DropoutKernelUtil<device_type, T>::MaskAndScale(ctx, n, dropout_rate, 1 / (1 - dropout_rate), dy,
                                                   random_mask, dx);
@@ -54,7 +54,7 @@ void DropoutKernel<device_type, T>::DropoutBackward(DeviceCtx* ctx, const int64_
 
 template<typename T>
 struct DropoutKernelUtil<DeviceType::kCPU, T> final {
-  static void MaskAndScale(DeviceCtx* ctx, const int64_t n, double threshold, double scale,
+  static void MaskAndScale(DeviceCtx* ctx, const int64_t n, float threshold, float scale,
                            const T* x, const float* random_mask, T* y) {
     for (int64_t i = 0; i < n; ++i) { y[i] = x[i] * (random_mask[i] > threshold) * scale; }
   }
diff --git a/oneflow/core/kernel/dropout_kernel.cu b/oneflow/core/kernel/dropout_kernel.cu
index c33c535f739c2a3306daf20cd90bd78214017cea..10e6a4ff932d0efa302b1b14c108e036088db7ca 100644
--- a/oneflow/core/kernel/dropout_kernel.cu
+++ b/oneflow/core/kernel/dropout_kernel.cu
@@ -8,7 +8,7 @@ namespace oneflow {
 namespace {
 
 template<typename T>
-__global__ void MaskAndScaleGpu(const int64_t n, double threshold, double scale, const T* x,
+__global__ void MaskAndScaleGpu(const int64_t n, float threshold, float scale, const T* x,
                                 const float* random_mask, T* y) {
   CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * (random_mask[i] > threshold) * scale; }
 }
@@ -17,7 +17,7 @@ __global__ void MaskAndScaleGpu(const int64_t n, double threshold, double scale,
 
 template<typename T>
 struct DropoutKernelUtil<DeviceType::kGPU, T> final {
-  static void MaskAndScale(DeviceCtx* ctx, const int64_t n, double threshold, double scale,
+  static void MaskAndScale(DeviceCtx* ctx, const int64_t n, float threshold, float scale,
                            const T* x, const float* random_mask, T* y) {
     MaskAndScaleGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
         n, threshold, scale, x, random_mask, y);
diff --git a/oneflow/core/kernel/dropout_kernel.h b/oneflow/core/kernel/dropout_kernel.h
index e181d1f1218fa8e0995976af28fe509c642f49e0..c40c7df50d9f6bddae7125aaa7f7aa0d4cb7860c 100644
--- a/oneflow/core/kernel/dropout_kernel.h
+++ b/oneflow/core/kernel/dropout_kernel.h
@@ -23,10 +23,10 @@ class DropoutKernel final : public KernelIf<device_type> {
 
   // random_mask = random_uniform(0, 1)
   // y = dropout(x, random_mask, dropout_rate)
-  void Dropout(DeviceCtx* ctx, const int64_t n, double dropout_rate, const T* x, float* random_mask,
+  void Dropout(DeviceCtx* ctx, const int64_t n, float dropout_rate, const T* x, float* random_mask,
                T* y) const;
   // y = dropout(x, random_mask)
-  void DropoutBackward(DeviceCtx* ctx, const int64_t n, double dropout_rate, const T* dy,
+  void DropoutBackward(DeviceCtx* ctx, const int64_t n, float dropout_rate, const T* dy,
                        const float* random_mask, T* dx) const;
 
   std::unique_ptr<RandomGenerator<device_type>> random_generator_;
@@ -34,7 +34,7 @@ class DropoutKernel final : public KernelIf<device_type> {
 
 template<DeviceType device_type, typename T>
 struct DropoutKernelUtil final {
-  static void MaskAndScale(DeviceCtx* ctx, const int64_t n, double threshold, double scale,
+  static void MaskAndScale(DeviceCtx* ctx, const int64_t n, float threshold, float scale,
                            const T* x, const float* random_mask, T* y);
 };
 
diff --git a/oneflow/core/kernel/gather_kernel.cpp b/oneflow/core/kernel/gather_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..7e0181e9ae107ae95743a9537408a636fd2bd731
--- /dev/null
+++ b/oneflow/core/kernel/gather_kernel.cpp
@@ -0,0 +1,113 @@
+#include "oneflow/core/kernel/gather_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+Shape GetFlatShape(const Shape& shape, int64_t axis) {
+  CHECK_GT(shape.NumAxes(), 0);
+  CHECK_GE(axis, 0);
+  CHECK_LT(axis, shape.NumAxes());
+  return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)});
+}
+
+template<DeviceType device_type, typename T, typename K>
+void GatherForward(DeviceCtx* ctx, const Blob* indices, const Blob* in, int64_t axis, Blob* out) {
+  const Shape flat_in_shape = GetFlatShape(in->shape(), axis);
+  GatherKernelUtil<device_type, T, K>::Forward(ctx, indices->dptr<K>(), indices->shape().elem_cnt(),
+                                               in->dptr<T>(), flat_in_shape, out->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T, typename K>
+void GatherBackward(DeviceCtx* ctx, const Blob* indices, const Blob* out_diff, int64_t axis,
+                    Blob* in_diff) {
+  Memset<device_type>(ctx, in_diff->mut_dptr<T>(), 0, in_diff->ByteSizeOfDataContentField());
+  const Shape flat_in_shape = GetFlatShape(in_diff->shape(), axis);
+  GatherKernelUtil<device_type, T, K>::Backward(ctx, indices->dptr<K>(),
+                                                indices->shape().elem_cnt(), out_diff->dptr<T>(),
+                                                flat_in_shape, in_diff->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+struct GatherSwitchUtil final {
+#define MAKE_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>
+#define DEFINE_GATHER_STATIC_SWITCH_FUNC(func_name)                    \
+  DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_GATHER_SWITCH_ENTRY, \
+                            MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
+  DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherForward);
+  DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherBackward);
+#undef DEFINE_GATHER_STATIC_SWITCH_FUNC
+#undef MAKE_GATHER_SWITCH_ENTRY
+};
+
+}  // namespace
+
+template<DeviceType device_type, typename T>
+const PbMessage& GatherKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().gather_conf();
+}
+
+template<DeviceType device_type, typename T>
+void GatherKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  GatherSwitchUtil<device_type, T>::SwitchGatherForward(
+      SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, BnInOp2Blob("indices"),
+      BnInOp2Blob("in"), this->kernel_conf().gather_conf().axis(), BnInOp2Blob("out"));
+}
+
+template<DeviceType device_type, typename T>
+void GatherKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  GatherSwitchUtil<device_type, T>::SwitchGatherBackward(
+      SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, BnInOp2Blob("indices"),
+      BnInOp2Blob(GenDiffBn("out")), this->kernel_conf().gather_conf().axis(),
+      BnInOp2Blob(GenDiffBn("in")));
+}
+
+template<typename T, typename K>
+struct GatherKernelUtil<DeviceType::kCPU, T, K> final {
+  static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in,
+                      const Shape& flat_in_shape, T* out);
+  static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff,
+                       const Shape& flat_in_shape, T* in_diff);
+};
+
+template<typename T, typename K>
+void GatherKernelUtil<DeviceType::kCPU, T, K>::Forward(DeviceCtx* ctx, const K* indices,
+                                                       int64_t num_indices, const T* in,
+                                                       const Shape& flat_in_shape, T* out) {
+  const int64_t outer_dim_size = flat_in_shape.At(0);
+  const int64_t gather_dim_size = flat_in_shape.At(1);
+  const int64_t inner_dim_size = flat_in_shape.At(2);
+  FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) {
+    FOR_RANGE(int64_t, i, 0, num_indices) {
+      const int64_t idx = indices[i];
+      CHECK(idx >= 0 && idx < gather_dim_size);
+      const T* from = in + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size;
+      T* to = out + outer_idx * num_indices * inner_dim_size + i * inner_dim_size;
+      std::copy(from, from + inner_dim_size, to);
+    }
+  }
+}
+
+template<typename T, typename K>
+void GatherKernelUtil<DeviceType::kCPU, T, K>::Backward(DeviceCtx* ctx, const K* indices,
+                                                        int64_t num_indices, const T* out_diff,
+                                                        const Shape& flat_in_shape, T* in_diff) {
+  const int64_t outer_dim_size = flat_in_shape.At(0);
+  const int64_t gather_dim_size = flat_in_shape.At(1);
+  const int64_t inner_dim_size = flat_in_shape.At(2);
+  FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) {
+    FOR_RANGE(int64_t, i, 0, num_indices) {
+      const int64_t idx = indices[i];
+      CHECK(idx >= 0 && idx < gather_dim_size);
+      const T* from = out_diff + outer_idx * num_indices * inner_dim_size + i * inner_dim_size;
+      T* to = in_diff + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size;
+      std::transform(from, from + inner_dim_size, to, to, std::plus<T>());
+    }
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGatherConf, GatherKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/gather_kernel.cu b/oneflow/core/kernel/gather_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2f4bddfbab73cde764fce7088c2dec5555ded853
--- /dev/null
+++ b/oneflow/core/kernel/gather_kernel.cu
@@ -0,0 +1,79 @@
+#include "oneflow/core/kernel/gather_kernel.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+#include <assert.h>
+
+namespace oneflow {
+
+namespace {
+
+template<typename K>
+__device__ int64_t get_in_offset(int64_t out_offset, const K* indices, int64_t num_indices,
+                                 int64_t gather_dim_size, int64_t inner_dim_size) {
+  const int64_t outer_dim_elem_cnt = num_indices * inner_dim_size;
+  const int64_t outer_idx = out_offset / outer_dim_elem_cnt;
+  const int64_t indices_idx = out_offset % outer_dim_elem_cnt / inner_dim_size;
+  const int64_t inner_idx = out_offset % inner_dim_size;
+  const int64_t idx = indices[indices_idx];
+  assert(idx >= 0 && idx < gather_dim_size);
+  return outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size + inner_idx;
+}
+
+template<typename T, typename K>
+__global__ void GatherForwardGpu(int64_t elem_cnt, const K* indices, int64_t num_indices,
+                                 const T* in, int64_t gather_dim_size, int64_t inner_dim_size,
+                                 T* out) {
+  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
+    out[i] = in[get_in_offset<K>(i, indices, num_indices, gather_dim_size, inner_dim_size)];
+  }
+}
+
+template<typename T, typename K>
+__global__ void GatherBackwardGpu(int64_t elem_cnt, const K* indices, int64_t num_indices,
+                                  const T* out_diff, int64_t gather_dim_size,
+                                  int64_t inner_dim_size, T* in_diff) {
+  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
+    gpu_atomic_add(
+        in_diff + get_in_offset<K>(i, indices, num_indices, gather_dim_size, inner_dim_size),
+        out_diff[i]);
+  }
+}
+
+}  // namespace
+
+template<typename T, typename K>
+struct GatherKernelUtil<DeviceType::kGPU, T, K> final {
+  static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in,
+                      const Shape& flat_in_shape, T* out);
+  static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff,
+                       const Shape& flat_in_shape, T* in_diff);
+};
+
+template<typename T, typename K>
+void GatherKernelUtil<DeviceType::kGPU, T, K>::Forward(DeviceCtx* ctx, const K* indices,
+                                                       int64_t num_indices, const T* in,
+                                                       const Shape& flat_in_shape, T* out) {
+  const int64_t elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2);
+  GatherForwardGpu<T, K>
+      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          elem_cnt, indices, num_indices, in, flat_in_shape.At(1), flat_in_shape.At(2), out);
+}
+
+template<typename T, typename K>
+void GatherKernelUtil<DeviceType::kGPU, T, K>::Backward(DeviceCtx* ctx, const K* indices,
+                                                        int64_t num_indices, const T* out_diff,
+                                                        const Shape& flat_in_shape, T* in_diff) {
+  const int64_t elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2);
+  GatherBackwardGpu<T, K>
+      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          elem_cnt, indices, num_indices, out_diff, flat_in_shape.At(1), flat_in_shape.At(2),
+          in_diff);
+}
+
+#define MAKE_GATHER_KERNEL_UTIL_ENTRY(in_type_pair, index_type_pair)                 \
+  template struct GatherKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(in_type_pair), \
+                                   OF_PP_PAIR_FIRST(index_type_pair)>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_GATHER_KERNEL_UTIL_ENTRY, FLOATING_DATA_TYPE_SEQ,
+                                 INT_DATA_TYPE_SEQ);
+#undef MAKE_GATHER_KERNEL_UTIL_ENTRY
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/gather_kernel.h b/oneflow/core/kernel/gather_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..ae711a7f64d274fbd2a40c6e37d357ab65d1b367
--- /dev/null
+++ b/oneflow/core/kernel/gather_kernel.h
@@ -0,0 +1,33 @@
+#ifndef ONEFLOW_CORE_KERNEL_GATHER_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_GATHER_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class GatherKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(GatherKernel);
+  GatherKernel() = default;
+  ~GatherKernel() override = default;
+
+ private:
+  const PbMessage& GetCustomizedOpConf() const override;
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BackwardDataContent(const KernelCtx& ctx,
+                           std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+};
+
+template<DeviceType device_type, typename T, typename K>
+struct GatherKernelUtil final {
+  static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in,
+                      const Shape& flat_in_shape, T* out);
+  static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff,
+                       const Shape& flat_in_shape, T* in_diff);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_GATHER_KERNEL_H_
diff --git a/oneflow/core/kernel/gelu_kernel.cpp b/oneflow/core/kernel/gelu_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..be1b273a70af24914372e3e76ae967f034c9e818
--- /dev/null
+++ b/oneflow/core/kernel/gelu_kernel.cpp
@@ -0,0 +1,52 @@
+#include "oneflow/core/kernel/gelu_kernel.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void GeluKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  GeluKernelUtil<device_type, T>::GeluForward(ctx.device_ctx, in_blob->static_shape().elem_cnt(),
+                                              in_blob->dptr<T>(),
+                                              BnInOp2Blob("out")->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void GeluKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  GeluKernelUtil<device_type, T>::GeluBackward(
+      ctx.device_ctx, in_blob->static_shape().elem_cnt(), in_blob->dptr<T>(),
+      BnInOp2Blob("out_diff")->dptr<T>(), BnInOp2Blob("in_diff")->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+const PbMessage& GeluKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().gelu_conf();
+}
+
+template<typename T>
+struct GeluKernelUtil<DeviceType::kCPU, T> {
+  static void GeluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
+    T inv_sqrt2 = std::sqrt(0.5);
+    FOR_RANGE(int32_t, i, 0, n) { y[i] = 0.5 * x[i] * (1.0 + std::erf(inv_sqrt2 * x[i])); }
+  }
+
+  static void GeluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* dy, T* dx) {
+    T inv_sqrt2 = std::sqrt(0.5);
+    T coef = std::sqrt(2.0 / std::acos(-1.0));
+    FOR_RANGE(int32_t, i, 0, n) {
+      dx[i] = 0.5 * (1.0 + std::erf(inv_sqrt2 * x[i]) + x[i] * coef * std::exp(-0.5 * x[i] * x[i]))
+              * dy[i];
+    }
+  }
+};
+
+#define INSTANTIATE_GELU_KERNEL_UTIL(type_cpp, type_proto) \
+  template struct GeluKernelUtil<DeviceType::kCPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GELU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGeluConf, GeluKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/gelu_kernel.cu b/oneflow/core/kernel/gelu_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a36e506eb07eb5b982d3fc82d523c87a633ee387
--- /dev/null
+++ b/oneflow/core/kernel/gelu_kernel.cu
@@ -0,0 +1,69 @@
+#include "oneflow/core/kernel/gelu_kernel.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void GeluForwardGpu(const int64_t n, const T* x, const T inv_sqrt2, T* y) {
+  UNIMPLEMENTED();
+}
+
+template<typename T>
+__global__ void GeluBackwardGpu(const int64_t n, const T* x, const T* dy, const T inv_sqrt2,
+                                const T coef, T* dx) {
+  UNIMPLEMENTED();
+}
+
+template<>
+__global__ void GeluForwardGpu(const int64_t n, const float* x, const float inv_sqrt2, float* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = 0.5f * x[i] * (1.0f + erff(inv_sqrt2 * x[i])); }
+}
+
+template<>
+__global__ void GeluBackwardGpu(const int64_t n, const float* x, const float* dy,
+                                const float inv_sqrt2, const float coef, float* dx) {
+  CUDA_1D_KERNEL_LOOP(i, n) {
+    dx[i] =
+        0.5f * (1.0f + erff(inv_sqrt2 * x[i]) + x[i] * coef * expf(-0.5f * x[i] * x[i])) * dy[i];
+  }
+}
+
+template<>
+__global__ void GeluForwardGpu(const int64_t n, const double* x, const double inv_sqrt2,
+                               double* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = 0.5 * x[i] * (1.0 + erf(inv_sqrt2 * x[i])); }
+}
+
+template<>
+__global__ void GeluBackwardGpu(const int64_t n, const double* x, const double* dy,
+                                const double inv_sqrt2, const double coef, double* dx) {
+  CUDA_1D_KERNEL_LOOP(i, n) {
+    dx[i] = 0.5 * (1.0 + erf(inv_sqrt2 * x[i]) + x[i] * coef * exp(-0.5 * x[i] * x[i])) * dy[i];
+  }
+}
+
+}  // namespace
+
+template<typename T>
+struct GeluKernelUtil<DeviceType::kGPU, T> {
+  static void GeluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
+    const T inv_sqrt2 = sqrt(0.5);
+    GeluForwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+        n, x, inv_sqrt2, y);
+  }
+
+  static void GeluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* dy, T* dx) {
+    const T inv_sqrt2 = sqrt(0.5);
+    const T coef = sqrt(2.0 / acos(-1.0));
+    GeluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+        n, x, dy, inv_sqrt2, coef, dx);
+  }
+};
+
+#define INSTANTIATE_GELU_KERNEL_UTIL(type_cpp, type_proto) \
+  template struct GeluKernelUtil<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GELU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/gelu_kernel.h b/oneflow/core/kernel/gelu_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..a6859c6009086c253ef792c0cc6b08585193f6f2
--- /dev/null
+++ b/oneflow/core/kernel/gelu_kernel.h
@@ -0,0 +1,32 @@
+#ifndef ONEFLOW_CORE_KERNEL_GELU_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_GELU_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class GeluKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(GeluKernel);
+  GeluKernel() = default;
+  ~GeluKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  const PbMessage& GetCustomizedOpConf() const override;
+};
+
+template<DeviceType device_type, typename T>
+struct GeluKernelUtil {
+  static void GeluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
+
+  static void GeluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* dy, T* dx);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_GELU_KERNEL_H_
diff --git a/oneflow/core/kernel/hinge_loss_kernel.h b/oneflow/core/kernel/hinge_loss_kernel.h
index befa52b31f2e9dc80873fc3e04556071dd604d2f..606e996880d8cf20c4647e76a57f759ecb6fedd0 100644
--- a/oneflow/core/kernel/hinge_loss_kernel.h
+++ b/oneflow/core/kernel/hinge_loss_kernel.h
@@ -6,7 +6,7 @@
 namespace oneflow {
 
 template<DeviceType device_type, typename PredType, typename LabelType>
-class HingeLossKernel final : public LossKernel<device_type, PredType, LabelType> {
+class HingeLossKernel final : public LossKernel<device_type, PredType> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(HingeLossKernel);
   HingeLossKernel() = default;
diff --git a/oneflow/core/kernel/identity_kernel.cpp b/oneflow/core/kernel/identity_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a76b5c8fabc1f191c99aa9a5c01581294af3ff07
--- /dev/null
+++ b/oneflow/core/kernel/identity_kernel.cpp
@@ -0,0 +1,19 @@
+#include "oneflow/core/kernel/identity_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type>
+void IdentityKernel<device_type>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in"));
+}
+
+template<DeviceType device_type>
+void IdentityKernel<device_type>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  BnInOp2Blob(GenDiffBn("in"))->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(GenDiffBn("out")));
+}
+
+ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kParallelCastConf, IdentityKernel);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/identity_kernel.h b/oneflow/core/kernel/identity_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..22eef17ef6e6240d8389adc4d89904e478a1b498
--- /dev/null
+++ b/oneflow/core/kernel/identity_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_IDENTITY_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_IDENTITY_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<DeviceType device_type>
+class IdentityKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(IdentityKernel);
+  IdentityKernel() = default;
+  ~IdentityKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_IDENTITY_KERNEL_H_
diff --git a/oneflow/core/kernel/identity_loss_kernel.cpp b/oneflow/core/kernel/identity_loss_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..02345d8307c15896e21413213907c0f1b28efec3
--- /dev/null
+++ b/oneflow/core/kernel/identity_loss_kernel.cpp
@@ -0,0 +1,39 @@
+#include "oneflow/core/kernel/identity_loss_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename PredType>
+void IdentityLossKernel<device_type, PredType>::InitConstBufBlobs(
+    DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  InitializerConf initializer;
+  initializer.mutable_constant_conf()->set_value(1.0);
+  KernelUtil<device_type, PredType>::InitializeWithConf(ctx, initializer, 0, BnInOp2Blob("ones"));
+}
+
+template<DeviceType device_type, typename PredType>
+void IdentityLossKernel<device_type, PredType>::VirtualLossForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* prediction = BnInOp2Blob("prediction");
+  const Blob* ones = BnInOp2Blob("ones");
+  Blob* loss = BnInOp2Blob("loss");
+  Blob* prediction_diff = BnInOp2Blob(GenDiffBn("prediction"));
+  loss->CopyDataContentFrom(ctx.device_ctx, prediction);
+  prediction_diff->CopyDataContentFrom(ctx.device_ctx, ones);
+}
+
+template<DeviceType device_type, typename PredType>
+const LossKernelConf& IdentityLossKernel<device_type, PredType>::GetLossKernelConf(
+    const KernelConf& kernel_conf) const {
+  return kernel_conf.identity_loss_conf().loss_conf();
+}
+
+template<DeviceType device_type, typename PredType>
+int64_t IdentityLossKernel<device_type, PredType>::CalcLossInstanceNum(
+    const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
+  return BnInOp2Blob("prediction")->shape().elem_cnt();
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kIdentityLossConf, IdentityLossKernel,
+                           FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/identity_loss_kernel.h b/oneflow/core/kernel/identity_loss_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..1f3ae8c62953e2551e1ce22130a321d3bee730e7
--- /dev/null
+++ b/oneflow/core/kernel/identity_loss_kernel.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_KERNEL_IDENTITY_LOSS_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_IDENTITY_LOSS_KERNEL_H_
+
+#include "oneflow/core/kernel/loss_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename PredType>
+class IdentityLossKernel final : public LossKernel<device_type, PredType> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(IdentityLossKernel);
+  IdentityLossKernel() = default;
+  ~IdentityLossKernel() = default;
+
+ private:
+  void VirtualLossForwardDataContent(const KernelCtx&,
+                                     std::function<Blob*(const std::string&)>) const override;
+  const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const override;
+  void InitConstBufBlobs(DeviceCtx* ctx,
+                         std::function<Blob*(const std::string&)> BnInOp2Blob) const;
+  int64_t CalcLossInstanceNum(const KernelCtx& ctx,
+                              const std::function<Blob*(const std::string&)>& BnInOp2Blob) const;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_IDENTITY_LOSS_KERNEL_H_
diff --git a/oneflow/core/kernel/kernel.cpp b/oneflow/core/kernel/kernel.cpp
index 95f0ab8095b9c0f8641723da3e1ad182a5d4aa89..ac473c14cad2b43a8280f3d0910959a973cba740 100644
--- a/oneflow/core/kernel/kernel.cpp
+++ b/oneflow/core/kernel/kernel.cpp
@@ -25,6 +25,44 @@ void ClearBlobDim0ValidNumIfNeed(const PbRpf<std::string>& bns,
   }
 }
 
+void CheckLossInstanceNumField(const PbRpf<std::string>& bns,
+                               const std::function<Blob*(const std::string&)>& BnInOp2Blob,
+                               bool expected) {
+  for (const std::string& bn : bns) {
+    const Blob* blob = BnInOp2Blob(bn);
+    if (blob != nullptr) { CHECK_EQ(blob->has_loss_instance_num_field(), expected); }
+  }
+}
+
+bool NeedCopyLossInstanceNum(const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns,
+                             const std::function<Blob*(const std::string&)>& BnInOp2Blob) {
+  const auto& first_bn_has_loss_instance_num_it =
+      std::find_if(from_bns.cbegin(), from_bns.cend(), [&BnInOp2Blob](const std::string& bn) {
+        const Blob* blob = BnInOp2Blob(bn);
+        return blob != nullptr && blob->has_loss_instance_num_field();
+      });
+  const bool need_copy_loss_instance_num = first_bn_has_loss_instance_num_it != from_bns.end();
+  CheckLossInstanceNumField(from_bns, BnInOp2Blob, need_copy_loss_instance_num);
+  CheckLossInstanceNumField(to_bns, BnInOp2Blob, need_copy_loss_instance_num);
+  return need_copy_loss_instance_num;
+}
+
+void NaiveCopyLossInstanceNum(const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns,
+                              const std::function<Blob*(const std::string&)>& BnInOp2Blob) {
+  CHECK_GT(from_bns.size(), 0);
+  CHECK(BnInOp2Blob(from_bns.Get(0))->has_loss_instance_num_field());
+  const float loss_instance_num = BnInOp2Blob(from_bns.Get(0))->loss_instance_num();
+  const float loss_instance_num_epsilon = 1e-8;
+  FOR_RANGE(int32_t, i, 1, from_bns.size()) {
+    CHECK_LT(std::fabs(BnInOp2Blob(from_bns.Get(i))->loss_instance_num() - loss_instance_num),
+             loss_instance_num_epsilon);
+  }
+  FOR_RANGE(int32_t, i, 0, to_bns.size()) {
+    Blob* blob = BnInOp2Blob(to_bns.Get(i));
+    if (blob != nullptr) { blob->set_loss_instance_num(loss_instance_num); }
+  }
+}
+
 }  // namespace
 
 void Kernel::Init(const ParallelContext* parallel_ctx, const KernelConf& kernel_conf,
@@ -94,6 +132,7 @@ void Kernel::Forward(const KernelCtx& ctx,
     CHECK(!kernel_conf_.need_do_opaque_header());
     ForwardDim0ValidNum(ctx, BnInOp2Blob);
   }
+  if (NeedForwardLossInstanceNum(ctx, BnInOp2Blob)) { ForwardLossInstanceNum(ctx, BnInOp2Blob); }
   if (HasEmptyShapeBlob(op_attribute().input_bns(), BnInOp2Blob) && !NeedForwardIfBlobEmpty()) {
     ClearBlobDim0ValidNumIfNeed(op_attribute().output_bns(), BnInOp2Blob);
     return;
@@ -135,6 +174,7 @@ void Kernel::Backward(const KernelCtx& ctx,
     CHECK(!kernel_conf_.need_do_opaque_header());
     BackwardInDiffDim0ValidNum(ctx, BnInOp2Blob);
   }
+  BackwardInDiffLossInstanceNum(ctx, BnInOp2Blob);
   if (HasEmptyShapeBlob(op_attribute().output_diff_bns(), BnInOp2Blob)
       && !NeedBackwardIfBlobEmpty()) {
     ClearBlobDim0ValidNumIfNeed(op_attribute().input_diff_bns(), BnInOp2Blob);
@@ -204,6 +244,19 @@ void KernelIf<device_type>::ForwardRecordIdInDevicePiece(
             op_attribute().output_bns(), &Blob::CopyRecordIdInDevicePieceFrom);
 }
 
+template<DeviceType device_type>
+void KernelIf<device_type>::ForwardLossInstanceNum(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  NaiveCopyLossInstanceNum(op_attribute().input_bns(), op_attribute().output_bns(), BnInOp2Blob);
+}
+
+template<DeviceType device_type>
+bool KernelIf<device_type>::NeedForwardLossInstanceNum(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  return NeedCopyLossInstanceNum(op_attribute().input_bns(), op_attribute().output_bns(),
+                                 BnInOp2Blob);
+}
+
 template<DeviceType device_type>
 void KernelIf<device_type>::BackwardModelDiffDim0ValidNum(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
@@ -233,6 +286,15 @@ void KernelIf<device_type>::BackwardInDiffDim0ValidNum(
             input_diff_bns, &Blob::CopyDim0ValidNumFrom);
 }
 
+template<DeviceType device_type>
+void KernelIf<device_type>::BackwardInDiffLossInstanceNum(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  CHECK(NeedCopyLossInstanceNum(op_attribute().output_diff_bns(), op_attribute().input_diff_bns(),
+                                BnInOp2Blob));
+  NaiveCopyLossInstanceNum(op_attribute().output_diff_bns(), op_attribute().input_diff_bns(),
+                           BnInOp2Blob);
+}
+
 template<DeviceType device_type>
 void KernelIf<device_type>::ForwardPackedHeader(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
@@ -257,10 +319,10 @@ template<DeviceType device_type, typename T>
 void KernelIfWithModel<device_type, T>::SetTotalInstanceNumDiffBlob(
     const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
   CHECK_GE(this->op_attribute().model_bns().size(), 2);
-  int64_t dim0_valid_num_sum =
-      BnInOp2Blob(this->op_attribute().output_diff_bns(0))->CalcDim0ValidNumSum();
+  const float loss_instance_num =
+      BnInOp2Blob(this->op_attribute().output_diff_bns(0))->loss_instance_num();
   Blob* total_instance_num_diff_blob = BnInOp2Blob("total_instance_num_diff");
-  KernelUtil<device_type, T>::Set(ctx.device_ctx, static_cast<T>(dim0_valid_num_sum),
+  KernelUtil<device_type, T>::Set(ctx.device_ctx, static_cast<T>(loss_instance_num),
                                   total_instance_num_diff_blob->mut_dptr<T>());
 }
 
diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h
index 9a008a640acd11420cbe3efade654c45afaf8d7e..8884b1409b3aebc855dfcb6c235ae44f66d8e3a5 100644
--- a/oneflow/core/kernel/kernel.h
+++ b/oneflow/core/kernel/kernel.h
@@ -76,6 +76,14 @@ class Kernel {
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
     UNIMPLEMENTED();
   }
+  virtual bool NeedForwardLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    UNIMPLEMENTED();
+  }
+  virtual void ForwardLossInstanceNum(const KernelCtx& ctx,
+                                      std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    UNIMPLEMENTED();
+  }
   virtual void ForwardPackedHeader(const KernelCtx& ctx,
                                    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
     UNIMPLEMENTED();
@@ -100,6 +108,10 @@ class Kernel {
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
     UNIMPLEMENTED();
   }
+  virtual void BackwardInDiffLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+    UNIMPLEMENTED();
+  }
   virtual void BackwardModelDiffDim0ValidNum(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
     UNIMPLEMENTED();
@@ -158,7 +170,10 @@ class KernelIf : public Kernel {
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   virtual void ForwardRecordIdInDevicePiece(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
-
+  virtual void ForwardLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  virtual bool NeedForwardLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   virtual void ForwardPackedHeader(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   virtual void BackwardDataId(const KernelCtx& ctx,
@@ -167,6 +182,8 @@ class KernelIf : public Kernel {
                               std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   virtual void BackwardInDiffDim0ValidNum(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  virtual void BackwardInDiffLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   virtual void BackwardModelDiffDim0ValidNum(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   void CopyField(DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob,
diff --git a/oneflow/core/kernel/kernel.proto b/oneflow/core/kernel/kernel.proto
index 7027c63408b716fafa1be8e6c8905734f5eb6bb3..ed19b1817491e0a0d6bd04736f45f9224f560c01 100644
--- a/oneflow/core/kernel/kernel.proto
+++ b/oneflow/core/kernel/kernel.proto
@@ -41,6 +41,14 @@ message SparseCrossEntropyLossKernelConf {
   required LossKernelConf loss_conf = 1;
 }
 
+message SigmoidCrossEntropyLossKernelConf {
+    required LossKernelConf loss_conf = 1;
+}
+
+message IdentityLossKernelConf {
+  required LossKernelConf loss_conf = 1;
+}
+
 message PoolingKernelConf {
   required ShapeProto in = 1;
   required ShapeProto out = 2;
@@ -61,7 +69,7 @@ message MaxPoolingKernelConf {
 }
 
 message ReduceSumKernelConf {
-  optional int32 axis = 1;
+  required ShapeProto kept_dims_shape = 1;
 }
 
 message SoftmaxKernelConf {
@@ -142,6 +150,28 @@ message HingeLossKernelConf {
   required LossKernelConf loss_conf = 1;
 }
 
+message SliceKernelConf {
+  required ShapeProto in_shape = 1;
+}
+
+message ConstantKernelConf {
+  required InitializerConf initializer = 1;
+  required uint32 random_seed = 2;
+}
+
+message GatherKernelConf {
+  required int64 axis = 1;
+}
+
+message VariableKernelConf {
+  required bool is_fw_inplace = 1;
+  required bool is_bw_inplace = 2;
+}
+
+message RecordLoadKernelConf {
+  required int64 device_piece_size = 1;
+}
+
 message KernelConf {
   required OpAttribute op_attribute = 1;
   required bool is_forward = 2;
@@ -177,5 +207,12 @@ message KernelConf {
     ReduceConcatKernelConf reduce_concat_conf = 351;
     ReduceSplitKernelConf reduce_split_conf = 352;
     AccuracyKernelConf accuracy_conf = 401;
+    SliceKernelConf slice_conf = 402;
+    ConstantKernelConf constant_conf = 403;
+    SigmoidCrossEntropyLossKernelConf sigmoid_cross_entropy_loss_conf = 404;
+    IdentityLossKernelConf identity_loss_conf = 405;
+    GatherKernelConf gather_conf = 406;
+    VariableKernelConf variable_conf = 407;
+    RecordLoadKernelConf record_load_conf = 408;
   }
 }
diff --git a/oneflow/core/kernel/kernel_util.cpp b/oneflow/core/kernel/kernel_util.cpp
index 86503c0b940c2d389498ea49a1fb008654177492..ca8eed6f5c7b4bf228d88a7463330005646c8cb2 100644
--- a/oneflow/core/kernel/kernel_util.cpp
+++ b/oneflow/core/kernel/kernel_util.cpp
@@ -13,7 +13,7 @@ void RngUniform(const int64_t elem_cnt, const T min, const T max, uint32_t rando
   CHECK(dptr);
   CHECK_LE(min, max);
   std::mt19937 generator(random_seed);
-  std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, MaxVal<T>()));
+  std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>()));
   for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(generator); }
 }
 
@@ -24,7 +24,7 @@ void RngIntUniform(const int64_t elem_cnt, const T min, const T max, uint32_t ra
   CHECK(dptr);
   CHECK_LE(min, max);
   std::mt19937 generator(random_seed);
-  std::uniform_int_distribution<T> random_distribution(min, std::nextafter(max, MaxVal<T>()));
+  std::uniform_int_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>()));
   for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(generator); }
 }
 
@@ -39,6 +39,24 @@ void RngNormal(const int64_t elem_cnt, const T mean, const T std, uint32_t rando
   for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(generator); }
 }
 
+template<typename T>
+void RngTruncatedNormal(const int64_t elem_cnt, const T std, uint32_t random_seed, T* dptr) {
+  CHECK_GE(elem_cnt, 0);
+  CHECK(dptr);
+  CHECK_GT(std, 0.0);
+  T truncated_value = 2 * std;
+  std::mt19937 generator(random_seed);
+  std::normal_distribution<T> random_distribution(0, std);
+  int64_t index = 0;
+  while (true) {
+    T val = random_distribution(generator);
+    if (std::abs(val) < truncated_value) {
+      dptr[index++] = val;
+      if (index >= elem_cnt) { break; }
+    }
+  }
+}
+
 template<typename T>
 void ConstantInitializer(const T& value, Blob* blob) {
   T* dptr = blob->mut_dptr<T>();
@@ -71,6 +89,14 @@ void RandomNormalInitializer(const RandomNormalInitializerConf& initializer_conf
                static_cast<T>(initializer_conf.std()), random_seed, blob->mut_dptr<T>());
 }
 
+template<typename T>
+void TruncatedNormalInitializer(const TruncatedNormalInitializerConf& initializer_conf,
+                                uint32_t random_seed, Blob* blob) {
+  CHECK(blob->shape().elem_cnt());
+  RngTruncatedNormal<T>(blob->shape().elem_cnt(), static_cast<T>(initializer_conf.std()),
+                        random_seed, blob->mut_dptr<T>());
+}
+
 template<typename T>
 T GenInitialFan(VarianceNorm variance_norm, Blob* blob, const std::string& data_format) {
   T fan = ZeroVal<T>::value;
@@ -116,6 +142,44 @@ void MsraInitializer(const MsraInitializerConf& initializer_conf, uint32_t rando
                blob->mut_dptr<T>());
 }
 
+template<typename T>
+void RangeInitializer(int64_t outer_size, int64_t idx_dim_size, int64_t inner_size, T start,
+                      T stride, T* out) {
+  FOR_RANGE(int64_t, i, 0, outer_size) {
+    FOR_RANGE(int64_t, j, 0, idx_dim_size) {
+      FOR_RANGE(int64_t, k, 0, inner_size) {
+        *(out + i * idx_dim_size * inner_size + j * inner_size + k) = start + j * stride;
+      }
+    }
+  }
+}
+
+template<typename T, typename RangeInitializerConfT>
+void RangeInitializer(const RangeInitializerConfT& initializer_conf, uint32_t random_seed,
+                      Blob* blob) {
+  CHECK_GT(blob->shape().NumAxes(), 0);
+  const int64_t axis = initializer_conf.axis() < 0
+                           ? blob->shape().NumAxes() + initializer_conf.axis()
+                           : initializer_conf.axis();
+  CHECK_GE(axis, 0);
+  CHECK_LT(axis, blob->shape().NumAxes());
+  RangeInitializer<T>(blob->shape().Count(0, axis), blob->shape().At(axis),
+                      blob->shape().Count(axis + 1), static_cast<T>(initializer_conf.start()),
+                      static_cast<T>(initializer_conf.stride()), blob->mut_dptr<T>());
+}
+
+template<typename T>
+void RangeInitializer(const RangeInitializerConf& initializer_conf, uint32_t random_seed,
+                      Blob* blob) {
+  RangeInitializer<T, RangeInitializerConf>(initializer_conf, random_seed, blob);
+}
+
+template<typename T>
+void IntSequenceInitializer(const IntRangeInitializerConf& initializer_conf, uint32_t random_seed,
+                            Blob* blob) {
+  RangeInitializer<T, IntRangeInitializerConf>(initializer_conf, random_seed, blob);
+}
+
 void ComputeOffset(const int32_t num_axes, const int64_t* shape, const int32_t* permutation,
                    std::vector<int64_t>& offset) {
   offset.resize(num_axes);
@@ -157,6 +221,7 @@ void Memcpy<DeviceType::kCPU>(DeviceCtx* ctx, void* dst, const void* src, size_t
 #endif
 
 ) {
+  if (dst == src) { return; }
   memcpy(dst, src, sz);
 }
 
@@ -271,6 +336,15 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num
   in_stream.Read(blob->mut_dptr<char>(), blob_size);
 }
 KU_IF_METHOD Set(DeviceCtx* ctx, const T value, T* addr) { *addr = value; }
+KU_IF_METHOD Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x) {
+  for (int64_t i = 0; i < n; ++i) { y[i] = *x; }
+}
+KU_IF_METHOD AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) {
+  for (int64_t i = 0; i < n; ++i) { z[i] = x[i] + y; }
+}
+KU_IF_METHOD MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) {
+  for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y; }
+}
 
 #define KU_FLOATING_METHOD \
   template<typename T>     \
@@ -307,6 +381,19 @@ KU_FLOATING_METHOD Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order,
                         const int ldc) {
   cblas_gemm<T>(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
 }
+KU_FLOATING_METHOD BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order,
+                               const enum CBLAS_TRANSPOSE trans_a,
+                               const enum CBLAS_TRANSPOSE trans_b, int batch_size, int m, int n,
+                               int k, const T alpha, const T* a, const T* b, const T beta, T* c,
+                               T** buf) {
+  const int a_stride = m * k;
+  const int b_stride = k * n;
+  const int c_stride = m * n;
+  FOR_RANGE(int32_t, i, 0, batch_size) {
+    KernelUtil<DeviceType::kCPU, T>::OFGemm(ctx, trans_a, trans_b, m, n, k, alpha, a + i * a_stride,
+                                            b + i * b_stride, beta, c + i * c_stride);
+  }
+}
 
 KU_FLOATING_METHOD Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
   for (int64_t i = 0; i < n; ++i) { y[i] = std::exp(x[i]); }
@@ -314,12 +401,27 @@ KU_FLOATING_METHOD Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
 KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha) {
   for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / (*alpha); }
 }
-KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
-  for (int64_t i = 0; i < n; ++i) { z[i] = x[i] / y[i]; }
+KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha) {
+  for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / alpha; }
 }
 KU_FLOATING_METHOD Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
   for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y[i]; }
 }
+KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
+  for (int64_t i = 0; i < n; ++i) { z[i] = x[i] / y[i]; }
+}
+KU_FLOATING_METHOD Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
+  for (int64_t i = 0; i < n; ++i) { y[i] = x[i] * x[i]; }
+}
+KU_FLOATING_METHOD Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
+  for (int64_t i = 0; i < n; ++i) { y[i] = std::sqrt(x[i]); }
+}
+KU_FLOATING_METHOD MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
+  for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y[0]; }
+}
+KU_FLOATING_METHOD Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y) {
+  for (int64_t i = 0; i < n; ++i) { y[i] = static_cast<T>(1.0) / x[i]; }
+}
 KU_FLOATING_METHOD Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon) {
   for (int64_t i = 0; i < n; ++i) { x[i] = 1.0 / std::sqrt(x[i] + epsilon); }
 }
@@ -411,10 +513,14 @@ KU_FLOATING_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& ini
     RandomUniformInitializer<T>(initializer_conf.random_uniform_conf(), random_seed, blob);
   } else if (initializer_conf.has_random_normal_conf()) {
     RandomNormalInitializer<T>(initializer_conf.random_normal_conf(), random_seed, blob);
+  } else if (initializer_conf.has_truncated_normal_conf()) {
+    TruncatedNormalInitializer<T>(initializer_conf.truncated_normal_conf(), random_seed, blob);
   } else if (initializer_conf.has_xavier_conf()) {
     XavierInitializer<T>(initializer_conf.xavier_conf(), random_seed, blob, data_format);
   } else if (initializer_conf.has_msra_conf()) {
     MsraInitializer<T>(initializer_conf.msra_conf(), random_seed, blob, data_format);
+  } else if (initializer_conf.has_range_conf()) {
+    RangeInitializer<T>(initializer_conf.range_conf(), random_seed, blob);
   } else {
     UNIMPLEMENTED();
   }
@@ -438,6 +544,8 @@ KU_INTEGRAL_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& ini
     ConstantInitializer<T>(static_cast<T>(initializer_conf.constant_int_conf().value()), blob);
   } else if (initializer_conf.has_random_uniform_int_conf()) {
     RandomIntUniformInitializer<T>(initializer_conf.random_uniform_int_conf(), random_seed, blob);
+  } else if (initializer_conf.has_int_range_conf()) {
+    IntSequenceInitializer<T>(initializer_conf.int_range_conf(), random_seed, blob);
   } else {
     UNIMPLEMENTED();
   }
diff --git a/oneflow/core/kernel/kernel_util.cu b/oneflow/core/kernel/kernel_util.cu
index b0ef96d5623bdeebf6dc001fa6b06da24527d1e0..bc2677ce1fd68887624a600f741f601a59ccabd8 100644
--- a/oneflow/core/kernel/kernel_util.cu
+++ b/oneflow/core/kernel/kernel_util.cu
@@ -20,15 +20,60 @@ __global__ void ExpGpu(const int64_t n, const T* x, T* y) {
 }
 
 template<typename T>
-__global__ void DivGpu(const int64_t n, T* x, const T* alpha_ptr) {
+__global__ void DivByConstParaPtrGpu(const int64_t n, T* x, const T* alpha_ptr) {
   CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / (*alpha_ptr); }
 }
 
+template<typename T>
+__global__ void DivGpu(const int64_t n, const T* x, const T* y, T* z) {
+  CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] / y[i]; }
+}
+
+template<typename T>
+__global__ void DivByConstParaGpu(const int64_t n, T* x, const T alpha) {
+  CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / alpha; }
+}
+
+template<typename T>
+__global__ void ReplicateGpu(const int64_t n, T* y, const T* x) {
+  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = *x; }
+}
+
 template<typename T>
 __global__ void MulGpu(const int64_t n, const T* x, const T* y, T* z) {
   CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y[i]; }
 }
 
+template<typename T>
+__global__ void MulByScalarGpu(const int64_t n, const T* x, const T* y, T* z) {
+  CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y[0]; }
+}
+
+template<typename T>
+__global__ void AddByScalarGpu(const int64_t n, const T* x, const T y, T* z) {
+  CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] + y; }
+}
+
+template<typename T>
+__global__ void MulByScalarParaGpu(const int64_t n, const T* x, const T y, T* z) {
+  CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y; }
+}
+
+template<typename T>
+__global__ void ReciprocalGpu(const int64_t n, const T* x, T* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = static_cast<T>(1.0) / x[i]; }
+}
+
+template<typename T>
+__global__ void SquareGpu(const int64_t n, const T* x, T* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * x[i]; }
+}
+
+template<typename T>
+__global__ void SqrtGpu(const int64_t n, const T* x, T* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) { y[i] = std::sqrt(x[i]); }
+}
+
 template<typename T>
 __global__ void AxpyGpu(const int n, const T alpha, const T* x, const int incx, T* y,
                         const int incy) {
@@ -142,37 +187,11 @@ cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) {
   return cublas_trans;
 }
 
-const int32_t kMaxDim = OF_PP_SEQ_SIZE(DIM_SEQ);
-
+template<int32_t NDIMS>
 struct Int32Array {
-  int32_t val[kMaxDim];
-};
-
-struct Int64Array {
-  int64_t val[kMaxDim];
+  int32_t val[NDIMS];
 };
 
-__device__ void ComputeOffset(const int32_t num_axis, const int64_t* x_dims,
-                              const int32_t* permutation, int64_t* x_strides) {
-  int64_t buff[kMaxDim];
-  int64_t cur_stride = 1;
-  for (int32_t i = num_axis - 1; i >= 0; --i) {
-    buff[i] = cur_stride;
-#if __CUDA_ARCH__ >= 350
-    cur_stride *= __ldg(x_dims + i);
-#else
-    cur_stride *= x_dims[i];
-#endif
-  }
-  for (int32_t i = 0; i < num_axis; ++i) {
-#if __CUDA_ARCH__ >= 350
-    x_strides[i] = buff[__ldg(permutation + i)];
-#else
-    x_strides[i] = buff[permutation[i]];
-#endif
-  }
-}
-
 template<typename T>
 __global__ void CopyColsRegionGpu(const int64_t row_num, const int64_t col_num, const T* x,
                                   const int64_t x_col_offset, const int64_t x_lda, T* y,
@@ -184,35 +203,29 @@ __global__ void CopyColsRegionGpu(const int64_t row_num, const int64_t col_num,
   }
 }
 
-__device__ int64_t GetXIndex(const int32_t num_axis, const int64_t* y_shape,
-                             const int64_t* x_strides, int64_t y_idx) {
-  int64_t x_idx = 0;
-  for (int32_t i = num_axis - 1; i >= 0 && y_idx > 0; --i) {
+template<int32_t NDIMS>
+__device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, int32_t y_idx) {
+  int32_t x_idx = 0;
+  for (int32_t i = NDIMS - 1; i >= 0; --i) {
     x_idx += (y_idx % y_shape[i]) * x_strides[i];
     y_idx /= y_shape[i];
   }
   return x_idx;
 }
 
-template<typename T>
-__global__ void TransposeGpu(const int32_t num_axis, const Int64Array x_shape,
-                             const Int64Array y_shape, const Int32Array permutation,
-                             const int64_t elem_cnt, const T* x, T* y) {
-  __shared__ int64_t x_strides[kMaxDim];
-  __shared__ int64_t x_dims_shared[kMaxDim];
-  __shared__ int64_t y_dims_shared[kMaxDim];
-  __shared__ int32_t perm_shared[kMaxDim];
+template<int32_t NDIMS, typename T>
+__global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<NDIMS> x_strides,
+                             const int32_t elem_cnt, const T* x, T* y) {
+  __shared__ int32_t x_strides_shared[NDIMS];
+  __shared__ int32_t y_dims_shared[NDIMS];
   const int32_t tid = threadIdx.x;
-  if (tid < num_axis) {
-    x_dims_shared[tid] = x_shape.val[tid];
+  if (tid < NDIMS) {
     y_dims_shared[tid] = y_shape.val[tid];
-    perm_shared[tid] = permutation.val[tid];
+    x_strides_shared[tid] = x_strides.val[tid];
   }
   __syncthreads();
-  if (tid == 0) { ComputeOffset(num_axis, x_dims_shared, perm_shared, x_strides); }
-  __syncthreads();
   CUDA_1D_KERNEL_LOOP(y_idx, elem_cnt) {
-    const int64_t x_idx = GetXIndex(num_axis, y_dims_shared, x_strides, y_idx);
+    const int32_t x_idx = GetXIndex<NDIMS>(y_dims_shared, x_strides_shared, y_idx);
 #if __CUDA_ARCH__ >= 350
     y[y_idx] = __ldg(x + x_idx);
 #else
@@ -221,6 +234,32 @@ __global__ void TransposeGpu(const int32_t num_axis, const Int64Array x_shape,
   }
 }
 
+template<int32_t NDIMS, typename T>
+void Transpose(DeviceCtx* ctx, const Shape& x_shape, const Shape& y_shape,
+               const PbRf<int32_t>& permutation, const int64_t elem_cnt, const T* x, T* y) {
+  CHECK_LE(y_shape.elem_cnt(), MaxVal<int32_t>::value);
+  Int32Array<NDIMS> y_shape_struct;
+  FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); }
+  Int32Array<NDIMS> x_strides;
+  int32_t buff[NDIMS];
+  int32_t cur_stride = 1;
+  for (int32_t i = NDIMS - 1; i >= 0; --i) {
+    buff[i] = cur_stride;
+    cur_stride *= x_shape.At(i);
+  }
+  for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; }
+  TransposeGpu<NDIMS, T>
+      <<<SMBlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          y_shape_struct, x_strides, elem_cnt, x, y);
+}
+
+template<typename T>
+struct TransposeUtil final {
+#define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T>
+  DEFINE_STATIC_SWITCH_FUNC(void, Transpose, MAKE_TRANSPOSE_SWITCH_ENTRY,
+                            MAKE_NDIM_CTRV_SEQ(DIM_SEQ));
+};
+
 template<typename T, T (*reduce_core_func)(const T, const T)>
 __device__ void MatrixShrinkCols(const size_t row_num, const size_t thread_col_num, const T* x,
                                  const size_t x_col_num, const size_t x_lda, T* y,
@@ -274,11 +313,25 @@ void MatrixRowReduce(DeviceCtx* ctx, const size_t row_num, const size_t col_num,
          ctx->cuda_stream()>>>(row_num, col_num, x, y, static_cast<T*>(temp_storage), temp_col_num);
 }
 
+template<typename T>
+__global__ void AssignStridedAddrGpu(T** dev_ptrs, T* start_ptr, int32_t stride_len,
+                                     int32_t stride_num) {
+  CUDA_1D_KERNEL_LOOP(i, stride_num) { dev_ptrs[i] = start_ptr + i * stride_len; }
+}
+
+template<typename T>
+void AssignStridedAddr(DeviceCtx* ctx, T** dev_ptrs, T* start_ptr, int stride_len, int stride_num) {
+  AssignStridedAddrGpu<T>
+      <<<BlocksNum4ThreadsNum(stride_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          dev_ptrs, start_ptr, stride_len, stride_num);
+}
+
 }  // namespace
 
 template<>
 void Memcpy<DeviceType::kGPU>(DeviceCtx* ctx, void* dst, const void* src, size_t sz,
                               cudaMemcpyKind kind) {
+  if (dst == src) { return; }
   CudaCheck(cudaMemcpyAsync(dst, src, sz, kind, ctx->cuda_stream()));
 }
 
@@ -299,21 +352,6 @@ size_t GetTmpSizeForReduceSum(DataType data_type, int64_t sum_elem_num) {
 
 #undef MAKE_CUB_DEVICE_REDUCE_SWITCH_ENTRY
 
-// create temporary host blob store initializer result
-#define BEFORE_CPU_INITIALIZE()                                     \
-  RtBlobDesc blob_desc(blob->blob_desc().blob_desc_proto());        \
-  char* host_raw_dptr = nullptr;                                    \
-  CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); \
-  std::unique_ptr<Blob> host_blob;                                  \
-  host_blob.reset(new Blob(nullptr, &blob_desc, host_raw_dptr));
-
-// asynchronous copy to device
-#define AFTER_CPU_INITIALIZE()                                                          \
-  Memcpy<DeviceType::kGPU>(ctx, blob->mut_dptr(), host_blob->dptr(),                    \
-                           blob->ByteSizeOfDataContentField(), cudaMemcpyHostToDevice); \
-  CudaCheck(cudaStreamSynchronize(ctx->cuda_stream()));                                 \
-  CudaCheck(cudaFreeHost(host_raw_dptr));
-
 #define KU_IF_METHOD                     \
   template<typename T, typename Derived> \
   void GpuKernelUtilIf<T, Derived>::
@@ -343,21 +381,15 @@ KU_IF_METHOD RowSum(DeviceCtx* ctx, const int64_t row_num, const int64_t col_num
                     void* temp_storage, const size_t temp_storage_bytes) {
   MatrixRowReduce<T, ReduceCoreAdd>(ctx, row_num, col_num, x, y, temp_storage, temp_storage_bytes);
 }
+
 KU_IF_METHOD Transpose(DeviceCtx* ctx, const int32_t num_axis, const Shape& x_shape,
                        const Shape& y_shape, const PbRf<int32_t>& permutation,
                        const int64_t elem_cnt, const T* x, T* y) {
-  CHECK_LE(num_axis, kMaxDim);
-  Int64Array x_shape_struct;
-  Int64Array y_shape_struct;
-  Int32Array perm_struct;
-  FOR_RANGE(int32_t, i, 0, num_axis) {
-    x_shape_struct.val[i] = x_shape.At(i);
-    y_shape_struct.val[i] = y_shape.At(i);
-    perm_struct.val[i] = permutation[i];
-  }
-  TransposeGpu<T>
-      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-          num_axis, x_shape_struct, y_shape_struct, perm_struct, elem_cnt, x, y);
+  CHECK_LE(y_shape.elem_cnt(), MaxVal<int32_t>::value);
+  CHECK_EQ(num_axis, y_shape.NumAxes());
+  CHECK_EQ(num_axis, x_shape.NumAxes());
+  TransposeUtil<T>::SwitchTranspose(SwitchCase(num_axis), ctx, x_shape, y_shape, permutation,
+                                    elem_cnt, x, y);
 }
 
 KU_IF_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf,
@@ -388,6 +420,18 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num
 KU_IF_METHOD Set(DeviceCtx* ctx, const T value, T* addr) {
   gpu_set<T><<<1, 1, 0, ctx->cuda_stream()>>>(value, addr);
 }
+KU_IF_METHOD Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x) {
+  ReplicateGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, y, x);
+}
+KU_IF_METHOD AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) {
+  AddByScalarGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z);
+}
+KU_IF_METHOD MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) {
+  MulByScalarParaGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z);
+}
 
 #define KU_FLOATING_METHOD \
   template<typename T>     \
@@ -432,18 +476,75 @@ KU_FLOATING_METHOD Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order,
   cublas_gemm<T>(ctx->cublas_pmh_handle(), cublas_trans_b, cublas_trans_a, n, m, k, &alpha, b, ldb,
                  a, lda, &beta, c, ldc);
 }
+KU_FLOATING_METHOD BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order,
+                               const enum CBLAS_TRANSPOSE trans_a,
+                               const enum CBLAS_TRANSPOSE trans_b, int batch_size, int m, int n,
+                               int k, const T alpha, const T* a, const T* b, const T beta, T* c,
+                               T** buf) {
+  const int a_stride = m * k;
+  const int b_stride = k * n;
+  const int c_stride = m * n;
+  const int lda = (trans_a == CblasNoTrans) ? k : m;
+  const int ldb = (trans_b == CblasNoTrans) ? n : k;
+  const int ldc = n;
+  cublasOperation_t cublas_trans_a = CblasTrans2CublasTrans(trans_a);
+  cublasOperation_t cublas_trans_b = CblasTrans2CublasTrans(trans_b);
+  T** dev_a_ptrs = buf;
+  T** dev_b_ptrs = buf + batch_size;
+  T** dev_c_ptrs = buf + 2 * batch_size;
+  AssignStridedAddr<T>(ctx, dev_a_ptrs, const_cast<T*>(a), a_stride, batch_size);
+  AssignStridedAddr<T>(ctx, dev_b_ptrs, const_cast<T*>(b), b_stride, batch_size);
+  AssignStridedAddr<T>(ctx, dev_c_ptrs, c, c_stride, batch_size);
+#if CUDA_VERSION >= 9010
+  cudaDataType_t data_type = CudaDataType<T>::value;
+  cublasGemmBatchedEx(ctx->cublas_pmh_handle(), cublas_trans_b, cublas_trans_a, n, m, k,
+                      reinterpret_cast<const void*>(&alpha),
+                      reinterpret_cast<const void**>(const_cast<const T**>(dev_b_ptrs)), data_type,
+                      ldb, reinterpret_cast<const void**>(const_cast<const T**>(dev_a_ptrs)),
+                      data_type, lda, reinterpret_cast<const void*>(&beta),
+                      reinterpret_cast<void**>(dev_c_ptrs), data_type, ldc, batch_size, data_type,
+                      CUBLAS_GEMM_DEFAULT);
+#else
+  cublas_gemmBatched<T>(ctx->cublas_pmh_handle(), cublas_trans_b, cublas_trans_a, n, m, k, &alpha,
+                        const_cast<const T**>(dev_b_ptrs), ldb, const_cast<const T**>(dev_a_ptrs),
+                        lda, &beta, dev_c_ptrs, ldc, batch_size);
+#endif
+}
 
 KU_FLOATING_METHOD Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
   ExpGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y);
 }
 KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha) {
-  DivGpu<T>
+  DivByConstParaPtrGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, alpha);
+}
+KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha) {
+  DivByConstParaGpu<T>
       <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, alpha);
 }
+KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
+  DivGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z);
+}
 KU_FLOATING_METHOD Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
   MulGpu<T>
       <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z);
 }
+KU_FLOATING_METHOD MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) {
+  MulByScalarGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z);
+}
+KU_FLOATING_METHOD Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y) {
+  ReciprocalGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y);
+}
+KU_FLOATING_METHOD Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
+  SquareGpu<T>
+      <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y);
+}
+KU_FLOATING_METHOD Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
+  SqrtGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y);
+}
 KU_FLOATING_METHOD Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon) {
   RsqrtGpu<T>
       <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, epsilon);
@@ -594,4 +695,28 @@ __device__ double gpu_atomic_max(double* address, const double val) {
   return __longlong_as_double(old);
 }
 
+template<typename T, typename U>
+__global__ void CastOnGpu(const T* in, U* out, int64_t elem_num) {
+  CUDA_1D_KERNEL_LOOP(i, elem_num) { out[i] = static_cast<U>(in[i]); }
+}
+
+template<typename T, typename U>
+void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num) {
+  if (std::is_same<T, U>::value) {
+    Memcpy<DeviceType::kGPU>(ctx, out_dptr, in_dptr, elem_num * sizeof(T));
+  } else {
+    CastOnGpu<T, U>
+        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            in_dptr, out_dptr, elem_num);
+  }
+}
+
+#define INSTANTIATE_COPY_ELEM_ON_GPU(T, U) \
+  template void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num);
+
+#define MAKE_COPY_ELEM_ON_GPU_ENTRY(TPair, UPair) \
+  INSTANTIATE_COPY_ELEM_ON_GPU(OF_PP_PAIR_FIRST(TPair), OF_PP_PAIR_FIRST(UPair))
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_COPY_ELEM_ON_GPU_ENTRY, POD_DATA_TYPE_SEQ, POD_DATA_TYPE_SEQ)
+
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/kernel_util.h b/oneflow/core/kernel/kernel_util.h
index 158c2dc0efc71a1c28a5267193d071bdb2340892..ebab13fe3e7147eec9cb137aba44f65097e26f1d 100644
--- a/oneflow/core/kernel/kernel_util.h
+++ b/oneflow/core/kernel/kernel_util.h
@@ -41,7 +41,7 @@ template<DeviceType device_type>
 void Memset(DeviceCtx*, void* dst, const char value, size_t sz);
 
 #if defined(__CUDACC__)
-#define OF_DEVICE_FUNC __device__
+#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
 #else
 #define OF_DEVICE_FUNC
 #endif
@@ -88,6 +88,14 @@ struct KernelUtilIf {
            c->mut_dptr<T>());
   }
 
+  static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a,
+                            enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m,
+                            const int n, const int k, const T alpha, const T* a, const T* b,
+                            const T beta, T* c, T** buf) {
+    Derived::BatchedGemm(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b,
+                         beta, c, buf);
+  }
+
   static void InitializeWithProperConf(DeviceCtx* ctx, const InitializerConf* initializer_conf,
                                        uint32_t random_seed, Blob* blob,
                                        const std::string& data_format = "") {
@@ -141,6 +149,9 @@ struct CpuKernelUtilIf {
                                 const std::string& bn_in_op, int32_t dim_num,
                                 int64_t num_in_each_dim);
   static void Set(DeviceCtx* ctx, const T value, T* addr);
+  static void Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x);
+  static void AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z);
+  static void MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z);
 };
 
 // CPU, Floating
@@ -162,11 +173,20 @@ struct KernelUtil<DeviceType::kCPU, T, typename std::enable_if<IsFloating<T>::va
                    const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
                    const T alpha, const T* a, const int lda, const T* b, const int ldb,
                    const T beta, T* c, const int ldc);
+  static void BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order,
+                          const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
+                          int batch_size, int m, int n, int k, const T alpha, const T* a,
+                          const T* b, const T beta, T* c, T** buf);
 
   static void Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
   static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha);
+  static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha);
   static void Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
   static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
+  static void MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
+  static void Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y);
+  static void Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
+  static void Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
   static void Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon);
   static void Powx(DeviceCtx* ctx, const int64_t n, const T* x, const float power, T* y);
 
@@ -247,6 +267,9 @@ struct GpuKernelUtilIf {
                                 const std::string& bn_in_op, int32_t dim_num,
                                 int64_t num_in_each_dim);
   static void Set(DeviceCtx* ctx, const T value, T* addr);
+  static void Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x);
+  static void AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z);
+  static void MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z);
 };
 
 // GPU, Floating
@@ -270,10 +293,20 @@ struct KernelUtil<DeviceType::kGPU, T, typename std::enable_if<IsFloating<T>::va
                    const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
                    const T alpha, const T* a, const int lda, const T* b, const int ldb,
                    const T beta, T* c, const int ldc);
+  static void BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order,
+                          const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
+                          int batch_size, int m, int n, int k, const T alpha, const T* a,
+                          const T* b, const T beta, T* c, T** buf);
 
   static void Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
   static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha);
+  static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha);
+  static void Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
   static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
+  static void MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
+  static void Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y);
+  static void Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
+  static void Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y);
   static void Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon);
 
   static void Sigmoid(DeviceCtx* ctx, int64_t n, const T* x, T* y);
@@ -495,6 +528,24 @@ typename std::enable_if<!std::is_same<T, U>::value>::type CopyElem(const T* in_d
   FOR_RANGE(int64_t, i, 0, elem_num) { *(out_dptr++) = static_cast<U>(*(in_dptr++)); }
 }
 
+template<typename T, typename U>
+void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num);
+
+// create temporary host blob store initializer result
+#define BEFORE_CPU_INITIALIZE()                                     \
+  RtBlobDesc blob_desc(blob->blob_desc().blob_desc_proto());        \
+  char* host_raw_dptr = nullptr;                                    \
+  CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); \
+  std::unique_ptr<Blob> host_blob;                                  \
+  host_blob.reset(new Blob(nullptr, &blob_desc, host_raw_dptr));
+
+// asynchronous copy to device
+#define AFTER_CPU_INITIALIZE()                                                          \
+  Memcpy<DeviceType::kGPU>(ctx, blob->mut_dptr(), host_blob->dptr(),                    \
+                           blob->ByteSizeOfDataContentField(), cudaMemcpyHostToDevice); \
+  CudaCheck(cudaStreamSynchronize(ctx->cuda_stream()));                                 \
+  CudaCheck(cudaFreeHost(host_raw_dptr));
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_
diff --git a/oneflow/core/kernel/layer_norm_kernel.cpp b/oneflow/core/kernel/layer_norm_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..35d28c095071ce7ac324182c4b83e0cd4408bcbe
--- /dev/null
+++ b/oneflow/core/kernel/layer_norm_kernel.cpp
@@ -0,0 +1,171 @@
+#include "oneflow/core/kernel/layer_norm_kernel.h"
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+
+namespace oneflow {
+
+namespace {
+
+InitializerConf ConstantInitializerConf(float val) {
+  InitializerConf conf;
+  conf.mutable_constant_conf()->set_value(val);
+  return conf;
+}
+
+InitializerConf OnesInitializerConf() { return ConstantInitializerConf(1.0f); }
+
+InitializerConf ZerosInitializerConf() { return ConstantInitializerConf(0.0f); }
+
+}  // namespace
+
+template<DeviceType device_type, typename T>
+void LayerNormKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const LayerNormOpConf& conf = this->op_conf().layer_norm_conf();
+  const Blob* in = BnInOp2Blob("in");
+  const Blob* bn_scale = BnInOp2Blob("cudnn_bn_scale_ones");
+  const Blob* bn_bias = BnInOp2Blob("cudnn_bn_bias_zeros");
+  Blob* out = BnInOp2Blob("out");
+  Blob* normalize_out = conf.scale() ? BnInOp2Blob("normalize_out") : out;
+  Blob* mean = BnInOp2Blob("cudnn_bn_mean");
+  Blob* inv_variance = BnInOp2Blob("cudnn_bn_inv_variance");
+  LayerNormKernelUtil<device_type, T>::NormalizeForward(
+      ctx.device_ctx, in, bn_scale, bn_bias, conf.epsilon(), normalize_out, mean, inv_variance);
+  if (conf.scale()) {
+    const Blob* gamma = BnInOp2Blob("gamma");
+    const int64_t m = gamma->shape().elem_cnt();
+    CHECK_EQ(out->shape().elem_cnt() % m, 0);
+    const int64_t n = out->shape().elem_cnt() / m;
+    NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+        ctx.device_ctx, XpuVarNdarray<T>({n, m}, out->mut_dptr<T>()),
+        XpuVarNdarray<const T>({n, m}, normalize_out->dptr<T>()),
+        XpuVarNdarray<const T>({1, m}, gamma->dptr<T>()));
+  }
+  if (conf.center()) {
+    const Blob* beta = BnInOp2Blob("beta");
+    const int64_t m = beta->shape().elem_cnt();
+    CHECK_EQ(out->shape().elem_cnt() % m, 0);
+    const int64_t n = out->shape().elem_cnt() / m;
+    NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncAdd>(
+        ctx.device_ctx, XpuVarNdarray<T>({n, m}, out->mut_dptr<T>()),
+        XpuVarNdarray<const T>({n, m}, out->dptr<T>()),
+        XpuVarNdarray<const T>({1, m}, beta->dptr<T>()));
+  }
+}
+
+template<DeviceType device_type, typename T>
+void LayerNormKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const LayerNormOpConf& conf = this->op_conf().layer_norm_conf();
+  const Blob* out_diff = BnInOp2Blob(GenDiffBn("out"));
+  Blob* in_diff = BnInOp2Blob(GenDiffBn("in"));
+  Blob* bw_buf = BnInOp2Blob("bw_reduce_buf");
+  if (conf.center() && this->op_conf().trainable()) {
+    Blob* beta_diff = BnInOp2Blob(GenDiffBn("beta"));
+    const int64_t m = beta_diff->shape().elem_cnt();
+    CHECK_EQ(out_diff->shape().elem_cnt() % m, 0);
+    const int64_t n = out_diff->shape().elem_cnt() / m;
+    NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx,
+                                           XpuVarNdarray<T>({1, m}, beta_diff->mut_dptr<T>()),
+                                           XpuVarNdarray<const T>({n, m}, out_diff->dptr<T>()),
+                                           XpuVarNdarray<T>({n, m}, bw_buf->mut_dptr<T>()));
+  }
+  if (conf.scale()) {
+    Blob* normalize_out = BnInOp2Blob("normalize_out");
+    const Blob* gamma = BnInOp2Blob("gamma");
+    Blob* gamma_diff = BnInOp2Blob(GenDiffBn("gamma"));
+    const int64_t m = gamma_diff->shape().elem_cnt();
+    CHECK_EQ(out_diff->shape().elem_cnt() % m, 0);
+    const int64_t n = out_diff->shape().elem_cnt() / m;
+    if (this->op_conf().trainable()) {
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+          ctx.device_ctx, XpuVarNdarray<T>({n, m}, bw_buf->mut_dptr<T>()),
+          XpuVarNdarray<const T>({n, m}, normalize_out->dptr<T>()),
+          XpuVarNdarray<const T>({n, m}, out_diff->dptr<T>()));
+      NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx,
+                                             XpuVarNdarray<T>({1, m}, gamma_diff->mut_dptr<T>()),
+                                             XpuVarNdarray<const T>({n, m}, bw_buf->dptr<T>()),
+                                             XpuVarNdarray<T>({n, m}, bw_buf->mut_dptr<T>()));
+    }
+    if (in_diff) {
+      NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>(
+          ctx.device_ctx, XpuVarNdarray<T>({n, m}, normalize_out->mut_dptr<T>()),
+          XpuVarNdarray<const T>({n, m}, out_diff->dptr<T>()),
+          XpuVarNdarray<const T>({1, m}, gamma->dptr<T>()));
+    }
+  }
+  if (in_diff) {
+    const Blob* in = BnInOp2Blob("in");
+    const Blob* normalize_out_diff = conf.scale() ? BnInOp2Blob("normalize_out") : out_diff;
+    const Blob* bn_scale = BnInOp2Blob("cudnn_bn_scale_ones");
+    const Blob* mean = BnInOp2Blob("cudnn_bn_mean");
+    const Blob* inv_variance = BnInOp2Blob("cudnn_bn_inv_variance");
+    Blob* bn_scale_diff = BnInOp2Blob("cudnn_bn_scale_diff_buf");
+    Blob* bn_bias_diff = BnInOp2Blob("cudnn_bn_bias_diff_buf");
+    LayerNormKernelUtil<device_type, T>::NormalizeBackward(
+        ctx.device_ctx, in, bn_scale, mean, inv_variance, normalize_out_diff, conf.epsilon(),
+        in_diff, bn_scale_diff, bn_bias_diff);
+  }
+}
+
+template<DeviceType device_type, typename T>
+void LayerNormKernel<device_type, T>::InitModelBlobsWithRandomSeed(
+    DeviceCtx* ctx, std::mt19937*, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const LayerNormOpConf& conf = this->op_conf().layer_norm_conf();
+  if (conf.scale()) {
+    InitializerConf ones_initializer = OnesInitializerConf();
+    KernelUtil<device_type, T>::InitializeWithConf(ctx, ones_initializer, 0, BnInOp2Blob("gamma"));
+  }
+  if (conf.center()) {
+    InitializerConf zeros_initializer = ZerosInitializerConf();
+    KernelUtil<device_type, T>::InitializeWithConf(ctx, zeros_initializer, 0, BnInOp2Blob("beta"));
+  }
+}
+
+template<DeviceType device_type, typename T>
+void LayerNormKernel<device_type, T>::InitModelBlobsWithDir(
+    DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const LayerNormOpConf& conf = this->op_conf().layer_norm_conf();
+  if (conf.scale()) {
+    Blob* gamma = BnInOp2Blob("gamma");
+    KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, gamma,
+                                                  "gamma", gamma->shape().At(0),
+                                                  gamma->shape().Count(1));
+  }
+  if (conf.center()) {
+    Blob* beta = BnInOp2Blob("beta");
+    KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, beta,
+                                                  "beta", beta->shape().At(0),
+                                                  beta->shape().Count(1));
+  }
+}
+
+template<DeviceType device_type, typename T>
+void LayerNormKernel<device_type, T>::InitConstBufBlobs(
+    DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  InitializerConf ones_initializer = OnesInitializerConf();
+  KernelUtil<device_type, T>::InitializeWithConf(ctx, ones_initializer, 0,
+                                                 BnInOp2Blob("cudnn_bn_scale_ones"));
+  InitializerConf zeros_initializer = ZerosInitializerConf();
+  KernelUtil<device_type, T>::InitializeWithConf(ctx, zeros_initializer, 0,
+                                                 BnInOp2Blob("cudnn_bn_bias_zeros"));
+}
+
+template<typename T>
+struct LayerNormKernelUtil<DeviceType::kCPU, T> {
+  static void NormalizeForward(const DeviceCtx* ctx, const Blob* in, const Blob* scale,
+                               const Blob* bias, double epsilon, Blob* out, Blob* mean,
+                               Blob* inv_variance) {
+    UNIMPLEMENTED();
+  }
+  static void NormalizeBackward(const DeviceCtx* ctx, const Blob* in, const Blob* scale,
+                                const Blob* mean, const Blob* inv_variance, const Blob* out_diff,
+                                double epsilon, Blob* in_diff, Blob* scale_diff, Blob* bias_diff) {
+    UNIMPLEMENTED();
+  }
+};
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kLayerNormConf, LayerNormKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/layer_norm_kernel.cu b/oneflow/core/kernel/layer_norm_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..5fe0da7e86ec0a746aa2a256c7ff049983175b3a
--- /dev/null
+++ b/oneflow/core/kernel/layer_norm_kernel.cu
@@ -0,0 +1,85 @@
+#include "oneflow/core/kernel/layer_norm_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+class LayerNormCudnnBnCtx final {
+ public:
+  LayerNormCudnnBnCtx(const Shape& data_shape, const Shape& param_shape, DataType data_type) {
+    const int64_t cudnn_c = param_shape.elem_cnt();
+    CHECK_EQ(data_shape.elem_cnt() % cudnn_c, 0);
+    const int64_t cudnn_w = data_shape.elem_cnt() / cudnn_c;
+    CHECK_LT(cudnn_c, MaxVal<int32_t>::value);
+    CHECK_LT(cudnn_w, MaxVal<int32_t>::value);
+    data_tensor_desc_.reset(new CudnnTensorDesc(CUDNN_TENSOR_NCHW, data_type, 1,
+                                                static_cast<int32_t>(cudnn_c), 1,
+                                                static_cast<int32_t>(cudnn_w)));
+    param_tensor_desc_.reset(
+        new CudnnTensorDesc(CUDNN_TENSOR_NCHW, data_type, 1, static_cast<int32_t>(cudnn_c), 1, 1));
+#if (CUDNN_VERSION >= 7000)
+    mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
+#else
+    mode_ = CUDNN_BATCHNORM_SPATIAL;
+#endif
+  }
+  ~LayerNormCudnnBnCtx() = default;
+
+  const cudnnTensorDescriptor_t& data_tensor_desc() const { return data_tensor_desc_->Get(); }
+  const cudnnTensorDescriptor_t& param_tensor_desc() const { return param_tensor_desc_->Get(); }
+  cudnnBatchNormMode_t mode() const { return mode_; };
+
+ private:
+  std::unique_ptr<CudnnTensorDesc> data_tensor_desc_;
+  std::unique_ptr<CudnnTensorDesc> param_tensor_desc_;
+  cudnnBatchNormMode_t mode_;
+};
+
+}  // namespace
+
+template<typename T>
+struct LayerNormKernelUtil<DeviceType::kGPU, T> {
+  static void NormalizeForward(const DeviceCtx* ctx, const Blob* in, const Blob* scale,
+                               const Blob* bias, double epsilon, Blob* out, Blob* mean,
+                               Blob* inv_variance);
+  static void NormalizeBackward(const DeviceCtx* ctx, const Blob* in, const Blob* scale,
+                                const Blob* mean, const Blob* inv_variance, const Blob* out_diff,
+                                double epsilon, Blob* in_diff, Blob* scale_diff, Blob* bias_diff);
+};
+
+template<typename T>
+void LayerNormKernelUtil<DeviceType::kGPU, T>::NormalizeForward(const DeviceCtx* ctx,
+                                                                const Blob* in, const Blob* scale,
+                                                                const Blob* bias, double epsilon,
+                                                                Blob* out, Blob* mean,
+                                                                Blob* inv_variance) {
+  CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON);
+  LayerNormCudnnBnCtx bn_ctx(in->static_shape(), mean->shape(), in->data_type());
+  CudaCheck(cudnnBatchNormalizationForwardTraining(
+      ctx->cudnn_handle(), bn_ctx.mode(), OnePtr<T>::value, ZeroPtr<T>::value,
+      bn_ctx.data_tensor_desc(), in->dptr<T>(), bn_ctx.data_tensor_desc(), out->mut_dptr<T>(),
+      bn_ctx.param_tensor_desc(), scale->dptr<T>(), bias->dptr<T>(), 1.0, nullptr, nullptr, epsilon,
+      mean->mut_dptr<T>(), inv_variance->mut_dptr<T>()));
+}
+
+template<typename T>
+void LayerNormKernelUtil<DeviceType::kGPU, T>::NormalizeBackward(
+    const DeviceCtx* ctx, const Blob* in, const Blob* scale, const Blob* mean,
+    const Blob* inv_variance, const Blob* out_diff, double epsilon, Blob* in_diff, Blob* scale_diff,
+    Blob* bias_diff) {
+  CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON);
+  LayerNormCudnnBnCtx bn_ctx(in->static_shape(), mean->shape(), in->data_type());
+  CudaCheck(cudnnBatchNormalizationBackward(
+      ctx->cudnn_handle(), bn_ctx.mode(), OnePtr<T>::value, ZeroPtr<T>::value, OnePtr<T>::value,
+      ZeroPtr<T>::value, bn_ctx.data_tensor_desc(), in->dptr<T>(), bn_ctx.data_tensor_desc(),
+      out_diff->dptr<T>(), bn_ctx.data_tensor_desc(), in_diff->mut_dptr<T>(),
+      bn_ctx.param_tensor_desc(), scale->dptr<T>(), scale_diff->mut_dptr<T>(),
+      bias_diff->mut_dptr<T>(), epsilon, mean->dptr<T>(), inv_variance->dptr<T>()));
+}
+
+#define INSTANTIATE_LAYER_NORM_KERNEL_UTIL_GPU(type_cpp, type_proto) \
+  template struct LayerNormKernelUtil<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_LAYER_NORM_KERNEL_UTIL_GPU, FLOATING_DATA_TYPE_SEQ)
+#undef INSTANTIATE_LAYER_NORM_KERNEL_UTIL_GPU
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/layer_norm_kernel.h b/oneflow/core/kernel/layer_norm_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..1241934fbb498974afaa5062ac11adbe0fd2c65e
--- /dev/null
+++ b/oneflow/core/kernel/layer_norm_kernel.h
@@ -0,0 +1,45 @@
+#ifndef ONEFLOW_CORE_KERNEL_LAYER_NORM_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_LAYER_NORM_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class LayerNormKernel final : public KernelIfWithModel<device_type, T>,
+                              public KernelIfWithActivation<device_type, T> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(LayerNormKernel);
+  LayerNormKernel() = default;
+  ~LayerNormKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  void InitModelBlobsWithRandomSeed(DeviceCtx*, std::mt19937*,
+                                    std::function<Blob*(const std::string&)>) const override;
+  void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num,
+                             const std::string& model_load_dir,
+                             std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void InitConstBufBlobs(DeviceCtx* ctx,
+                         std::function<Blob*(const std::string&)> BnInOp2Blob) const;
+  const PbMessage& GetCustomizedOpConf() const override {
+    return this->op_conf().layer_norm_conf();
+  }
+};
+
+template<DeviceType device_type, typename T>
+struct LayerNormKernelUtil {
+  static void NormalizeForward(const DeviceCtx* ctx, const Blob* in, const Blob* scale,
+                               const Blob* bias, double epsilon, Blob* out, Blob* mean,
+                               Blob* inv_variance);
+  static void NormalizeBackward(const DeviceCtx* ctx, const Blob* in, const Blob* scale,
+                                const Blob* mean, const Blob* inv_variance, const Blob* out_diff,
+                                double epsilon, Blob* in_diff, Blob* scale_diff, Blob* bias_diff);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_LAYER_NORM_KERNEL_H_
diff --git a/oneflow/core/kernel/loss_kernel.cpp b/oneflow/core/kernel/loss_kernel.cpp
index 7f84396a62f395a393f094f18688568484530857..bea078dead4fcc77e0be7640aac1677215c42bf9 100644
--- a/oneflow/core/kernel/loss_kernel.cpp
+++ b/oneflow/core/kernel/loss_kernel.cpp
@@ -1,23 +1,24 @@
 #include "oneflow/core/kernel/loss_kernel.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
 
 namespace oneflow {
 
-template<DeviceType device_type, typename PredType, typename LabelType>
-void LossKernel<device_type, PredType, LabelType>::SetLossInstanceNumBlob(
+template<DeviceType device_type, typename PredType>
+void LossKernel<device_type, PredType>::SetLossInstanceNum(
     const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
-  CHECK_GE(this->op_attribute().input_bns().size(), 2);
-  // already did CheckSameDim0ValidNum in Kernel::Forward
-  int64_t dim0_valid_num_sum =
-      BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum();
-  KernelUtil<device_type, PredType>::Set(ctx.device_ctx, static_cast<PredType>(dim0_valid_num_sum),
+  const int64_t loss_instance_num = CalcLossInstanceNum(ctx, BnInOp2Blob);
+  KernelUtil<device_type, PredType>::Set(ctx.device_ctx, static_cast<PredType>(loss_instance_num),
                                          BnInOp2Blob("loss_instance_num")->mut_dptr<PredType>());
+  CHECK(BnInOp2Blob(GenDiffBn("prediction"))->has_loss_instance_num_field());
+  BnInOp2Blob(GenDiffBn("prediction"))
+      ->set_loss_instance_num(static_cast<float>(loss_instance_num));
 }
 
-template<DeviceType device_type, typename PredType, typename LabelType>
-void LossKernel<device_type, PredType, LabelType>::ForwardDataContent(
+template<DeviceType device_type, typename PredType>
+void LossKernel<device_type, PredType>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   VirtualLossForwardDataContent(ctx, BnInOp2Blob);
-  SetLossInstanceNumBlob(ctx, BnInOp2Blob);
+  SetLossInstanceNum(ctx, BnInOp2Blob);
 
   const LossKernelConf& conf = GetLossKernelConf(this->kernel_conf());
   int64_t n = BnInOp2Blob("prediction")->shape().At(0);
@@ -30,8 +31,11 @@ void LossKernel<device_type, PredType, LabelType>::ForwardDataContent(
     if (weight_blob != nullptr) {
       PredType* weight = weight_blob->mut_dptr<PredType>();
       if (weight_blob->shape().elem_cnt() == n) {
-        KernelUtil<device_type, PredType>::Mul(ctx.device_ctx, n, weight, prediction_diff,
-                                               prediction_diff);
+        const int64_t m = prediction_diff_blob->shape().Count(1);
+        NdarrayUtil<device_type, PredType>::template BroadcastApply<BinaryFuncMul>(
+            ctx.device_ctx, XpuVarNdarray<PredType>({n, m}, prediction_diff),
+            XpuVarNdarray<const PredType>({n, 1}, weight),
+            XpuVarNdarray<const PredType>({n, m}, prediction_diff));
       } else if (weight_blob->shape().elem_cnt() == 1) {
         KernelUtil<device_type, PredType>::Scal(ctx.device_ctx, n, weight, prediction_diff, 1);
       } else {
@@ -53,32 +57,38 @@ void LossKernel<device_type, PredType, LabelType>::ForwardDataContent(
   }
 }
 
-template<DeviceType device_type, typename PredType, typename LabelType>
-void LossKernel<device_type, PredType, LabelType>::ForwardDataId(
+template<DeviceType device_type, typename PredType>
+void LossKernel<device_type, PredType>::ForwardDataId(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   BnInOp2Blob("loss")->CopyDataIdFrom(ctx.device_ctx, BnInOp2Blob("prediction"));
 }
 
-template<DeviceType device_type, typename PredType, typename LabelType>
-void LossKernel<device_type, PredType, LabelType>::ForwardColNum(
+template<DeviceType device_type, typename PredType>
+void LossKernel<device_type, PredType>::ForwardColNum(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   BnInOp2Blob(GenDiffBn("prediction"))->CopyColNumFrom(ctx.device_ctx, BnInOp2Blob("prediction"));
 }
 
-template<DeviceType device_type, typename PredType, typename LabelType>
-void LossKernel<device_type, PredType, LabelType>::ForwardDim0ValidNum(
+template<DeviceType device_type, typename PredType>
+void LossKernel<device_type, PredType>::ForwardDim0ValidNum(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   BnInOp2Blob(GenDiffBn("prediction"))
       ->CopyDim0ValidNumFrom(ctx.device_ctx, BnInOp2Blob("prediction"));
   BnInOp2Blob("loss")->CopyDim0ValidNumFrom(ctx.device_ctx, BnInOp2Blob("prediction"));
 }
 
-template<DeviceType device_type, typename PredType, typename LabelType>
-void LossKernel<device_type, PredType, LabelType>::ForwardRecordIdInDevicePiece(
+template<DeviceType device_type, typename PredType>
+void LossKernel<device_type, PredType>::ForwardRecordIdInDevicePiece(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   // do nothing
 }
 
+template<DeviceType device_type, typename PredType>
+int64_t LossKernel<device_type, PredType>::CalcLossInstanceNum(
+    const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
+  return BnInOp2Blob("prediction")->CalcDim0ValidNumSum();
+}
+
 template<typename T>
 struct LossKernelUtil<DeviceType::kCPU, T> {
   static void ComputeReductionCoefficient(DeviceCtx* ctx, int64_t data_num, int64_t weight_length,
@@ -120,11 +130,9 @@ struct LossKernelUtil<DeviceType::kCPU, T> {
 
 OF_PP_FOR_EACH_TUPLE(MAKE_LOSS_KERNEL_UTIL_ENTRY, FLOATING_DATA_TYPE_SEQ)
 
-#define MAKE_LOSS_ENTRY(device_type, data_type_pair, label_type_pair)      \
-  template class LossKernel<device_type, OF_PP_PAIR_FIRST(data_type_pair), \
-                            OF_PP_PAIR_FIRST(label_type_pair)>;
+#define MAKE_LOSS_ENTRY(device_type, data_type_pair) \
+  template class LossKernel<device_type, OF_PP_PAIR_FIRST(data_type_pair)>;
 
-OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_LOSS_ENTRY, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ,
-                                 ARITHMETIC_DATA_TYPE_SEQ)
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_LOSS_ENTRY, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/loss_kernel.h b/oneflow/core/kernel/loss_kernel.h
index 3e42b906aadc6359252981ef00ff36587a1c0291..7905587308f5ca37a74b9c953acc074c81eaed6f 100644
--- a/oneflow/core/kernel/loss_kernel.h
+++ b/oneflow/core/kernel/loss_kernel.h
@@ -5,7 +5,7 @@
 
 namespace oneflow {
 
-template<DeviceType device_type, typename PredType, typename LabelType>
+template<DeviceType device_type, typename PredType>
 class LossKernel : public KernelIf<device_type> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(LossKernel);
@@ -16,6 +16,8 @@ class LossKernel : public KernelIf<device_type> {
   virtual void VirtualLossForwardDataContent(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0;
   virtual const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const = 0;
+  virtual int64_t CalcLossInstanceNum(
+      const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const;
 
  private:
   void ForwardDataContent(const KernelCtx& ctx,
@@ -28,8 +30,8 @@ class LossKernel : public KernelIf<device_type> {
                            std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   void ForwardRecordIdInDevicePiece(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
-  void SetLossInstanceNumBlob(const KernelCtx& ctx,
-                              const std::function<Blob*(const std::string&)>& BnInOp2Blob) const;
+  void SetLossInstanceNum(const KernelCtx& ctx,
+                          const std::function<Blob*(const std::string&)>& BnInOp2Blob) const;
 };
 
 template<DeviceType device_type, typename T>
diff --git a/oneflow/core/kernel/loss_print_kernel.cpp b/oneflow/core/kernel/loss_print_kernel.cpp
index 7b6b26875403d49b884bea9575b57f39582a5fbb..c7cdfce1807928ae12f1c48ae75ff457604c6c02 100644
--- a/oneflow/core/kernel/loss_print_kernel.cpp
+++ b/oneflow/core/kernel/loss_print_kernel.cpp
@@ -20,7 +20,15 @@ void LossPrintKernel<T>::Forward(const KernelCtx& kernel_ctx,
   }
   loss_reduced /= reduction_coefficient;
   const char* loss_op_name = op_conf().name().c_str() + LossPrintPrefix.length();
-  LOG(INFO) << loss_op_name << ":" << loss_reduced;
+  double* prev_ts = static_cast<double*>(kernel_ctx.other);
+  const double cur_ts = GetCurTime() / 1e9;
+  if (*prev_ts == 0) {
+    LOG(INFO) << loss_op_name << ":" << loss_reduced;
+  } else {
+    LOG(INFO) << loss_op_name << ":" << std::fixed << std::setprecision(3) << loss_reduced << " ("
+              << (cur_ts - *prev_ts) << " sec)";
+  }
+  *prev_ts = cur_ts;
 }
 
 template<typename T>
diff --git a/oneflow/core/kernel/matmul_kernel.cpp b/oneflow/core/kernel/matmul_kernel.cpp
index e47f6e0bfd6c066ad90cd82df7382dff5f6abcb3..4a0130e566109f08e9e71f8ab257400fa40cf262 100644
--- a/oneflow/core/kernel/matmul_kernel.cpp
+++ b/oneflow/core/kernel/matmul_kernel.cpp
@@ -6,59 +6,98 @@ namespace oneflow {
 template<DeviceType device_type, typename T>
 void MatmulKernel<device_type, T>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
-  const Blob* in_blob = BnInOp2Blob("in");
-  const Blob* weight_blob = BnInOp2Blob("weight");
+  const Blob* a_blob = BnInOp2Blob("a");
+  const Blob* b_blob = BnInOp2Blob("b");
+  Blob* fw_buf_blob = BnInOp2Blob("fw_buf");
   Blob* out_blob = BnInOp2Blob("out");
-  // out = in * weight'
-  KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasTrans, OneVal<T>::value,
-                                       ZeroVal<T>::value, in_blob, weight_blob, out_blob);
-  if (this->op_conf().matmul_conf().has_bias()) {
-    const Blob* bias_blob = BnInOp2Blob("bias");
-    const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
-    // out = bias_multiplier * bias + out
-    KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans,
-                                         OneVal<T>::value, OneVal<T>::value, bias_mul_blob,
-                                         bias_blob, out_blob);
+  bool transpose_a = this->op_conf().matmul_conf().transpose_a();
+  bool transpose_b = this->op_conf().matmul_conf().transpose_b();
+  if (a_blob->static_shape().dim_vec().size() == 2) {
+    Calc2DMatMul(ctx.device_ctx, a_blob, transpose_a, b_blob, transpose_b, out_blob, false);
+  } else {
+    CalcBatchMatMul(ctx.device_ctx, a_blob, transpose_a, b_blob, transpose_b, out_blob, fw_buf_blob,
+                    false);
   }
 }
 
 template<DeviceType device_type, typename T>
 void MatmulKernel<device_type, T>::BackwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* a_blob = BnInOp2Blob("a");
+  const Blob* b_blob = BnInOp2Blob("b");
   const Blob* out_diff_blob = BnInOp2Blob("out_diff");
-  const Blob* in_blob = BnInOp2Blob("in");
-  Blob* in_diff_blob = BnInOp2Blob("in_diff");
-  const Blob* weight_blob = BnInOp2Blob("weight");
-  Blob* weight_diff_blob = BnInOp2Blob("weight_diff");
-  // weight_diff = out_diff * in'
-  KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal<T>::value,
-                                       ZeroVal<T>::value, out_diff_blob, in_blob, weight_diff_blob);
-  // in_diff = out_diff * weight
-  KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, OneVal<T>::value,
-                                       ZeroVal<T>::value, out_diff_blob, weight_blob, in_diff_blob);
-  if (this->op_conf().matmul_conf().has_bias()) {
-    const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
-    Blob* bias_diff_blob = BnInOp2Blob("bias_diff");
-    // bias_diff = bias_multiplier' * out_diff
-    KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal<T>::value,
-                                         ZeroVal<T>::value, bias_mul_blob, out_diff_blob,
-                                         bias_diff_blob);
+  Blob* a_diff_blob = BnInOp2Blob("a_diff");
+  Blob* b_diff_blob = BnInOp2Blob("b_diff");
+  Blob* bw_buf_blob = BnInOp2Blob("bw_buf");
+  bool transpose_a = this->op_conf().matmul_conf().transpose_a();
+  bool transpose_b = this->op_conf().matmul_conf().transpose_b();
+  // trans_a  trans_b  a_diff  b_diff
+  //   T        T       b'g'    g'a'
+  //   T        F       bg'     ag
+  //   F        T       gb      g'a
+  //   F        F       gb'     a'g
+  if (a_blob->static_shape().dim_vec().size() == 2) {
+    Calc2DMatMul(ctx.device_ctx, b_blob, !(transpose_a ^ transpose_b), out_diff_blob, transpose_a,
+                 a_diff_blob, !transpose_a);
+    Calc2DMatMul(ctx.device_ctx, a_blob, !(transpose_a ^ transpose_b), out_diff_blob, transpose_b,
+                 b_diff_blob, transpose_b);
+  } else {
+    CalcBatchMatMul(ctx.device_ctx, b_blob, !(transpose_a ^ transpose_b), out_diff_blob,
+                    transpose_a, a_diff_blob, bw_buf_blob, !transpose_a);
+    CalcBatchMatMul(ctx.device_ctx, a_blob, !(transpose_a ^ transpose_b), out_diff_blob,
+                    transpose_b, b_diff_blob, bw_buf_blob, transpose_b);
   }
 }
 
 template<DeviceType device_type, typename T>
-void MatmulKernel<device_type, T>::InitConstBufBlobs(
-    DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
-  if (!this->op_conf().matmul_conf().has_bias()) { return; }
-  InitializerConf bias_multiplier_initializer_conf;
-  bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
-  KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0,
-                                                 BnInOp2Blob("bias_multiplier"));
+const PbMessage& MatmulKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().matmul_conf();
 }
 
 template<DeviceType device_type, typename T>
-const PbMessage& MatmulKernel<device_type, T>::GetCustomizedOpConf() const {
-  return this->op_conf().matmul_conf();
+void MatmulKernel<device_type, T>::Calc2DMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a,
+                                                const Blob* b, bool trans_b, Blob* c,
+                                                bool swap_in) const {
+  CBLAS_TRANSPOSE blas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
+  CBLAS_TRANSPOSE blas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
+  if (swap_in) {
+    KernelUtil<device_type, T>::BlobGemm(ctx, blas_trans_b, blas_trans_a, OneVal<T>::value,
+                                         ZeroVal<T>::value, b, a, c);
+  } else {
+    KernelUtil<device_type, T>::BlobGemm(ctx, blas_trans_a, blas_trans_b, OneVal<T>::value,
+                                         ZeroVal<T>::value, a, b, c);
+  }
+}
+
+template<DeviceType device_type, typename T>
+void MatmulKernel<device_type, T>::CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a,
+                                                   const Blob* b, bool trans_b, Blob* c, Blob* buf,
+                                                   bool swap_in) const {
+  if (swap_in) {
+    CalcBatchMatMul(ctx, b, trans_b, a, trans_a, c, buf);
+  } else {
+    CalcBatchMatMul(ctx, a, trans_a, b, trans_b, c, buf);
+  }
+}
+
+template<DeviceType device_type, typename T>
+void MatmulKernel<device_type, T>::CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a,
+                                                   const Blob* b, bool trans_b, Blob* c,
+                                                   Blob* buf) const {
+  CBLAS_TRANSPOSE blas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
+  CBLAS_TRANSPOSE blas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
+  int32_t dim_num = a->shape().dim_vec().size();
+  int32_t batch_size = a->shape().Count(0, dim_num - 2);
+  int m = trans_a ? a->shape().At(dim_num - 1) : a->shape().At(dim_num - 2);
+  int k = trans_a ? a->shape().At(dim_num - 2) : a->shape().At(dim_num - 1);
+  int n = trans_b ? b->shape().At(dim_num - 2) : b->shape().At(dim_num - 1);
+  const T* a_dptr = a->dptr<T>();
+  const T* b_dptr = b->dptr<T>();
+  T* c_dptr = c->mut_dptr<T>();
+  T** buf_dptr = reinterpret_cast<T**>(buf->mut_dptr<int64_t>());
+  KernelUtil<device_type, T>::OFBatchedGemm(ctx, blas_trans_a, blas_trans_b, batch_size, m, n, k,
+                                            OneVal<T>::value, a_dptr, b_dptr, ZeroVal<T>::value,
+                                            c_dptr, buf_dptr);
 }
 
 ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMatmulConf, MatmulKernel, FLOATING_DATA_TYPE_SEQ);
diff --git a/oneflow/core/kernel/matmul_kernel.h b/oneflow/core/kernel/matmul_kernel.h
index 4a82c3da53b4dab006fb322335c2549cc84f3a0b..6917a10b6f3f2cdbd9947d77efbe58d282378f18 100644
--- a/oneflow/core/kernel/matmul_kernel.h
+++ b/oneflow/core/kernel/matmul_kernel.h
@@ -16,8 +16,12 @@ class MatmulKernel final : public KernelIfWithModel<device_type, T> {
                           std::function<Blob*(const std::string&)>) const override;
   void BackwardDataContent(const KernelCtx&,
                            std::function<Blob*(const std::string&)>) const override;
-  void InitConstBufBlobs(DeviceCtx*,
-                         std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void Calc2DMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, const Blob* b, bool trans_b,
+                    Blob* c, bool swap_in) const;
+  void CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, const Blob* b, bool trans_b,
+                       Blob* c, Blob* buf, bool swap_in) const;
+  void CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, const Blob* b, bool trans_b,
+                       Blob* c, Blob* buf) const;
   const PbMessage& GetCustomizedOpConf() const override;
 };
 
diff --git a/oneflow/core/kernel/max_pooling_kernel.cpp b/oneflow/core/kernel/max_pooling_kernel.cpp
index bf32b21b39b13780ac601f4a3c66959a8241a5dd..0ea7e8f2fd0fae727b82bce1e22a07dd5ac07909 100644
--- a/oneflow/core/kernel/max_pooling_kernel.cpp
+++ b/oneflow/core/kernel/max_pooling_kernel.cpp
@@ -4,7 +4,7 @@ namespace oneflow {
 
 template<typename T>
 T MaxPoolingKernel<DeviceType::kCPU, T>::ForwardInitialize() const {
-  return MinVal<T>();
+  return GetMinVal<T>();
 }
 
 template<typename T>
diff --git a/oneflow/core/kernel/mean_kernel.cpp b/oneflow/core/kernel/mean_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..79f5e84c32cd701403069b56fb65643dfb14dffe
--- /dev/null
+++ b/oneflow/core/kernel/mean_kernel.cpp
@@ -0,0 +1,45 @@
+#include "oneflow/core/kernel/mean_kernel.h"
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void MeanKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  size_t count = in_blob->shape().elem_cnt() / out_blob->shape().elem_cnt();
+  Blob* fw_tmp_blob = BnInOp2Blob("fw_tmp");
+  size_t num_axes = in_blob->shape().NumAxes();
+  NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(out_blob, num_axes),
+                                         XpuVarNdarray<const T>(in_blob, num_axes),
+                                         XpuVarNdarray<T>(fw_tmp_blob, num_axes));
+
+  KernelUtil<device_type, T>::Div(ctx.device_ctx, out_blob->shape().elem_cnt(),
+                                  out_blob->mut_dptr<T>(), static_cast<T>(count));
+}
+
+template<DeviceType device_type, typename T>
+void MeanKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  CHECK_EQ(1, out_diff_blob->shape().dim_vec().back());
+  Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
+  Blob* bw_tmp_blob = BnInOp2Blob("bw_tmp");
+
+  Memcpy<device_type>(ctx.device_ctx, bw_tmp_blob->mut_dptr(), out_diff_blob->dptr(),
+                      out_diff_blob->ByteSizeOfDataContentField());
+  size_t count = in_diff_blob->shape().elem_cnt() / out_diff_blob->shape().elem_cnt();
+  // bw_tmp = out_diff/M
+  KernelUtil<device_type, T>::Div(ctx.device_ctx, bw_tmp_blob->shape().elem_cnt(),
+                                  bw_tmp_blob->mut_dptr<T>(), static_cast<T>(count));
+  size_t num_axes = in_diff_blob->shape().NumAxes();
+  NdarrayUtil<device_type, T>::template BroadcastApply<UnaryFuncIdentity>(
+      ctx.device_ctx, XpuVarNdarray<T>(in_diff_blob, num_axes),
+      XpuVarNdarray<const T>(out_diff_blob->shape(), bw_tmp_blob->dptr<T>()));
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMeanConf, MeanKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/mean_kernel.h b/oneflow/core/kernel/mean_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..2f4d45d39054d726e57d26feee194f4f2e9f6e52
--- /dev/null
+++ b/oneflow/core/kernel/mean_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_MEAN_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_MEAN_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class MeanKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(MeanKernel);
+  MeanKernel() = default;
+  ~MeanKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  const PbMessage& GetCustomizedOpConf() const override { return this->op_conf().mean_conf(); }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOE_CORE_KERNEL_MEAN_KERNEL_H_
diff --git a/oneflow/core/kernel/normal_model_update_kernel.cpp b/oneflow/core/kernel/normal_model_update_kernel.cpp
index 1857bde209b69520f4ef1ed654191f5d3d6707b5..1d6c1cedeac1e5f4567861a68bd27844e0139455 100644
--- a/oneflow/core/kernel/normal_model_update_kernel.cpp
+++ b/oneflow/core/kernel/normal_model_update_kernel.cpp
@@ -3,6 +3,7 @@
 #include "oneflow/core/kernel/momentum_model_update_kernel.h"
 #include "oneflow/core/kernel/rmsprop_model_update_kernel.h"
 #include "oneflow/core/kernel/lars_model_update_kernel.h"
+#include "oneflow/core/kernel/adam_model_update_kernel.h"
 
 namespace oneflow {
 
@@ -13,13 +14,17 @@ void NormalMdUpdateKernel<device_type, T>::Forward(
   int64_t cur_batch_num = next_model_vid - 1;
   const NormalModelUpdateOpUserConf& conf = this->op_conf().normal_mdupdt_conf().user_conf();
   float learning_rate = this->op_conf().normal_mdupdt_conf().learning_rate();
+  const T* batch_instance_num_ptr = BnInOp2Blob("total_instance_num_diff")->dptr<T>();
+  if (conf.has_clip_conf()) {
+    ClipGradient(ctx.device_ctx, cur_batch_num, conf.clip_conf(), batch_instance_num_ptr,
+                 BnInOp2Blob);
+  }
   if (TriggerWarmup(conf, learning_rate, cur_batch_num)) {
     learning_rate = GetWarmupLearningRate(conf.warmup_conf(), learning_rate, cur_batch_num);
   } else if (conf.has_learning_rate_decay()) {
     learning_rate =
         GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num);
   }
-  const T* batch_instance_num_ptr = BnInOp2Blob("total_instance_num_diff")->dptr<T>();
   float l1 = this->op_conf().normal_mdupdt_conf().l1();
   float l2 = this->op_conf().normal_mdupdt_conf().l2();
   UpdateModel(ctx.device_ctx, batch_instance_num_ptr, static_cast<T>(learning_rate),
@@ -43,6 +48,8 @@ Kernel* CreateMdUpdtKernel(const KernelConf& kernel_conf) {
     return CreateRMSPropMdUpdtKernel(kernel_conf);
   } else if (user_conf.has_lars_conf()) {
     return CreateLARSMdUpdtKernel(kernel_conf);
+  } else if (user_conf.has_adam_conf()) {
+    return CreateAdamMdUpdtKernel(kernel_conf);
   } else {
     UNIMPLEMENTED();
   }
@@ -138,7 +145,7 @@ double ConstantWarmupLearningRate(const ConstantWarmupConf& conf, double lr,
 
 double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t cur_batch_num) {
   CHECK_GT(conf.warmup_batches(), 0);
-  CHECK_GT(conf.start_multiplier(), 0);
+  CHECK_GE(conf.start_multiplier(), 0);
   CHECK_LT(conf.start_multiplier(), 1);
   double start_multiplier = conf.start_multiplier();
   double multiplier = 1.0;
@@ -149,9 +156,32 @@ double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t
   return lr * multiplier;
 }
 
+template<DeviceType device_type, typename T>
+void ClipByGlobalNorm(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipByGlobalNormConf& conf,
+                      const T* batch_instance_num_ptr,
+                      std::function<Blob*(const std::string&)> BnInOp2Blob) {
+  int64_t n = BnInOp2Blob("model_diff")->shape().elem_cnt();
+  T* model_diff = BnInOp2Blob("model_diff")->mut_dptr<T>();
+  T* global_norm_ptr = BnInOp2Blob("data_tmp")->mut_dptr<T>();
+  if (conf.has_global_norm()) {
+    KernelUtil<device_type, T>::Set(ctx, static_cast<T>(conf.global_norm()), global_norm_ptr);
+  } else {
+    // The Dot does not read the result, so the global_norm need not be initialized.
+    KernelUtil<device_type, T>::Dot(ctx, n, model_diff, 1, model_diff, 1, global_norm_ptr);
+    KernelUtil<device_type, T>::Sqrt(ctx, 1, global_norm_ptr, global_norm_ptr);
+    KernelUtil<device_type, T>::Div(ctx, 1, global_norm_ptr, batch_instance_num_ptr);
+  }
+  T* ratio_ptr = BnInOp2Blob("data_tmp")->mut_dptr<T>();
+  NormalMdUpdateKernelUtil<device_type, T>::CmptClipRatioByGlobalNorm(
+      ctx, global_norm_ptr, static_cast<T>(conf.clip_norm()), ratio_ptr);
+  KernelUtil<device_type, T>::Scal(ctx, n, ratio_ptr, model_diff, 1);
+}
+
 }  // namespace
 
-bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, int64_t cur_batch_num) {
+template<DeviceType device_type, typename T>
+bool NormalMdUpdateKernel<device_type, T>::TriggerWarmup(const NormalModelUpdateOpUserConf& conf,
+                                                         double lr, int64_t cur_batch_num) const {
   if (!conf.has_warmup_conf()) { return false; }
   const WarmupConf& warmup_conf = conf.warmup_conf();
   if (warmup_conf.has_constant_conf()) {
@@ -163,7 +193,10 @@ bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, int64_t c
   }
 }
 
-double GetWarmupLearningRate(const WarmupConf& conf, double lr, int64_t cur_batch_num) {
+template<DeviceType device_type, typename T>
+double NormalMdUpdateKernel<device_type, T>::GetWarmupLearningRate(const WarmupConf& conf,
+                                                                   double lr,
+                                                                   int64_t cur_batch_num) const {
   if (conf.has_constant_conf()) {
     return ConstantWarmupLearningRate(conf.constant_conf(), lr, cur_batch_num);
   } else if (conf.has_linear_conf()) {
@@ -173,7 +206,21 @@ double GetWarmupLearningRate(const WarmupConf& conf, double lr, int64_t cur_batc
   }
 }
 
-double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) {
+template<DeviceType device_type, typename T>
+void NormalMdUpdateKernel<device_type, T>::ClipGradient(
+    DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf,
+    const T* batch_instance_num_ptr, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  if (conf.has_clip_by_global_norm()) {
+    ClipByGlobalNorm<device_type, T>(ctx, cur_batch_num, conf.clip_by_global_norm(),
+                                     batch_instance_num_ptr, BnInOp2Blob);
+  } else {
+    UNIMPLEMENTED();
+  }
+}
+
+template<DeviceType device_type, typename T>
+double NormalMdUpdateKernel<device_type, T>::GetDecayedLearningRate(
+    const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) const {
   if (conf.has_exponential_conf()) {
     return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num);
   } else if (conf.has_inverse_time_conf()) {
@@ -193,6 +240,15 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6
   }
 }
 
+template<typename T>
+class NormalMdUpdateKernelUtil<DeviceType::kCPU, T> final {
+ public:
+  static void CmptClipRatioByGlobalNorm(DeviceCtx* ctx, const T* global_norm_ptr, T clip_norm,
+                                        T* ratio_ptr) {
+    *ratio_ptr = clip_norm / std::max(*global_norm_ptr, clip_norm);
+  }
+};
+
 REGISTER_KERNEL_CREATOR(OperatorConf::kNormalMdupdtConf, CreateMdUpdtKernel);
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/normal_model_update_kernel.cu b/oneflow/core/kernel/normal_model_update_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d8ee75da9cae581ae9b71cb540e956da04849041
--- /dev/null
+++ b/oneflow/core/kernel/normal_model_update_kernel.cu
@@ -0,0 +1,29 @@
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/kernel/normal_model_update_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void CmptClipRatioByGlobalNormGpu(const T* global_norm_ptr, T clip_norm, T* ratio_ptr) {
+  *ratio_ptr = clip_norm / max(*global_norm_ptr, clip_norm);
+}
+
+}  // namespace
+
+template<typename T>
+class NormalMdUpdateKernelUtil<DeviceType::kGPU, T> final {
+ public:
+  static void CmptClipRatioByGlobalNorm(DeviceCtx* ctx, const T* global_norm_ptr, T clip_norm,
+                                        T* ratio_ptr) {
+    CmptClipRatioByGlobalNormGpu<T>
+        <<<1, 1, 0, ctx->cuda_stream()>>>(global_norm_ptr, clip_norm, ratio_ptr);
+  }
+};
+
+#define INSTANTIATE_GPU_KERNEL_UTIL(type_cpp, type_proto) \
+  template class NormalMdUpdateKernelUtil<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GPU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/normal_model_update_kernel.h b/oneflow/core/kernel/normal_model_update_kernel.h
index f3f8ff1901a21fba79518bf312bd0d296801dcb3..ae5accae68144baf2db3ac891b0818c5614a6623 100644
--- a/oneflow/core/kernel/normal_model_update_kernel.h
+++ b/oneflow/core/kernel/normal_model_update_kernel.h
@@ -19,11 +19,24 @@ class NormalMdUpdateKernel : public KernelIf<device_type> {
   virtual void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1,
                            T l2, int64_t next_model_vid,
                            std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0;
+
+ private:
+  bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr,
+                     int64_t cur_batch_num) const;
+  double GetWarmupLearningRate(const WarmupConf&, double lr, int64_t cur_batch_num) const;
+  double GetDecayedLearningRate(const LearningRateDecayConf&, double lr,
+                                int64_t cur_batch_num) const;
+  void ClipGradient(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf,
+                    const T* batch_instance_num_ptr,
+                    std::function<Blob*(const std::string&)> BnInOp2Blob) const;
 };
 
-bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, int64_t cur_batch_num);
-double GetWarmupLearningRate(const WarmupConf&, double lr, int64_t cur_batch_num);
-double GetDecayedLearningRate(const LearningRateDecayConf&, double lr, int64_t cur_batch_num);
+template<DeviceType device_type, typename T>
+class NormalMdUpdateKernelUtil final {
+ public:
+  static void CmptClipRatioByGlobalNorm(DeviceCtx* ctx, const T* global_norm_ptr, T clip_norm,
+                                        T* ratio_ptr);
+};
 
 #define DECLARE_MDUPDT_KERNEL_CREATOR(x) Kernel* Create##x##MdUpdtKernel(const KernelConf&);
 
diff --git a/oneflow/core/kernel/one_hot_kernel.cpp b/oneflow/core/kernel/one_hot_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dd9ba25925680945dd01a3ac3161ff50ba0ca7a6
--- /dev/null
+++ b/oneflow/core/kernel/one_hot_kernel.cpp
@@ -0,0 +1,57 @@
+#include "oneflow/core/kernel/one_hot_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<DeviceType device_type, typename T, typename K>
+void OneHot(DeviceCtx* ctx, const Blob* indices, Blob* out) {
+  const int64_t depth = out->shape().At(out->shape().NumAxes() - 1);
+  OneHotKernelUtil<device_type, T, K>::Encode(ctx, indices->dptr<K>(), indices->shape().elem_cnt(),
+                                              depth, out->mut_dptr<T>());
+}
+
+}  // namespace
+
+template<DeviceType device_type, typename T>
+struct OneHotUtil final {
+#define MAKE_ONE_HOT_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>
+  DEFINE_STATIC_SWITCH_FUNC(void, OneHot, MAKE_ONE_HOT_SWITCH_ENTRY,
+                            MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
+#undef MAKE_ONE_HOT_SWITCH_ENTRY
+};
+
+template<DeviceType device_type, typename T>
+const PbMessage& OneHotKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().one_hot_conf();
+}
+
+template<DeviceType device_type, typename T>
+void OneHotKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* indices = BnInOp2Blob("indices");
+  Blob* out = BnInOp2Blob("out");
+  OneHotUtil<device_type, T>::SwitchOneHot(SwitchCase(indices->data_type()), ctx.device_ctx,
+                                           indices, out);
+}
+
+template<typename T, typename K>
+struct OneHotKernelUtil<DeviceType::kCPU, T, K> final {
+  static void Encode(DeviceCtx* ctx, const K* indices, int64_t num_indices, int64_t depth, T* out);
+};
+
+template<typename T, typename K>
+void OneHotKernelUtil<DeviceType::kCPU, T, K>::Encode(DeviceCtx* ctx, const K* indices,
+                                                      int64_t num_indices, int64_t depth, T* out) {
+  Memset<kCPU>(ctx, out, 0, num_indices * depth * sizeof(T));
+  FOR_RANGE(int64_t, i, 0, num_indices) {
+    const K idx = indices[i];
+    CHECK_GE(idx, 0);
+    CHECK_LT(idx, depth);
+    out[i * depth + idx] = OneVal<T>::value;
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kOneHotConf, OneHotKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/one_hot_kernel.cu b/oneflow/core/kernel/one_hot_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c996d9cea850659dcb2fb5ccd09e3529dc01f6b7
--- /dev/null
+++ b/oneflow/core/kernel/one_hot_kernel.cu
@@ -0,0 +1,43 @@
+#include "oneflow/core/kernel/one_hot_kernel.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+#include <assert.h>
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, typename K>
+__global__ void OneHotEncodeGpu(int64_t elem_cnt, const K* indices, int64_t depth, T* out) {
+  CUDA_1D_KERNEL_LOOP(i, elem_cnt) {
+    const int64_t row = i / depth;
+    const int64_t col = i % depth;
+    const int64_t idx = indices[row];
+    assert(idx >= 0 && idx < depth);
+    out[i] = (idx == col);
+  }
+}
+
+}  // namespace
+
+template<typename T, typename K>
+struct OneHotKernelUtil<DeviceType::kGPU, T, K> final {
+  static void Encode(DeviceCtx* ctx, const K* indices, int64_t num_indices, int64_t depth, T* out);
+};
+
+template<typename T, typename K>
+void OneHotKernelUtil<DeviceType::kGPU, T, K>::Encode(DeviceCtx* ctx, const K* indices,
+                                                      int64_t num_indices, int64_t depth, T* out) {
+  const int64_t elem_cnt = num_indices * depth;
+  OneHotEncodeGpu<T, K>
+      <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          elem_cnt, indices, depth, out);
+}
+
+#define INSTANTIATE_ONE_HOT_KERNEL_UTIL_GPU(data_type_pair, index_type_pair)           \
+  template struct OneHotKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), \
+                                   OF_PP_PAIR_FIRST(index_type_pair)>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ONE_HOT_KERNEL_UTIL_GPU, ARITHMETIC_DATA_TYPE_SEQ,
+                                 INT_DATA_TYPE_SEQ);
+#undef INSTANTIATE_ONE_HOT_KERNEL_UTIL_GPU
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/one_hot_kernel.h b/oneflow/core/kernel/one_hot_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..5d5c8bd7fe7d744125d0b09b85f4097705faf5f4
--- /dev/null
+++ b/oneflow/core/kernel/one_hot_kernel.h
@@ -0,0 +1,28 @@
+#ifndef ONEFLOW_CORE_KERNEL_ONE_HOT_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_ONE_HOT_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class OneHotKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(OneHotKernel)
+  OneHotKernel() = default;
+  ~OneHotKernel() = default;
+
+ private:
+  const PbMessage& GetCustomizedOpConf() const override;
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+};
+
+template<DeviceType device_type, typename T, typename K>
+struct OneHotKernelUtil final {
+  static void Encode(DeviceCtx* ctx, const K* indices, int64_t num_indices, int64_t depth, T* out);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_ONE_HOT_KERNEL_H_
diff --git a/oneflow/core/kernel/opkernel_test_case.cpp b/oneflow/core/kernel/opkernel_test_case.cpp
index 798d18796c4f1b83893eccebc0747747d31b8fd6..5c0544bdd3f3494d3b7e20ec1e68c12b3b734a37 100644
--- a/oneflow/core/kernel/opkernel_test_case.cpp
+++ b/oneflow/core/kernel/opkernel_test_case.cpp
@@ -420,8 +420,8 @@ void OpKernelTestCase::InferBlobDesc(std::shared_ptr<Operator>* op, OpContext**
   op_conf_.set_device_type(default_device_type_);
   *op = ConstructOp(op_conf_);
   if (NeedInferBlobDescs(op->get())) {
-    (*op)->InferBlobDescs(MakeGetterBnInOp2BlobDesc(), &parallel_ctx_,
-                          [&](OpContext* op_ctx) { *op_context = op_ctx; });
+    (*op)->InferBlobDescsIf(MakeGetterBnInOp2BlobDesc(), &parallel_ctx_, 1,
+                            [&](OpContext* op_ctx) { *op_context = op_ctx; });
   }
 }
 
diff --git a/oneflow/core/kernel/pack_kernel.h b/oneflow/core/kernel/pack_kernel.h
index 66e5c0be9e69a4d08a4b1d439651d86f275a4309..0c73cc66bbbcacab1722e7347dc32d64254c0cf9 100644
--- a/oneflow/core/kernel/pack_kernel.h
+++ b/oneflow/core/kernel/pack_kernel.h
@@ -31,6 +31,10 @@ class PackKernel final : public KernelIf<device_type> {
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   void ForwardDataId(const KernelCtx& ctx,
                      std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BackwardInDiffLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
+    UNIMPLEMENTED();
+  }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/random_generator.cpp b/oneflow/core/kernel/random_generator.cpp
index f47526032f26d766e2b64756d7228fc4a9ad24f8..c939d075b5164015e5e610ed35b1a64d7b525bc2 100644
--- a/oneflow/core/kernel/random_generator.cpp
+++ b/oneflow/core/kernel/random_generator.cpp
@@ -14,7 +14,7 @@ void RandomGenerator<DeviceType::kCPU>::Uniform(const int64_t elem_cnt, const T
   CHECK_GE(elem_cnt, 0);
   CHECK(dptr);
   CHECK_LE(min, max);
-  std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, MaxVal<T>()));
+  std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>()));
   for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(mt19937_generator_); }
 }
 
diff --git a/oneflow/core/kernel/record_load_kernel.cpp b/oneflow/core/kernel/record_load_kernel.cpp
index 02c479a5c7795731d2652cd84a92cc3e8b60bec4..7cd3f403023a2a937b71dffab45d4e1ccf6af16b 100644
--- a/oneflow/core/kernel/record_load_kernel.cpp
+++ b/oneflow/core/kernel/record_load_kernel.cpp
@@ -21,24 +21,39 @@ void RecordLoadKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) {
     int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0);
     data_paths.push_back(JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num));
   }
+  piece_size_in_one_loader_ = kernel_conf().record_load_conf().device_piece_size();
   if (Global<JobDesc>::Get()->IsTrain()) {
+    const size_t num_max_read =
+        static_cast<size_t>(piece_size_in_one_loader_ * Global<JobDesc>::Get()->TotalBatchNum()
+                            * Global<JobDesc>::Get()->NumOfPiecesInBatch());
     in_stream_.reset(new PersistentInStream(
         DataFS(), data_paths, true, Global<JobDesc>::Get()->save_downloaded_file_to_local_fs()));
+    if (record_load_conf.has_random_shuffle_conf()) {
+      const int32_t shuffle_buffer_size = record_load_conf.random_shuffle_conf().buffer_size();
+      CHECK_GT(shuffle_buffer_size, 0);
+      record_reader_.reset(new RandomShuffleOFRecordReader(
+          in_stream_.get(), static_cast<size_t>(shuffle_buffer_size), num_max_read));
+    } else {
+      record_reader_.reset(new NaiveOFRecordReader(in_stream_.get(), num_max_read));
+    }
   } else {
     in_stream_.reset(new PersistentInStream(DataFS(), data_paths, false, false));
+    if (record_load_conf.has_random_shuffle_conf()) {
+      const int32_t shuffle_buffer_size = record_load_conf.random_shuffle_conf().buffer_size();
+      CHECK_GT(shuffle_buffer_size, 0);
+      record_reader_.reset(new RandomShuffleOFRecordReader(
+          in_stream_.get(), static_cast<size_t>(shuffle_buffer_size)));
+    } else {
+      record_reader_.reset(new NaiveOFRecordReader(in_stream_.get()));
+    }
   }
-  int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize();
-  CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0);
-  piece_size_in_one_loader_ = global_piece_size / parallel_ctx->parallel_num();
 }
 
 void RecordLoadKernel::Forward(const KernelCtx& ctx,
                                std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   auto status = static_cast<RecordLoadStatus*>(ctx.other);
-  Blob* out_blob = BnInOp2Blob("out");
-  RecordBlob<OFRecord> record_blob(out_blob);
-  record_blob.ReadFrom(in_stream_.get());
-  status->record_num = record_blob.record_num();
+  status->record_num = record_reader_->Read(static_cast<size_t>(piece_size_in_one_loader_),
+                                            BnInOp2Blob("out")->mut_dptr<OFRecord>());
   if (status->record_num < piece_size_in_one_loader_) { status->is_eof = true; }
 }
 
diff --git a/oneflow/core/kernel/record_load_kernel.h b/oneflow/core/kernel/record_load_kernel.h
index d21b26cf8cc80fb6d31f93d4ba7af50dbff880f7..42472836da1db1daff10ed4725cd3de0e5b0d37b 100644
--- a/oneflow/core/kernel/record_load_kernel.h
+++ b/oneflow/core/kernel/record_load_kernel.h
@@ -3,6 +3,7 @@
 
 #include "oneflow/core/kernel/kernel.h"
 #include "oneflow/core/persistence/persistent_in_stream.h"
+#include "oneflow/core/record/ofrecord_reader.h"
 
 namespace oneflow {
 
@@ -15,7 +16,7 @@ class RecordLoadKernel final : public KernelIf<DeviceType::kCPU> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(RecordLoadKernel);
   RecordLoadKernel() = default;
-  ~RecordLoadKernel() = default;
+  ~RecordLoadKernel() override = default;
 
  private:
   void VirtualKernelInit(const ParallelContext*) override;
@@ -23,6 +24,7 @@ class RecordLoadKernel final : public KernelIf<DeviceType::kCPU> {
                std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
 
   std::unique_ptr<PersistentInStream> in_stream_;
+  std::unique_ptr<OFRecordReader> record_reader_;
   int64_t piece_size_in_one_loader_;
 };
 
diff --git a/oneflow/core/kernel/reduce_concat_kernel.cpp b/oneflow/core/kernel/reduce_concat_kernel.cpp
index f2884fe55973eab00a51c144ce7e90331a8f708d..f4baf582ad985b0cc24734399758ef0e65eca343 100644
--- a/oneflow/core/kernel/reduce_concat_kernel.cpp
+++ b/oneflow/core/kernel/reduce_concat_kernel.cpp
@@ -5,16 +5,15 @@ namespace oneflow {
 template<DeviceType device_type>
 void ReduceConcatKernel<device_type>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
-  const auto* other_val = static_cast<std::pair<int64_t, bool>*>(ctx.other);
-  int64_t in_bn_id = other_val->first;
-  bool is_inplace = other_val->second;
-  if (is_inplace) { return; }
+  if (device_type == DeviceType::kGPU && Global<JobDesc>::Get()->enable_mem_sharing()) { return; }
   Blob* out_blob = BnInOp2Blob("out");
   char* dst_cur_dptr = out_blob->mut_dptr<char>();
-  dst_cur_dptr += this->kernel_conf().reduce_concat_conf().data_offset().Get(in_bn_id);
-  Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id));
-  size_t in_byte_size = in_blob->ByteSizeOfDataContentField();
-  Memcpy<device_type>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size);
+  FOR_RANGE(int, in_bn_id, 0, this->op_attribute().input_bns().size()) {
+    dst_cur_dptr += this->kernel_conf().reduce_concat_conf().data_offset().Get(in_bn_id);
+    Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id));
+    size_t in_byte_size = in_blob->ByteSizeOfDataContentField();
+    Memcpy<device_type>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size);
+  }
 }
 
 ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceConcatConf, ReduceConcatKernel);
diff --git a/oneflow/core/kernel/reduce_identity_kernel.cpp b/oneflow/core/kernel/reduce_identity_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b6dab91c4cec96ac725ac8a1291de5a2465349a4
--- /dev/null
+++ b/oneflow/core/kernel/reduce_identity_kernel.cpp
@@ -0,0 +1,17 @@
+#include "oneflow/core/kernel/reduce_identity_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type>
+void ReduceIdentityKernel<device_type>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  CHECK_EQ(out_blob->ByteSizeOfDataContentField(), in_blob->ByteSizeOfDataContentField());
+  Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr(), in_blob->dptr(),
+                      out_blob->ByteSizeOfDataContentField());
+}
+
+ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceIdentityConf, ReduceIdentityKernel);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/reduce_identity_kernel.h b/oneflow/core/kernel/reduce_identity_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..1994fa1d4e98becf5adcde1eadcc64cd36312376
--- /dev/null
+++ b/oneflow/core/kernel/reduce_identity_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_REDUCE_IDENTITY_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_REDUCE_IDENTITY_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type>
+class ReduceIdentityKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ReduceIdentityKernel);
+  ReduceIdentityKernel() = default;
+  ~ReduceIdentityKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  const PbMessage& GetCustomizedOpConf() const override {
+    return this->op_conf().reduce_identity_conf();
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_REDUCE_IDENTITY_KERNEL_H_
diff --git a/oneflow/core/kernel/reduce_mean_kernel.cpp b/oneflow/core/kernel/reduce_mean_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1b42148e1c0adde893e716a93c023d7c1ab3f855
--- /dev/null
+++ b/oneflow/core/kernel/reduce_mean_kernel.cpp
@@ -0,0 +1,43 @@
+#include "oneflow/core/kernel/reduce_mean_kernel.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void ReduceMeanKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  Blob* fw_tmp_blob = BnInOp2Blob("fw_tmp");
+  size_t count = in_blob->shape().elem_cnt() / out_blob->shape().elem_cnt();
+  NdarrayUtil<device_type, T>::ReduceSum(
+      ctx.device_ctx,
+      XpuVarNdarray<T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()),
+                       out_blob->mut_dptr<T>()),
+      XpuVarNdarray<const T>(in_blob, in_blob->shape().NumAxes()),
+      XpuVarNdarray<T>(fw_tmp_blob, in_blob->shape().NumAxes()));
+  KernelUtil<device_type, T>::Div(ctx.device_ctx, out_blob->shape().elem_cnt(),
+                                  out_blob->mut_dptr<T>(), static_cast<T>(count));
+}
+
+template<DeviceType device_type, typename T>
+void ReduceMeanKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  Blob* in_diff_blob = BnInOp2Blob("in_diff");
+  Blob* bw_tmp_blob = BnInOp2Blob("bw_tmp");
+  size_t count = in_diff_blob->shape().elem_cnt() / out_diff_blob->shape().elem_cnt();
+  Memcpy<device_type>(ctx.device_ctx, bw_tmp_blob->mut_dptr(), out_diff_blob->dptr(),
+                      out_diff_blob->ByteSizeOfDataContentField());
+  KernelUtil<device_type, T>::Div(ctx.device_ctx, bw_tmp_blob->shape().elem_cnt(),
+                                  bw_tmp_blob->mut_dptr<T>(), static_cast<T>(count));
+  NdarrayUtil<device_type, T>::BroadcastTo(
+      ctx.device_ctx, XpuVarNdarray<T>(in_diff_blob, in_diff_blob->shape().NumAxes()),
+      XpuVarNdarray<const T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()),
+                             bw_tmp_blob->dptr<T>()));
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kReduceMeanConf, ReduceMeanKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/reduce_mean_kernel.h b/oneflow/core/kernel/reduce_mean_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..e3541bf50fb697fcf21db1d3ac46b93a2ecdb128
--- /dev/null
+++ b/oneflow/core/kernel/reduce_mean_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_REDUCE_MEAN_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_REDUCE_MEAN_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class ReduceMeanKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ReduceMeanKernel);
+  ReduceMeanKernel() = default;
+  ~ReduceMeanKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+
+  void BackwardDataContent(const KernelCtx& ctx,
+                           std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_REDUCE_MEAN_KERNEL_H_
diff --git a/oneflow/core/kernel/reduce_split_kernel.cpp b/oneflow/core/kernel/reduce_split_kernel.cpp
index 847c7b2ead2bbcd94992831efda0694e813353b8..1e2fefc57a3a23f0ee59fa0db4b3f86c8dcf65e2 100644
--- a/oneflow/core/kernel/reduce_split_kernel.cpp
+++ b/oneflow/core/kernel/reduce_split_kernel.cpp
@@ -5,8 +5,7 @@ namespace oneflow {
 template<DeviceType device_type>
 void ReduceSplitKernel<device_type>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
-  bool is_inplace = *static_cast<bool*>(ctx.other);
-  if (is_inplace) { return; }
+  if (device_type == DeviceType::kGPU && Global<JobDesc>::Get()->enable_mem_sharing()) { return; }
   const Blob* in_blob = BnInOp2Blob("in");
   const char* src_cur_dptr = in_blob->dptr<char>();
   for (const std::string& obn : this->op_attribute().output_bns()) {
diff --git a/oneflow/core/kernel/reduce_sum_kernel.cpp b/oneflow/core/kernel/reduce_sum_kernel.cpp
index 0da31f15a962ea43609b6b6d730d4df9c6526a5f..ff3ea04a6b18a416aa97c8e15b34e6f58c3606fd 100644
--- a/oneflow/core/kernel/reduce_sum_kernel.cpp
+++ b/oneflow/core/kernel/reduce_sum_kernel.cpp
@@ -1,4 +1,5 @@
 #include "oneflow/core/kernel/reduce_sum_kernel.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
 
 namespace oneflow {
 
@@ -8,25 +9,12 @@ void ReduceSumKernel<device_type, T>::ForwardDataContent(
   const Blob* in_blob = BnInOp2Blob("in");
   Blob* out_blob = BnInOp2Blob("out");
   Blob* fw_tmp_blob = BnInOp2Blob("fw_tmp");
-  if (this->kernel_conf().reduce_sum_conf().has_axis() == false) {
-    KernelUtil<device_type, T>::Sum(ctx.device_ctx, in_blob->shape().elem_cnt(), in_blob->dptr<T>(),
-                                    out_blob->mut_dptr<T>(), fw_tmp_blob->mut_dptr<T>(),
-                                    fw_tmp_blob->ByteSizeOfDataContentField());
-    return;
-  }
-  int32_t axis = this->kernel_conf().reduce_sum_conf().axis();
-  int64_t lhs_num = in_blob->shape().Count(0, axis);
-  int64_t middle_num = in_blob->shape().At(axis);
-  int64_t rhs_num = in_blob->shape().Count(axis + 1);
-  FOR_RANGE(int64_t, lhs_i, 0, lhs_num) {
-    const T* src_ptr = in_blob->dptr<T>() + lhs_i * middle_num * rhs_num;
-    T* dst_ptr = out_blob->mut_dptr<T>() + lhs_i * rhs_num;
-    Memcpy<device_type>(ctx.device_ctx, dst_ptr, src_ptr, rhs_num * sizeof(T));
-    FOR_RANGE(int64_t, middle_i, 1, middle_num) {
-      KernelUtil<device_type, T>::Axpy(ctx.device_ctx, rhs_num, 1.0f, src_ptr + middle_i * rhs_num,
-                                       1, dst_ptr, 1);
-    }
-  }
+  NdarrayUtil<device_type, T>::ReduceSum(
+      ctx.device_ctx,
+      XpuVarNdarray<T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()),
+                       out_blob->mut_dptr<T>()),
+      XpuVarNdarray<const T>(in_blob, in_blob->shape().NumAxes()),
+      XpuVarNdarray<T>(fw_tmp_blob, in_blob->shape().NumAxes()));
 }
 
 template<DeviceType device_type, typename T>
@@ -34,28 +22,10 @@ void ReduceSumKernel<device_type, T>::BackwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* out_diff_blob = BnInOp2Blob("out_diff");
   Blob* in_diff_blob = BnInOp2Blob("in_diff");
-
-  if (this->kernel_conf().reduce_sum_conf().has_axis() == false) {
-    T* dst_ptr = in_diff_blob->mut_dptr<T>();
-    const T* src_ptr = out_diff_blob->dptr<T>();
-    FOR_RANGE(int64_t, i, 0, in_diff_blob->shape().Count(0)) {
-      Memcpy<device_type>(ctx.device_ctx, dst_ptr++, src_ptr, sizeof(T));
-    }
-    return;
-  }
-
-  int32_t axis = this->kernel_conf().reduce_sum_conf().axis();
-  int64_t lhs_num = in_diff_blob->shape().Count(0, axis);
-  int64_t middle_num = in_diff_blob->shape().At(axis);
-  int64_t rhs_num = in_diff_blob->shape().Count(axis + 1);
-  FOR_RANGE(int64_t, lhs_i, 0, lhs_num) {
-    const T* src_ptr = out_diff_blob->dptr<T>() + lhs_i * rhs_num;
-    T* dst_ptr = in_diff_blob->mut_dptr<T>() + lhs_i * middle_num * rhs_num;
-    FOR_RANGE(int64_t, middle_i, 0, middle_num) {
-      Memcpy<device_type>(ctx.device_ctx, dst_ptr, src_ptr, rhs_num * sizeof(T));
-      dst_ptr += rhs_num;
-    }
-  }
+  NdarrayUtil<device_type, T>::BroadcastTo(
+      ctx.device_ctx, XpuVarNdarray<T>(in_diff_blob, in_diff_blob->shape().NumAxes()),
+      XpuVarNdarray<const T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()),
+                             out_diff_blob->dptr<T>()));
 }
 
 ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kReduceSumConf, ReduceSumKernel, ARITHMETIC_DATA_TYPE_SEQ);
diff --git a/oneflow/core/kernel/repeat_kernel.h b/oneflow/core/kernel/repeat_kernel.h
index ea4a1864e27756a70e2397d0ab1ba60930d5df8d..54116c5d3e4ceb737bd4f8cb9d76b17960551255 100644
--- a/oneflow/core/kernel/repeat_kernel.h
+++ b/oneflow/core/kernel/repeat_kernel.h
@@ -17,6 +17,10 @@ class RepeatKernel final : public KernelIf<device_type> {
                           std::function<Blob*(const std::string&)>) const override;
   void BackwardDataContent(const KernelCtx&,
                            std::function<Blob*(const std::string&)>) const override;
+  void BackwardInDiffLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
+    UNIMPLEMENTED();
+  }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/rsqrt_kernel.cpp b/oneflow/core/kernel/rsqrt_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1489deeb653ea5f735ff2420fc15326a493b36c7
--- /dev/null
+++ b/oneflow/core/kernel/rsqrt_kernel.cpp
@@ -0,0 +1,19 @@
+#include "oneflow/core/kernel/rsqrt_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void RsqrtKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  TODO();
+}
+
+template<DeviceType device_type, typename T>
+void RsqrtKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  TODO();
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kRsqrtConf, RsqrtKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/rsqrt_kernel.h b/oneflow/core/kernel/rsqrt_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..618de4c693040e9f1e2ab80778de78e48bab4a50
--- /dev/null
+++ b/oneflow/core/kernel/rsqrt_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_RSQRT_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_RSQRT_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class RsqrtKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(RsqrtKernel);
+  RsqrtKernel() = default;
+  ~RsqrtKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_RSQRT_KERNEL_H_
diff --git a/oneflow/core/kernel/scalar_add_kernel.cpp b/oneflow/core/kernel/scalar_add_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dbee81b59086d4288a9112e6dd50a005204da441
--- /dev/null
+++ b/oneflow/core/kernel/scalar_add_kernel.cpp
@@ -0,0 +1,37 @@
+#include "oneflow/core/kernel/scalar_add_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void ScalarAddKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  T scalar_operand = 0;
+  const auto& conf = this->op_conf().scalar_add_conf();
+  if (IsIntegral<T>::value) {
+    CHECK(conf.has_int_operand());
+    scalar_operand = static_cast<T>(conf.int_operand());
+  } else if (IsFloating<T>::value) {
+    CHECK(conf.has_float_operand());
+    scalar_operand = static_cast<T>(conf.float_operand());
+  } else {
+    UNIMPLEMENTED();
+  }
+  KernelUtil<device_type, T>::AddByScalar(ctx.device_ctx, out_blob->shape().elem_cnt(),
+                                          in_blob->dptr<T>(), scalar_operand,
+                                          out_blob->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void ScalarAddKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
+  Memcpy<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), out_diff_blob->dptr<T>(),
+                      out_diff_blob->ByteSizeOfDataContentField());
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kScalarAddConf, ScalarAddKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/scalar_add_kernel.h b/oneflow/core/kernel/scalar_add_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..9a34ec190805571047931bd838dd76205aa5980d
--- /dev/null
+++ b/oneflow/core/kernel/scalar_add_kernel.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_KERNEL_SCALAR_ADD_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SCALAR_ADD_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class ScalarAddKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ScalarAddKernel);
+  ScalarAddKernel() = default;
+  ~ScalarAddKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BackwardDataContent(const KernelCtx& ctx,
+                           std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  const PbMessage& GetCustomizedOpConf() const override {
+    return this->op_conf().scalar_add_conf();
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SCALAR_ADD_KERNEL_H_
diff --git a/oneflow/core/kernel/scalar_mul_kernel.cpp b/oneflow/core/kernel/scalar_mul_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e7ed70acb5ac26a4bddba91c90cdfdc6829551f0
--- /dev/null
+++ b/oneflow/core/kernel/scalar_mul_kernel.cpp
@@ -0,0 +1,49 @@
+#include "oneflow/core/kernel/scalar_mul_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void ScalarMulKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  T scalar_operand = 0;
+  const auto& conf = this->op_conf().scalar_mul_conf();
+  if (IsIntegral<T>::value) {
+    CHECK(conf.has_int_operand());
+    scalar_operand = static_cast<T>(conf.int_operand());
+  } else if (IsFloating<T>::value) {
+    CHECK(conf.has_float_operand());
+    scalar_operand = static_cast<T>(conf.float_operand());
+  } else {
+    UNIMPLEMENTED();
+  }
+  KernelUtil<device_type, T>::MulByScalarPara(ctx.device_ctx, out_blob->shape().elem_cnt(),
+                                              in_blob->dptr<T>(), scalar_operand,
+                                              out_blob->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void ScalarMulKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
+  T scalar_operand = 0;
+  const auto& conf = this->op_conf().scalar_mul_conf();
+  if (IsIntegral<T>::value) {
+    CHECK(conf.has_int_operand());
+    scalar_operand = static_cast<T>(conf.int_operand());
+  } else if (IsFloating<T>::value) {
+    CHECK(conf.has_float_operand());
+    scalar_operand = static_cast<T>(conf.float_operand());
+  } else {
+    UNIMPLEMENTED();
+  }
+  KernelUtil<device_type, T>::MulByScalarPara(ctx.device_ctx, in_diff_blob->shape().elem_cnt(),
+                                              out_diff_blob->dptr<T>(), scalar_operand,
+                                              in_diff_blob->mut_dptr<T>());
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kScalarMulConf, ScalarMulKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/scalar_mul_kernel.h b/oneflow/core/kernel/scalar_mul_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..2dca5e04c661c9118741bad419ebb5971af8fb47
--- /dev/null
+++ b/oneflow/core/kernel/scalar_mul_kernel.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class ScalarMulKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ScalarMulKernel);
+  ScalarMulKernel() = default;
+  ~ScalarMulKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BackwardDataContent(const KernelCtx& ctx,
+                           std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  const PbMessage& GetCustomizedOpConf() const override {
+    return this->op_conf().scalar_mul_conf();
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
diff --git a/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cpp b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2203faa82cae07340423d04182f53518a05127a8
--- /dev/null
+++ b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cpp
@@ -0,0 +1,110 @@
+#include "oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h"
+#include "oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h"
+#include "oneflow/core/kernel/softmax_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename PredType, typename LabelType>
+void SigmoidCrossEntropyLossKernel<device_type, PredType, LabelType>::VirtualLossForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const SigmoidCrossEntropyLossOpConf& conf = this->op_conf().sigmoid_cross_entropy_loss_conf();
+  const Blob* prediction = BnInOp2Blob("prediction");
+  const Blob* label = BnInOp2Blob("label");
+  Blob* loss_buf = BnInOp2Blob("loss_buf");
+  Blob* tmp_storage = BnInOp2Blob("sum_buf");
+  const size_t tmp_storage_byte_size = static_cast<size_t>(tmp_storage->shape().elem_cnt());
+  Blob* count = BnInOp2Blob("count");
+  Blob* label_num = BnInOp2Blob("label_num");
+  Blob* loss = BnInOp2Blob("loss");
+  Blob* prediction_diff = BnInOp2Blob(GenDiffBn("prediction"));
+  int64_t data_dim = label->shape().Count(1);
+  int64_t data_offset = 0;
+  if (prediction_diff != nullptr) {
+    Memset<device_type>(ctx.device_ctx, prediction_diff->mut_dptr<PredType>(), 0,
+                        prediction_diff->ByteSizeOfDataContentField());
+  }
+  FOR_RANGE(int64_t, data_index, 0, prediction->shape().At(0)) {
+    data_offset = data_dim * data_index;
+    const PredType* prediction_offset = prediction->dptr<PredType>() + data_offset;
+    const LabelType* label_offset = label->dptr<LabelType>() + data_offset;
+    SigmoidCrossEntropyLossKernelUtil<device_type, PredType, LabelType>::Forward(
+        ctx.device_ctx, conf, data_dim, prediction_offset, label_offset,
+        loss_buf->mut_dptr<PredType>(), tmp_storage->mut_dptr<PredType>(), tmp_storage_byte_size,
+        count->mut_dptr<PredType>(), label_num->mut_dptr<PredType>(),
+        loss->mut_dptr<PredType>() + data_index);
+    if (prediction_diff != nullptr) {
+      SigmoidCrossEntropyLossKernelUtil<device_type, PredType, LabelType>::Backward(
+          ctx.device_ctx, conf, data_dim, prediction_offset, label_offset,
+          label_num->dptr<PredType>(), prediction_diff->mut_dptr<PredType>() + data_offset);
+    }
+  }
+}
+
+template<DeviceType device_type, typename PredType, typename LabelType>
+const LossKernelConf&
+SigmoidCrossEntropyLossKernel<device_type, PredType, LabelType>::GetLossKernelConf(
+    const KernelConf& kernel_conf) const {
+  return kernel_conf.sigmoid_cross_entropy_loss_conf().loss_conf();
+}
+
+template<typename PredType, typename LabelType>
+struct SigmoidCrossEntropyLossKernelUtil<DeviceType::kCPU, PredType, LabelType> {
+  static void Forward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n,
+                      const PredType* prediction, const LabelType* label, PredType* loss_buf,
+                      PredType* tmp_storage, const size_t tmp_storage_byte_size, PredType* count,
+                      PredType* label_num, PredType* loss) {
+    loss_buf[0] = 0;
+    loss[0] = 0;
+    label_num[0] = 0;
+    FOR_RANGE(int64_t, index, 0, n) {
+      if (label[index] != -1) {
+        loss_buf[0] +=
+            -1 * prediction[index] * (label[index] - (prediction[index] >= 0))
+            + logf(1 + expf(prediction[index] - 2 * prediction[index] * (prediction[index] >= 0)));
+        label_num[0] += 1;
+      }
+    }
+    loss_buf[0] *= static_cast<PredType>(conf.scale());
+    if (conf.normalize()) {
+      if (label_num[0] == 0) { label_num[0] = 1e-5; }
+      loss[0] = loss_buf[0] / label_num[0];
+    }
+  }
+
+  static void Backward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n,
+                       const PredType* prediction, const LabelType* label,
+                       const PredType* label_num, PredType* pred_diff) {
+    FOR_RANGE(int64_t, index, 0, n) {
+      if (label[index] != -1) {
+        pred_diff[index] = 1.f / (1.f + expf(-prediction[index])) - label[index];
+        pred_diff[index] *= static_cast<PredType>(conf.scale());
+        if (conf.normalize()) { pred_diff[index] /= label_num[0]; }
+      }
+    }
+  }
+};
+
+namespace {
+
+Kernel* CreateSigmoidCrossEntropyLossKernel(const KernelConf& kernel_conf) {
+  static const HashMap<std::string, std::function<Kernel*()>> creators = {
+#define SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, label_type_pair)      \
+  {GetHashKey(device_type, OF_PP_PAIR_SECOND(pred_type_pair), OF_PP_PAIR_SECOND(label_type_pair)), \
+   []() {                                                                                          \
+     return new SigmoidCrossEntropyLossKernel<device_type, OF_PP_PAIR_FIRST(pred_type_pair),       \
+                                              OF_PP_PAIR_FIRST(label_type_pair)>();                \
+   }},
+      OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_ENTRY, DEVICE_TYPE_SEQ,
+                                       FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)};
+  return creators.at(
+      GetHashKey(kernel_conf.op_attribute().op_conf().device_type(),
+                 kernel_conf.sigmoid_cross_entropy_loss_conf().loss_conf().prediction_type(),
+                 kernel_conf.sigmoid_cross_entropy_loss_conf().loss_conf().label_type()))();
+}
+
+}  // namespace
+
+REGISTER_KERNEL_CREATOR(OperatorConf::kSigmoidCrossEntropyLossConf,
+                        CreateSigmoidCrossEntropyLossKernel);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cu b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..fe7355e06195a3f586d25c4a2033d9c6e5bb8a76
--- /dev/null
+++ b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cu
@@ -0,0 +1,83 @@
+#include "oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h"
+
+namespace oneflow {
+
+namespace {
+template<typename PredType>
+__global__ void NoSmallerThan(const int n, PredType* x, const float floor_val) {
+  CUDA_1D_KERNEL_LOOP(index, n) { x[index] = (x[index] > floor_val) ? x[index] : floor_val; }
+}
+
+template<typename PredType, typename LabelType>
+__global__ void SigmoidCrossEntropyLossForward(const int64_t n, const PredType* prediction,
+                                               const LabelType* label, PredType* loss_buf,
+                                               PredType* count) {
+  CUDA_1D_KERNEL_LOOP(index, n) {
+    if (label[index] == -1) {
+      loss_buf[index] = 0.f;
+      count[index] = 0.f;
+    } else {
+      loss_buf[index] =
+          -1.f * prediction[index] * (label[index] - (prediction[index] >= 0))
+          + logf(1 + expf(prediction[index] - 2 * prediction[index] * (prediction[index] >= 0)));
+      count[index] = 1.f;
+    }
+  }
+}
+
+template<typename PredType, typename LabelType>
+__global__ void SigmoidCrossEntropyLossBackward(const int64_t n, const PredType* prediction,
+                                                const LabelType* label, PredType* pred_diff) {
+  CUDA_1D_KERNEL_LOOP(index, n) {
+    if (label[index] == -1) {
+      pred_diff[index] = 0.f;
+    } else {
+      pred_diff[index] = 1.f / (1.f + expf(-prediction[index])) - label[index];
+    }
+  }
+}
+}  // namespace
+
+template<typename PredType, typename LabelType>
+struct SigmoidCrossEntropyLossKernelUtil<DeviceType::kGPU, PredType, LabelType> {
+  static void Forward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n,
+                      const PredType* prediction, const LabelType* label, PredType* loss_buf,
+                      PredType* tmp_storage, const size_t tmp_storage_byte_size, PredType* count,
+                      PredType* label_num, PredType* loss) {
+    SigmoidCrossEntropyLossForward<PredType>
+        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            n, prediction, label, loss_buf, count);
+    KernelUtil<DeviceType::kGPU, PredType>::Sum(ctx, n, loss_buf, loss, tmp_storage,
+                                                tmp_storage_byte_size);
+    if (conf.normalize()) {
+      KernelUtil<DeviceType::kGPU, PredType>::Sum(ctx, n, count, label_num, tmp_storage,
+                                                  tmp_storage_byte_size);
+      NoSmallerThan<PredType>
+          <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+              1, label_num, 1e-5);
+      KernelUtil<DeviceType::kGPU, PredType>::Div(ctx, 1, loss, label_num);
+    }
+    KernelUtil<DeviceType::kGPU, PredType>::Scal(ctx, 1, static_cast<PredType>(conf.scale()), loss,
+                                                 1);
+  }
+
+  static void Backward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n,
+                       const PredType* prediction, const LabelType* label,
+                       const PredType* label_num, PredType* pred_diff) {
+    SigmoidCrossEntropyLossBackward<PredType>
+        <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            n, prediction, label, pred_diff);
+    KernelUtil<DeviceType::kGPU, PredType>::Scal(ctx, n, static_cast<PredType>(conf.scale()),
+                                                 pred_diff, 1);
+    if (conf.normalize()) {
+      KernelUtil<DeviceType::kGPU, PredType>::Div(ctx, n, pred_diff, label_num);
+    }
+  }
+};
+
+#define INSTANTIATE_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_UTIL(data_type_pair, label_type_pair) \
+  template struct SigmoidCrossEntropyLossKernelUtil<                                        \
+      DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(label_type_pair)>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_UTIL,
+                                 FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..13435a3ab52719ff2e11b5ad36ec576829339546
--- /dev/null
+++ b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h
@@ -0,0 +1,34 @@
+#ifndef ONEFLOW_CORE_KERNEL_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_H_
+
+#include "oneflow/core/kernel/loss_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename PredType, typename LabelType>
+class SigmoidCrossEntropyLossKernel final : public LossKernel<device_type, PredType> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SigmoidCrossEntropyLossKernel);
+  SigmoidCrossEntropyLossKernel() = default;
+  ~SigmoidCrossEntropyLossKernel() = default;
+
+ private:
+  void VirtualLossForwardDataContent(const KernelCtx&,
+                                     std::function<Blob*(const std::string&)>) const override;
+  const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const override;
+};
+
+template<DeviceType device_type, typename PredType, typename LabelType>
+struct SigmoidCrossEntropyLossKernelUtil {
+  static void Forward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n,
+                      const PredType* prediction, const LabelType* label, PredType* loss_buf,
+                      PredType* tmp_storage, const size_t tmp_storage_byte_size, PredType* count,
+                      PredType* label_num, PredType* loss);
+  static void Backward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n,
+                       const PredType* prediction, const LabelType* label,
+                       const PredType* label_num, PredType* pred_diff);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_H_
diff --git a/oneflow/core/kernel/slice_kernel.cpp b/oneflow/core/kernel/slice_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3ddbbecd65f426a65c59f68f604adbdbd711293a
--- /dev/null
+++ b/oneflow/core/kernel/slice_kernel.cpp
@@ -0,0 +1,127 @@
+#include "oneflow/core/kernel/slice_kernel.h"
+#include "oneflow/core/ndarray/ndarray_helper.h"
+
+namespace oneflow {
+
+namespace {
+
+int64_t GetStart(const DimSliceConf& conf) {
+  CHECK_GT(conf.stride(), 0);
+  return conf.has_start() ? conf.start() : Slice::kStart;
+}
+
+int64_t GetEnd(const DimSliceConf& conf) {
+  CHECK_GT(conf.stride(), 0);
+  return conf.has_end() ? conf.end() : Slice::kEnd;
+}
+
+int64_t GetStride(const DimSliceConf& conf) { return conf.stride(); }
+
+}  // namespace
+
+template<typename T>
+void SliceKernel<DeviceType::kCPU, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const SliceOpConf& conf = this->op_conf().slice_conf();
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  CHECK_EQ(in_blob->shape().NumAxes(), out_blob->shape().NumAxes());
+
+  switch (out_blob->shape().NumAxes()) {
+// clang-format off
+#define MAKE_CASE(num_axes)                                                                   \
+    case num_axes: {                                                                          \
+      NdArraySliceUtil<T, num_axes>::Forward(ctx.device_ctx, conf.dim_slice_conf(), in_blob,  \
+                                             out_blob);                                       \
+      break;                                                                                  \
+    }
+    MAKE_CASE(2);
+    MAKE_CASE(3);
+#undef MAKE_CASE
+    // clang-format on
+    default: { UNIMPLEMENTED(); }
+  }
+}
+
+template<typename T>
+void SliceKernel<DeviceType::kCPU, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const SliceOpConf& conf = this->op_conf().slice_conf();
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
+  CHECK_EQ(out_diff_blob->shape().NumAxes(), in_diff_blob->shape().NumAxes());
+
+  Memset<DeviceType::kCPU>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), 0,
+                           in_diff_blob->ByteSizeOfDataContentField());
+
+  switch (in_diff_blob->shape().NumAxes()) {
+// clang-format off
+#define MAKE_CASE(num_axes)                                                           \
+    case num_axes: {                                                                  \
+      NdArraySliceUtil<T, num_axes>::Backward(ctx.device_ctx, conf.dim_slice_conf(),  \
+                                              out_diff_blob, in_diff_blob);           \
+      break;                                                                          \
+    }
+    MAKE_CASE(2);
+    MAKE_CASE(3);
+#undef MAKE_CASE
+    // clang-format on
+    default: { UNIMPLEMENTED(); }
+  }
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSliceConf, SliceKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+template<typename T>
+struct NdArraySliceUtil<T, 2> final {
+  static void Forward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice,
+                      const Blob* in_blob, Blob* out_blob) {
+    NdArrayHelper<T, 2> ndarray;
+    auto&& in_ndarray = ndarray.Var(in_blob->shape(), const_cast<T*>(in_blob->dptr<T>()));
+    auto&& out_ndarray = ndarray.Var(out_blob->shape(), out_blob->mut_dptr<T>());
+    out_ndarray.Assign(in_ndarray({}, {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)),
+                                       GetStride(rep_dim_slice.Get(0))}));
+  }
+
+  static void Backward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice,
+                       const Blob* out_diff_blob, Blob* in_diff_blob) {
+    NdArrayHelper<T, 2> ndarray;
+    auto&& out_diff_ndarray =
+        ndarray.Var(out_diff_blob->shape(), const_cast<T*>(out_diff_blob->dptr<T>()));
+    auto&& in_diff_ndarray = ndarray.Var(in_diff_blob->shape(), in_diff_blob->mut_dptr<T>());
+    in_diff_ndarray({}, {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)),
+                         GetStride(rep_dim_slice.Get(0))})
+        .Assign(out_diff_ndarray({}, {}));
+  }
+};
+
+template<typename T>
+struct NdArraySliceUtil<T, 3> final {
+  static void Forward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice,
+                      const Blob* in_blob, Blob* out_blob) {
+    NdArrayHelper<T, 3> ndarray;
+    auto&& in_ndarray = ndarray.Var(in_blob->shape(), const_cast<T*>(in_blob->dptr<T>()));
+    auto&& out_ndarray = ndarray.Var(out_blob->shape(), out_blob->mut_dptr<T>());
+    out_ndarray.Assign(in_ndarray({},
+                                  {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)),
+                                   GetStride(rep_dim_slice.Get(0))},
+                                  {GetStart(rep_dim_slice.Get(1)), GetEnd(rep_dim_slice.Get(1)),
+                                   GetStride(rep_dim_slice.Get(1))}));
+  }
+
+  static void Backward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice,
+                       const Blob* out_diff_blob, Blob* in_diff_blob) {
+    NdArrayHelper<T, 3> ndarray;
+    auto&& out_diff_ndarray =
+        ndarray.Var(out_diff_blob->shape(), const_cast<T*>(out_diff_blob->dptr<T>()));
+    auto&& in_diff_ndarray = ndarray.Var(in_diff_blob->shape(), in_diff_blob->mut_dptr<T>());
+    in_diff_ndarray({},
+                    {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)),
+                     GetStride(rep_dim_slice.Get(0))},
+                    {GetStart(rep_dim_slice.Get(1)), GetEnd(rep_dim_slice.Get(1)),
+                     GetStride(rep_dim_slice.Get(1))})
+        .Assign(out_diff_ndarray({}, {}, {}));
+  }
+};
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/slice_kernel.cu b/oneflow/core/kernel/slice_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..515b36e6864079b9d5e5d0192ff3d2c03ba577a9
--- /dev/null
+++ b/oneflow/core/kernel/slice_kernel.cu
@@ -0,0 +1,87 @@
+#include "oneflow/core/kernel/slice_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void SliceForwardGpu(const int64_t n, const int64_t* offset, const T* entire, T* slice) {
+  CUDA_1D_KERNEL_LOOP(i, n) { slice[i] = entire[offset[i]]; }
+}
+
+template<typename T>
+__global__ void SliceBackwardGpu(const int64_t n, const int64_t* offset, const T* slice,
+                                 T* entire) {
+  CUDA_1D_KERNEL_LOOP(i, n) { entire[offset[i]] = slice[i]; }
+}
+
+}  // namespace
+
+template<typename T>
+void SliceKernel<DeviceType::kGPU, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  const Blob* offset_blob = BnInOp2Blob("out_to_in_offset");
+  Blob* out_blob = BnInOp2Blob("out");
+  const int64_t num_output = out_blob->shape().elem_cnt();
+  SliceForwardGpu<T><<<BlocksNum4ThreadsNum(num_output), kCudaThreadsNumPerBlock, 0,
+                       ctx.device_ctx->cuda_stream()>>>(
+      num_output, offset_blob->dptr<int64_t>(), in_blob->dptr<T>(), out_blob->mut_dptr<T>());
+}
+
+template<typename T>
+void SliceKernel<DeviceType::kGPU, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  const Blob* offset_blob = BnInOp2Blob("out_to_in_offset");
+  Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in"));
+  const int64_t num_output = out_diff_blob->shape().elem_cnt();
+  Memset<DeviceType::kGPU>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), 0,
+                           in_diff_blob->ByteSizeOfDataContentField());
+  SliceBackwardGpu<T><<<BlocksNum4ThreadsNum(num_output), kCudaThreadsNumPerBlock, 0,
+                        ctx.device_ctx->cuda_stream()>>>(num_output, offset_blob->dptr<int64_t>(),
+                                                         out_diff_blob->dptr<T>(),
+                                                         in_diff_blob->mut_dptr<T>());
+}
+
+template<typename T>
+void SliceKernel<DeviceType::kGPU, T>::InitConstBufBlobs(
+    DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Shape in_shape(this->kernel_conf().slice_conf().in_shape());
+  InitOut2InOffsetFromHost(ctx, in_shape, BnInOp2Blob("out_to_in_offset"));
+}
+
+template<typename T>
+void SliceKernel<DeviceType::kGPU, T>::InitOut2InOffsetFromHost(DeviceCtx* ctx,
+                                                                const Shape& in_shape,
+                                                                Blob* blob) const {
+  const SliceOpConf& conf = op_conf().slice_conf();
+  BEFORE_CPU_INITIALIZE();
+  int64_t* host_blob_ptr = host_blob->mut_dptr<int64_t>();
+  FOR_RANGE(int64_t, i, 0, host_blob->shape().elem_cnt()) {
+    int64_t offset = 0;
+    int64_t index = i;
+    FOR_RANGE(int64_t, j, 0, host_blob->shape().NumAxes()) {
+      const int64_t dim_elem_cnt = host_blob->shape().Count(j + 1);
+      const int64_t dim_i = index / dim_elem_cnt;
+      index = index % dim_elem_cnt;
+      int64_t start = 0;
+      int64_t stride = 1;
+      if (j > 0) {
+        const DimSliceConf& dim_slice_conf = conf.dim_slice_conf(j - 1);
+        if (dim_slice_conf.has_start()) { start = dim_slice_conf.start(); }
+        if (start < 0) { start += host_blob->shape().At(j); }
+        stride = dim_slice_conf.stride();
+      }
+      offset += (start + dim_i * stride) * in_shape.Count(j + 1);
+    }
+    host_blob_ptr[i] = offset;
+  }
+  AFTER_CPU_INITIALIZE();
+}
+
+#define INSTANTIATE_GPU_SLICE_KERNEL(type_cpp, type_proto) \
+  template struct SliceKernel<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GPU_SLICE_KERNEL, ARITHMETIC_DATA_TYPE_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/slice_kernel.h b/oneflow/core/kernel/slice_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa438f10fd1110566889b0cd12f881dcf9477c60
--- /dev/null
+++ b/oneflow/core/kernel/slice_kernel.h
@@ -0,0 +1,48 @@
+#ifndef ONEFLOW_CORE_KERNEL_SLICE_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SLICE_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class SliceKernel;
+
+template<typename T>
+class SliceKernel<DeviceType::kCPU, T> final : public KernelIf<DeviceType::kCPU> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SliceKernel);
+  SliceKernel() = default;
+  ~SliceKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+template<typename T>
+class SliceKernel<DeviceType::kGPU, T> final : public KernelIf<DeviceType::kGPU> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SliceKernel);
+  SliceKernel() = default;
+  ~SliceKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  void InitConstBufBlobs(DeviceCtx*,
+                         std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+
+  void InitOut2InOffsetFromHost(DeviceCtx* ctx, const Shape& in_shape, Blob* blob) const;
+};
+
+template<typename T, size_t NDIMS>
+struct NdArraySliceUtil;
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SLICE_KERNEL_H_
diff --git a/oneflow/core/kernel/softmax_kernel.cpp b/oneflow/core/kernel/softmax_kernel.cpp
index 896a7d5ccf094961d1b547231fa0a81189777fc2..f68327c908f83f386d4de9451ce99e0e9b88fbd4 100644
--- a/oneflow/core/kernel/softmax_kernel.cpp
+++ b/oneflow/core/kernel/softmax_kernel.cpp
@@ -1,6 +1,7 @@
 #include "oneflow/core/kernel/softmax_kernel.h"
 #include "oneflow/core/kernel/kernel.h"
 #include "oneflow/core/kernel/transpose_kernel.h"
+#include "oneflow/core/ndarray/ndarray_util.h"
 
 namespace oneflow {
 
@@ -14,7 +15,10 @@ void SoftmaxComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const
   // dot product | get dot product sum_vec[i] from out[i] * out_diff[i]
   T* tmp = in_diff;
   KernelUtil<device_type, T>::Mul(ctx, n * w, out, out_diff, tmp);
-  KernelUtil<device_type, T>::RowSum(ctx, n, w, tmp, sum_vec, temp_storage, temp_storage_bytes);
+  NdarrayUtil<device_type, T>::ReduceSum(
+      ctx, XpuVarNdarray<T>({n, 1}, sum_vec), XpuVarNdarray<const T>({n, w}, tmp),
+      XpuVarNdarray<T>({static_cast<int64_t>(temp_storage_bytes / sizeof(T))},
+                       reinterpret_cast<T*>(temp_storage)));
   // copy out_diff to in_diff
   KernelUtil<device_type, T>::Copy(ctx, n * w, out_diff, 1, in_diff, 1);
   // sub | in_diff[i][j] -= sum_vec[i]
@@ -25,6 +29,32 @@ void SoftmaxComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const
 
 }  // namespace
 
+template<DeviceType device_type, typename T>
+void SoftmaxComputeProb(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* in, T* tmp,
+                        T* prob, void* temp_storage, const size_t temp_storage_bytes) {
+  // copy in blob to prob blob
+  KernelUtil<device_type, T>::Copy(ctx, n * w, in, 1, prob, 1);
+  // max | calculate max of every sample vector prob[i], store in tmp[i]
+  //       the prob[i] now is store the data of in[i]
+  NdarrayUtil<device_type, T>::ReduceMax(
+      ctx, XpuVarNdarray<T>({n, 1}, tmp), XpuVarNdarray<const T>({n, w}, prob),
+      XpuVarNdarray<T>({static_cast<int64_t>(temp_storage_bytes / sizeof(T))},
+                       reinterpret_cast<T*>(temp_storage)));
+  // sub | every element of prob blob subract the max value of the same sample
+  SoftmaxKernelUtil<device_type, T>::Sub(ctx, n, w, prob, tmp);
+  // exp | exponentiation every element
+  KernelUtil<device_type, T>::Exp(ctx, n * w, prob, prob);
+  // sum | calculate sum of every sample vector prob[i], store in tmp[i]
+  //       the prob[i] now is store the tmp data after exp
+  NdarrayUtil<device_type, T>::ReduceSum(
+      ctx, XpuVarNdarray<T>({n, 1}, tmp), XpuVarNdarray<const T>({n, w}, prob),
+      XpuVarNdarray<T>({static_cast<int64_t>(temp_storage_bytes / sizeof(T))},
+                       reinterpret_cast<T*>(temp_storage)));
+  // div | every element of prob[i] divided by the data of tmp[i] (the sum
+  // value)
+  SoftmaxKernelUtil<device_type, T>::Div(ctx, n, w, prob, tmp);
+}
+
 template<DeviceType device_type, typename T>
 void SoftmaxKernel<device_type, T>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
diff --git a/oneflow/core/kernel/softmax_kernel.h b/oneflow/core/kernel/softmax_kernel.h
index 8f1903aa6d5440c0a6831b1f4359d37ad697d7f7..c7f6bf15db54549f31d3d143f7c9cd10f1998319 100644
--- a/oneflow/core/kernel/softmax_kernel.h
+++ b/oneflow/core/kernel/softmax_kernel.h
@@ -32,23 +32,7 @@ struct SoftmaxKernelUtil {
 
 template<DeviceType device_type, typename T>
 void SoftmaxComputeProb(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* in, T* tmp,
-                        T* prob, void* temp_storage, const size_t temp_storage_bytes) {
-  // copy in blob to prob blob
-  KernelUtil<device_type, T>::Copy(ctx, n * w, in, 1, prob, 1);
-  // max | calculate max of every sample vector prob[i], store in tmp[i]
-  //       the prob[i] now is store the data of in[i]
-  KernelUtil<device_type, T>::RowMax(ctx, n, w, prob, tmp, temp_storage, temp_storage_bytes);
-  // sub | every element of prob blob subract the max value of the same sample
-  SoftmaxKernelUtil<device_type, T>::Sub(ctx, n, w, prob, tmp);
-  // exp | exponentiation every element
-  KernelUtil<device_type, T>::Exp(ctx, n * w, prob, prob);
-  // sum | calculate sum of every sample vector prob[i], store in tmp[i]
-  //       the prob[i] now is store the tmp data after exp
-  KernelUtil<device_type, T>::RowSum(ctx, n, w, prob, tmp, temp_storage, temp_storage_bytes);
-  // div | every element of prob[i] divided by the data of tmp[i] (the sum
-  // value)
-  SoftmaxKernelUtil<device_type, T>::Div(ctx, n, w, prob, tmp);
-}
+                        T* prob, void* temp_storage, const size_t temp_storage_bytes);
 
 }  // namespace oneflow
 
diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel.cpp b/oneflow/core/kernel/sparse_cross_entropy_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e5adee9ea9f54e3e712fb97e57b8095e8c6f6d6b
--- /dev/null
+++ b/oneflow/core/kernel/sparse_cross_entropy_kernel.cpp
@@ -0,0 +1,66 @@
+#include "oneflow/core/kernel/sparse_cross_entropy_kernel.h"
+#include "oneflow/core/kernel/sparse_cross_entropy_kernel_util.h"
+
+namespace oneflow {
+
+namespace {
+
+template<DeviceType device_type, typename T, typename K>
+void Forward(DeviceCtx* ctx, const Blob* prediction, const Blob* label, Blob* out) {
+  const int64_t num_instances = label->shape().elem_cnt();
+  CHECK_EQ(prediction->shape().elem_cnt() % num_instances, 0);
+  const int64_t num_classes = prediction->shape().elem_cnt() / num_instances;
+  SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeEntropy(
+      ctx, num_instances, num_classes, prediction->dptr<T>(), label->dptr<K>(), out->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T, typename K>
+void Backward(DeviceCtx* ctx, const Blob* prediction, const Blob* label, const Blob* out_diff,
+              Blob* prediction_diff) {
+  const int64_t num_instances = label->shape().elem_cnt();
+  CHECK_EQ(prediction->shape().elem_cnt() % num_instances, 0);
+  const int64_t num_classes = prediction->shape().elem_cnt() / num_instances;
+  Memset<device_type>(ctx, prediction_diff->mut_dptr<T>(), 0,
+                      prediction_diff->ByteSizeOfDataContentField());
+  SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeDiff(
+      ctx, num_instances, num_classes, prediction->dptr<T>(), label->dptr<K>(), out_diff->dptr<T>(),
+      prediction_diff->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+struct SparseCrossEntropyUntil final {
+#define MAKE_CROSS_ENTROPY_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K>
+  DEFINE_STATIC_SWITCH_FUNC(void, Forward, MAKE_CROSS_ENTROPY_SWITCH_ENTRY,
+                            MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
+  DEFINE_STATIC_SWITCH_FUNC(void, Backward, MAKE_CROSS_ENTROPY_SWITCH_ENTRY,
+                            MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
+#undef MAKE_CROSS_ENTROPY_SWITCH_ENTRY
+};
+
+}  // namespace
+
+template<DeviceType device_type, typename T>
+void SparseCrossEntropyKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* prediction = BnInOp2Blob("prediction");
+  const Blob* label = BnInOp2Blob("label");
+  Blob* out = BnInOp2Blob("out");
+  SparseCrossEntropyUntil<device_type, T>::SwitchForward(SwitchCase(label->data_type()),
+                                                         ctx.device_ctx, prediction, label, out);
+}
+
+template<DeviceType device_type, typename T>
+void SparseCrossEntropyKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* prediction = BnInOp2Blob("prediction");
+  const Blob* label = BnInOp2Blob("label");
+  const Blob* out_diff = BnInOp2Blob(GenDiffBn("out"));
+  Blob* prediction_diff = BnInOp2Blob(GenDiffBn("prediction"));
+  SparseCrossEntropyUntil<device_type, T>::SwitchBackward(
+      SwitchCase(label->data_type()), ctx.device_ctx, prediction, label, out_diff, prediction_diff);
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSparseCrossEntropyConf, SparseCrossEntropyKernel,
+                           FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel.h b/oneflow/core/kernel/sparse_cross_entropy_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..6dca30ebb54e0e8c546e72f835fe73028485b9ed
--- /dev/null
+++ b/oneflow/core/kernel/sparse_cross_entropy_kernel.h
@@ -0,0 +1,24 @@
+#ifndef ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class SparseCrossEntropyKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SparseCrossEntropyKernel);
+  SparseCrossEntropyKernel() = default;
+  ~SparseCrossEntropyKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_H_
diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cpp b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c80393b76e1c3d0f8448bd98cb33e6a39e249e64
--- /dev/null
+++ b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cpp
@@ -0,0 +1,46 @@
+#include "oneflow/core/kernel/sparse_cross_entropy_kernel_util.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+
+namespace oneflow {
+
+template<typename T, typename K>
+struct SparseCrossEntropyKernelUtil<DeviceType::kCPU, T, K> {
+  static void ComputeEntropy(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                             const K* labels, T* y) {
+    FOR_RANGE(int64_t, i, 0, num_instances) {
+      K label = labels[i];
+      CHECK_GE(label, 0);
+      CHECK_LT(label, num_classes);
+      y[i] = -SafeLog(x[i * num_classes + label]);
+    }
+  }
+
+  static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                          const K* labels, T* dx) {
+    FOR_RANGE(int64_t, i, 0, num_instances) {
+      K label = labels[i];
+      CHECK_GE(label, 0);
+      CHECK_LT(label, num_classes);
+      dx[i * num_classes + label] = -1 / MaxWithLogThreshold(x[i * num_classes + label]);
+    }
+  }
+
+  static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                          const K* labels, const T* dy, T* dx) {
+    FOR_RANGE(int64_t, i, 0, num_instances) {
+      K label = labels[i];
+      CHECK_GE(label, 0);
+      CHECK_LT(label, num_classes);
+      dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]);
+    }
+  }
+};
+
+#define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU(data_type_pair, index_type_pair)          \
+  template struct SparseCrossEntropyKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), \
+                                               OF_PP_PAIR_FIRST(index_type_pair)>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU,
+                                 FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);
+#undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cu b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cu
new file mode 100644
index 0000000000000000000000000000000000000000..b62ac198af9dcf922a3413e81f6a37d3b0c5c437
--- /dev/null
+++ b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cu
@@ -0,0 +1,71 @@
+#include "oneflow/core/kernel/sparse_cross_entropy_kernel_util.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, typename K>
+__global__ void ComputeEntropyGpu(int64_t num_instances, int64_t num_classes, const T* x,
+                                  const K* labels, T* y) {
+  CUDA_1D_KERNEL_LOOP(i, num_instances) {
+    K label = labels[i];
+    assert(label >= 0);
+    assert(label < num_classes);
+    y[i] = -SafeLog(x[i * num_classes + label]);
+  }
+}
+
+template<typename T, typename K>
+__global__ void ComputeDiffGpu(int64_t num_instances, int64_t num_classes, const T* x,
+                               const K* labels, T* dx) {
+  CUDA_1D_KERNEL_LOOP(i, num_instances) {
+    K label = labels[i];
+    assert(label >= 0);
+    assert(label < num_classes);
+    dx[i * num_classes + label] = -1 / MaxWithLogThreshold(x[i * num_classes + label]);
+  }
+}
+
+template<typename T, typename K>
+__global__ void ComputeDiffGpu(int64_t num_instances, int64_t num_classes, const T* x,
+                               const K* labels, const T* dy, T* dx) {
+  CUDA_1D_KERNEL_LOOP(i, num_instances) {
+    K label = labels[i];
+    assert(label >= 0);
+    assert(label < num_classes);
+    dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]);
+  }
+}
+
+}  // namespace
+
+template<typename T, typename K>
+struct SparseCrossEntropyKernelUtil<DeviceType::kGPU, T, K> {
+  static void ComputeEntropy(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                             const K* labels, T* y) {
+    ComputeEntropyGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,
+                        ctx->cuda_stream()>>>(num_instances, num_classes, x, labels, y);
+  }
+
+  static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                          const K* labels, T* dx) {
+    ComputeDiffGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,
+                     ctx->cuda_stream()>>>(num_instances, num_classes, x, labels, dx);
+  }
+
+  static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                          const K* labels, const T* dy, T* dx) {
+    ComputeDiffGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0,
+                     ctx->cuda_stream()>>>(num_instances, num_classes, x, labels, dy, dx);
+  }
+};
+
+#define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_GPU(data_type_pair, index_type_pair)          \
+  template struct SparseCrossEntropyKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), \
+                                               OF_PP_PAIR_FIRST(index_type_pair)>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_GPU,
+                                 FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);
+#undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_GPU
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel_util.h b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..4676e813b4c146d11f8e8ae7d7c30f095c496304
--- /dev/null
+++ b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.h
@@ -0,0 +1,20 @@
+#ifndef ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_
+#define ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_
+
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, typename K>
+struct SparseCrossEntropyKernelUtil {
+  static void ComputeEntropy(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                             const K* labels, T* y);
+  static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                          const K* labels, T* dx);
+  static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x,
+                          const K* labels, const T* dy, T* dx);
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_
diff --git a/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h b/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h
index 5ea7cf1d5fcbb4f62f63e8f598100e184d0e209d..f52a09b0392a4d939a29a0d1e70860f1043193c8 100644
--- a/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h
+++ b/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h
@@ -6,7 +6,7 @@
 namespace oneflow {
 
 template<DeviceType device_type, typename PredType, typename LabelType>
-class SparseCrossEntropyLossKernel final : public LossKernel<device_type, PredType, LabelType> {
+class SparseCrossEntropyLossKernel final : public LossKernel<device_type, PredType> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(SparseCrossEntropyLossKernel);
   SparseCrossEntropyLossKernel() = default;
diff --git a/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h b/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h
index c0f7d03d1a0dd5d311c19d027917a1d6d0fbbb16..60c34cdb6091bcd4b2922c046d962bb1581a6cbb 100644
--- a/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h
+++ b/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h
@@ -6,8 +6,7 @@
 namespace oneflow {
 
 template<DeviceType device_type, typename PredType, typename LabelType>
-class SparseSoftmaxCrossEntropyLossKernel final
-    : public LossKernel<device_type, PredType, LabelType> {
+class SparseSoftmaxCrossEntropyLossKernel final : public LossKernel<device_type, PredType> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(SparseSoftmaxCrossEntropyLossKernel);
   SparseSoftmaxCrossEntropyLossKernel() = default;
diff --git a/oneflow/core/kernel/sqrt_kernel.cpp b/oneflow/core/kernel/sqrt_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..aa51d2ce0f0dcfc2b16c85d3df6f916105a76e6f
--- /dev/null
+++ b/oneflow/core/kernel/sqrt_kernel.cpp
@@ -0,0 +1,29 @@
+#include "oneflow/core/kernel/sqrt_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void SqrtKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  KernelUtil<device_type, T>::Sqrt(ctx.device_ctx, in_blob->static_shape().elem_cnt(),
+                                   in_blob->dptr<T>(), out_blob->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void SqrtKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Blob* out_blob = BnInOp2Blob("out");
+  Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  Blob* in_diff_blob = BnInOp2Blob("in_diff");
+  KernelUtil<device_type, T>::Div(ctx.device_ctx, out_blob->static_shape().elem_cnt(),
+                                  out_diff_blob->dptr<T>(), out_blob->dptr<T>(),
+                                  in_diff_blob->mut_dptr<T>());
+  KernelUtil<device_type, T>::Scal(ctx.device_ctx, out_blob->static_shape().elem_cnt(),
+                                   static_cast<T>(0.5), in_diff_blob->mut_dptr<T>(), 1);
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSqrtConf, SqrtKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/sqrt_kernel.h b/oneflow/core/kernel/sqrt_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..ae206ef4bf56c98a00d443489c2d53d81356e445
--- /dev/null
+++ b/oneflow/core/kernel/sqrt_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_SQRT_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SQRT_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class SqrtKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SqrtKernel);
+  SqrtKernel() = default;
+  ~SqrtKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SQRT_KERNEL_H_
diff --git a/oneflow/core/kernel/square_kernel.cpp b/oneflow/core/kernel/square_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f0dd0eb7d5dd3d7fb8d10b5277ce38f588a44b95
--- /dev/null
+++ b/oneflow/core/kernel/square_kernel.cpp
@@ -0,0 +1,29 @@
+#include "oneflow/core/kernel/square_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void SquareKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_blob = BnInOp2Blob("out");
+  KernelUtil<device_type, T>::Square(ctx.device_ctx, in_blob->static_shape().elem_cnt(),
+                                     in_blob->dptr<T>(), out_blob->mut_dptr<T>());
+}
+
+template<DeviceType device_type, typename T>
+void SquareKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  Blob* in_blob = BnInOp2Blob("in");
+  Blob* out_diff_blob = BnInOp2Blob("out_diff");
+  Blob* in_diff_blob = BnInOp2Blob("in_diff");
+  KernelUtil<device_type, T>::Mul(ctx.device_ctx, in_blob->static_shape().elem_cnt(),
+                                  out_diff_blob->dptr<T>(), in_blob->dptr<T>(),
+                                  in_diff_blob->mut_dptr<T>());
+  KernelUtil<device_type, T>::Scal(ctx.device_ctx, in_blob->static_shape().elem_cnt(),
+                                   static_cast<T>(2), in_diff_blob->mut_dptr<T>(), 1);
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSquareConf, SquareKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/square_kernel.h b/oneflow/core/kernel/square_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..af578ccf48a342c42fd20b50a9be7ad32de8d205
--- /dev/null
+++ b/oneflow/core/kernel/square_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_SQUARE_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_SQUARE_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class SquareKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SquareKernel);
+  SquareKernel() = default;
+  ~SquareKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_SQUARE_KERNEL_H_
diff --git a/oneflow/core/kernel/tick_kernel.cpp b/oneflow/core/kernel/tick_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..28dee87c66b46eb9be8b57fcebb9d8278da26ddb
--- /dev/null
+++ b/oneflow/core/kernel/tick_kernel.cpp
@@ -0,0 +1,7 @@
+#include "oneflow/core/kernel/tick_kernel.h"
+
+namespace oneflow {
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kTickConf, TickKernel, ARITHMETIC_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/tick_kernel.h b/oneflow/core/kernel/tick_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..dfe8f36201b528da0bc16b55e1a4442e6fc8053e
--- /dev/null
+++ b/oneflow/core/kernel/tick_kernel.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_KERNEL_TICK_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_TICK_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class TickKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TickKernel);
+  TickKernel() = default;
+  ~TickKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx& ctx,
+                          std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
+  void BackwardDataContent(const KernelCtx& ctx,
+                           std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
+    UNIMPLEMENTED();
+  }
+  const PbMessage& GetCustomizedOpConf() const override { return this->op_conf().tick_conf(); }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_TICK_KERNEL_H_
diff --git a/oneflow/core/kernel/top_k_kernel.cpp b/oneflow/core/kernel/top_k_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b83dfb12675d67e91f8d1910717de9fda010607f
--- /dev/null
+++ b/oneflow/core/kernel/top_k_kernel.cpp
@@ -0,0 +1,40 @@
+#include "oneflow/core/kernel/top_k_kernel.h"
+
+namespace oneflow {
+
+template<typename T>
+void TopKKernel<T>::ForwardDataContent(const KernelCtx& ctx,
+                                       std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* in_blob = BnInOp2Blob("in");
+  Blob* fw_buf_blob = BnInOp2Blob("fw_buf");
+  Blob* out_blob = BnInOp2Blob("out");
+
+  CHECK_LE(in_blob->shape().elem_cnt(), GetMaxVal<int32_t>());
+  const int32_t instance_size = static_cast<int32_t>(in_blob->shape().dim_vec().back());
+  const int32_t instance_num = static_cast<int32_t>(in_blob->shape().elem_cnt() / instance_size);
+  const T* in = in_blob->dptr<T>();
+  int32_t* fw_buf = fw_buf_blob->mut_dptr<int32_t>();
+  int32_t* out = out_blob->mut_dptr<int32_t>();
+  const auto& conf = this->op_conf().top_k_conf();
+  const int32_t k = conf.k();
+  FOR_RANGE(int32_t, i, 0, instance_num) {
+    std::iota(fw_buf, fw_buf + instance_size, 0);
+    const int32_t offset = i * instance_size;
+    auto comp = [&](const int32_t lhs, const int32_t rhs) {
+      const T l = in[offset + lhs];
+      const T r = in[offset + rhs];
+      if (l == r) {
+        return lhs < rhs;
+      } else {
+        return l > r;
+      }
+    };
+    std::nth_element(fw_buf, fw_buf + k, fw_buf + instance_size, comp);
+    if (k > 1 && conf.sorted()) { std::sort(fw_buf, fw_buf + k, comp); }
+    std::copy(fw_buf, fw_buf + k, out + i * k);
+  }
+}
+
+ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kTopKConf, TopKKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/top_k_kernel.h b/oneflow/core/kernel/top_k_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..98fb94d467d5ef77fb3d4b72694c79e63c2afde7
--- /dev/null
+++ b/oneflow/core/kernel/top_k_kernel.h
@@ -0,0 +1,23 @@
+#ifndef ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<typename T>
+class TopKKernel final : public KernelIf<DeviceType::kCPU> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TopKKernel);
+  TopKKernel() = default;
+  ~TopKKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_
diff --git a/oneflow/core/kernel/tuple_identity_kernel.cpp b/oneflow/core/kernel/tuple_identity_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..442f4cb716d490148edef4ce489ab13476715d0e
--- /dev/null
+++ b/oneflow/core/kernel/tuple_identity_kernel.cpp
@@ -0,0 +1,31 @@
+#include "oneflow/core/kernel/tuple_identity_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type>
+void TupleIdentityKernel<device_type>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const auto& input_bns = this->op_attribute().input_bns();
+  const auto& output_bns = this->op_attribute().output_bns();
+  CHECK_EQ(input_bns.size(), output_bns.size());
+  FOR_RANGE(int, i, 0, input_bns.size()) {
+    Blob* out_blob = BnInOp2Blob(output_bns.Get(i));
+    out_blob->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(input_bns.Get(i)));
+  }
+}
+
+template<DeviceType device_type>
+void TupleIdentityKernel<device_type>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const auto& input_diff_bns = this->op_attribute().input_diff_bns();
+  const auto& output_diff_bns = this->op_attribute().output_diff_bns();
+  CHECK_EQ(input_diff_bns.size(), output_diff_bns.size());
+  FOR_RANGE(int, i, 0, output_diff_bns.size()) {
+    Blob* in_diff_blob = BnInOp2Blob(input_diff_bns.Get(i));
+    in_diff_blob->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(output_diff_bns.Get(i)));
+  }
+}
+
+ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kTupleIdentityConf, TupleIdentityKernel);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/tuple_identity_kernel.h b/oneflow/core/kernel/tuple_identity_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..d2d953e6fe2041604c2d82dc855c64200dbbf57f
--- /dev/null
+++ b/oneflow/core/kernel/tuple_identity_kernel.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_KERNEL_TUPLE_IDENTITY_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_TUPLE_IDENTITY_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+#include "oneflow/core/kernel/kernel_context.h"
+
+namespace oneflow {
+
+template<DeviceType device_type>
+class TupleIdentityKernel final : public KernelIf<device_type> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TupleIdentityKernel);
+  TupleIdentityKernel() = default;
+  ~TupleIdentityKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_TUPLE_IDENTITY_KERNEL_H_
diff --git a/oneflow/core/kernel/unpack_kernel.h b/oneflow/core/kernel/unpack_kernel.h
index c82b085ef3f27625148bee306db85e1507d891c6..1411672182860813239cbda53d93c04eeb03836b 100644
--- a/oneflow/core/kernel/unpack_kernel.h
+++ b/oneflow/core/kernel/unpack_kernel.h
@@ -31,6 +31,10 @@ class UnpackKernel final : public KernelIf<device_type> {
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   void BackwardInDiffDim0ValidNum(
       const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BackwardInDiffLossInstanceNum(
+      const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
+    UNIMPLEMENTED();
+  }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/variable_kernel.cpp b/oneflow/core/kernel/variable_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b32360299a1deb6e45fd4962402cdc90ab136134
--- /dev/null
+++ b/oneflow/core/kernel/variable_kernel.cpp
@@ -0,0 +1,70 @@
+#include "oneflow/core/kernel/variable_kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+void VariableKernel<device_type, T>::ForwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const Blob* model_blob = BnInOp2Blob(ModelName());
+  Blob* out_blob = BnInOp2Blob("out");
+  if ((this->op_conf().trainable() && *tick_ % Global<JobDesc>::Get()->NumOfPiecesInBatch() == 0)
+      || (this->op_conf().trainable() == false && *tick_ == 0)) {
+    if (this->kernel_conf().variable_conf().is_fw_inplace()) {
+      CHECK_EQ(out_blob->dptr(), model_blob->dptr());
+    } else {
+      CHECK_NE(out_blob->dptr(), model_blob->dptr());
+      out_blob->CopyDataContentFrom(ctx.device_ctx, model_blob);
+    }
+  } else {
+    // do nothing
+  }
+  ++(*tick_);
+}
+
+template<DeviceType device_type, typename T>
+void VariableKernel<device_type, T>::BackwardDataContent(
+    const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  CHECK(this->op_conf().trainable());
+  const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out"));
+  Blob* model_diff_blob = BnInOp2Blob(GenDiffBn(ModelName()));
+  if (this->kernel_conf().variable_conf().is_bw_inplace()) {
+    CHECK_EQ(out_diff_blob->dptr(), model_diff_blob->dptr());
+  } else {
+    CHECK_NE(out_diff_blob->dptr(), model_diff_blob->dptr());
+    model_diff_blob->CopyDataContentFrom(ctx.device_ctx, out_diff_blob);
+  }
+}
+
+template<DeviceType device_type, typename T>
+void VariableKernel<device_type, T>::InitModelBlobsWithRandomSeed(
+    DeviceCtx* ctx, std::mt19937* random_seed_gen,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  KernelUtil<device_type, T>::InitializeWithProperConf(
+      ctx, GetMsgPtrFromPbMessage(this->op_conf().variable_conf(), "initializer"),
+      (*random_seed_gen)(), BnInOp2Blob(ModelName()));
+}
+
+template<DeviceType device_type, typename T>
+void VariableKernel<device_type, T>::InitModelBlobsWithDir(
+    DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  const std::string& model_name = ModelName();
+  Blob* model_blob = BnInOp2Blob(model_name);
+  KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, model_blob,
+                                                model_name, model_blob->shape().At(0),
+                                                model_blob->shape().Count(1));
+}
+
+template<DeviceType device_type, typename T>
+const PbMessage& VariableKernel<device_type, T>::GetCustomizedOpConf() const {
+  return this->op_conf().variable_conf();
+}
+
+template<DeviceType device_type, typename T>
+const std::string& VariableKernel<device_type, T>::ModelName() const {
+  return this->op_conf().variable_conf().model_name();
+}
+
+ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kVariableConf, VariableKernel, FLOATING_DATA_TYPE_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/variable_kernel.h b/oneflow/core/kernel/variable_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..82185acddcd96d8743380a658b69bc572789b2a7
--- /dev/null
+++ b/oneflow/core/kernel/variable_kernel.h
@@ -0,0 +1,33 @@
+#ifndef ONEFLOW_CORE_KERNEL_VARIABLE_KERNEL_H_
+#define ONEFLOW_CORE_KERNEL_VARIABLE_KERNEL_H_
+
+#include "oneflow/core/kernel/kernel.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+class VariableKernel final : public KernelIfWithModel<device_type, T> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(VariableKernel);
+  VariableKernel() : tick_(new int64_t(0)) {}
+  ~VariableKernel() = default;
+
+ private:
+  void ForwardDataContent(const KernelCtx&,
+                          std::function<Blob*(const std::string&)>) const override;
+  void BackwardDataContent(const KernelCtx&,
+                           std::function<Blob*(const std::string&)>) const override;
+  void InitModelBlobsWithRandomSeed(DeviceCtx*, std::mt19937*,
+                                    std::function<Blob*(const std::string&)>) const override;
+  void InitModelBlobsWithDir(DeviceCtx*, int32_t part_id, int32_t part_num,
+                             const std::string& model_load_dir,
+                             std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  const PbMessage& GetCustomizedOpConf() const override;
+  const std::string& ModelName() const;
+
+  std::unique_ptr<int64_t> tick_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_VARIABLE_KERNEL_H_
diff --git a/oneflow/core/ndarray/binary_func.h b/oneflow/core/ndarray/binary_func.h
new file mode 100644
index 0000000000000000000000000000000000000000..c1c8ffee8505ff933bc4b3d5a9056d4fef9eb450
--- /dev/null
+++ b/oneflow/core/ndarray/binary_func.h
@@ -0,0 +1,70 @@
+#ifndef ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_
+#define ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_
+
+#include <cstdint>
+#include <climits>
+#include <cfloat>
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+template<typename T>
+OF_DEVICE_FUNC const T BinaryFuncAdd(const T x, const T y) {
+  return x + y;
+}
+
+template<typename T>
+OF_DEVICE_FUNC const T BinaryFuncSub(const T x, const T y) {
+  return x - y;
+}
+
+template<typename T>
+OF_DEVICE_FUNC const T BinaryFuncMul(const T x, const T y) {
+  return x * y;
+}
+
+template<typename T>
+OF_DEVICE_FUNC const T BinaryFuncDiv(const T x, const T y) {
+  return x / y;
+}
+
+template<typename T>
+OF_DEVICE_FUNC const T BinaryFuncMax(const T x, const T y) {
+  return x > y ? x : y;
+}
+
+template<typename T>
+OF_DEVICE_FUNC const T BinaryFuncMin(const T x, const T y) {
+  return x < y ? x : y;
+}
+
+#define ARITHMETIC_BINARY_FUNC_SEQ    \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncAdd) \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncSub) \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncMul) \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncDiv)
+
+template<typename T, const T (*binary_func)(const T, const T), typename Enable = void>
+struct UnitOfBinaryFunc;
+
+#define SPECIALIZE_UNIT_OF_BINARY_FUNC(binary_func, val_trait)                               \
+  template<typename T, const T (*bfunc)(const T, const T)>                                   \
+  struct UnitOfBinaryFunc<T, bfunc, typename std::enable_if<bfunc == &binary_func<T>>::type> \
+      final {                                                                                \
+    constexpr static T value = val_trait<T>::value;                                          \
+  };
+SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAdd, ZeroVal);
+SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMul, OneVal);
+SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMax, MinVal);
+SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMin, MaxVal);
+#undef SPECIALIZE_UNIT_OF_BINARY_FUNC
+
+#define REDUCE_BINARY_FUNC_SEQ        \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncAdd) \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncMax) \
+  OF_PP_MAKE_TUPLE_SEQ(BinaryFuncMin)
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_
diff --git a/oneflow/core/ndarray/cpu_ndarray_assign.h b/oneflow/core/ndarray/cpu_ndarray_assign.h
new file mode 100644
index 0000000000000000000000000000000000000000..38cac781358960ef4bb21c8701289680b50fb6fa
--- /dev/null
+++ b/oneflow/core/ndarray/cpu_ndarray_assign.h
@@ -0,0 +1,18 @@
+#ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_ASSIGN_H_
+#define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_ASSIGN_H_
+
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<int NDIMS, typename T, typename X>
+OF_DEVICE_FUNC void CpuNdArrayAssign(XpuVarNdarray<T>* y, const X& x) {
+  size_t n = y->shape().ElemNum();
+  FOR_RANGE(int, i, 0, n) { *(y->template Mut<NDIMS>(i)) = x.template Get<NDIMS>(i); }
+}
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_ASSIGN_H_
diff --git a/oneflow/core/ndarray/exec_shape.cpp b/oneflow/core/ndarray/exec_shape.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..cd5712afd79fc9ff233ee0386dca686cbcf220ce
--- /dev/null
+++ b/oneflow/core/ndarray/exec_shape.cpp
@@ -0,0 +1,37 @@
+#include "oneflow/core/ndarray/xpu_shape.h"
+
+namespace oneflow {
+
+XpuShape::XpuShape(const int64_t dim[], int num_axes) {
+  num_axes_ = num_axes;
+  int i = 0;
+  for (; i < num_axes_; ++i) { dim_[i] = dim[i]; }
+  UpdateDimElemNumAndElemNum();
+  for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) {
+    dim_[i] = 1;
+    dim_elem_num_[i] = 1;
+  }
+}
+
+XpuShape::XpuShape(const Shape& shape) {
+  num_axes_ = shape.NumAxes();
+  int i = 0;
+  for (; i < num_axes_; ++i) { dim_[i] = shape.At(i); }
+  UpdateDimElemNumAndElemNum();
+  for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) {
+    dim_[i] = 1;
+    dim_elem_num_[i] = 1;
+  }
+}
+
+bool XpuShape::operator==(const XpuShape& rhs) const {
+  if (num_axes_ != rhs.num_axes_) { return false; }
+  if (elem_num_ != rhs.elem_num_) { return false; }
+  for (int i = 0; i < num_axes_; ++i) {
+    if (dim_[i] != rhs.dim_[i]) { return false; }
+    if (dim_elem_num_[i] != rhs.dim_elem_num_[i]) { return false; }
+  }
+  return true;
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/gpu_ndarray_assign.h b/oneflow/core/ndarray/gpu_ndarray_assign.h
new file mode 100644
index 0000000000000000000000000000000000000000..42a6faa9e16f8e57ec905e453c5e78ed79dd53b3
--- /dev/null
+++ b/oneflow/core/ndarray/gpu_ndarray_assign.h
@@ -0,0 +1,18 @@
+#ifndef ONEFLOW_CORE_NDARRAY_GPU_NDARRAY_ASSIGN_H_
+#define ONEFLOW_CORE_NDARRAY_GPU_NDARRAY_ASSIGN_H_
+
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/kernel/kernel_util.cuh"
+
+namespace oneflow {
+
+template<int NDIMS, typename T, typename X>
+__device__ void GpuNdArrayAssign(XpuVarNdarray<T>* y, const X& x) {
+  size_t n = y->shape().ElemNum();
+  CUDA_1D_KERNEL_LOOP(i, n) { *(y->template Mut<NDIMS>(i)) = x.template Get<NDIMS>(i); }
+}
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_GPU_NDARRAY_ASSIGN_H_
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h
new file mode 100644
index 0000000000000000000000000000000000000000..3abcbfd051fde57def8927aa998d470afcfbd96b
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h
@@ -0,0 +1,44 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_
+
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h"
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayApplyBroadcastBinary final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a,
+                    const XpuVarNdarray<const T>& b) {
+    using NdarrayAssign = XpuNdArrayAssign<device_type, T>;
+    using BroadcastBinary =
+        NdArrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>;
+    CheckBroadcastable(y, a, b);
+    return BroadcastBinary::Apply(ctx, y, a, b);
+    if (a.shape() == y.shape()) {
+      NdarrayAssign::Assign(ctx, y, a);
+      BroadcastBinary::ImplaceApply(ctx, y, b);
+    } else if (b.shape() == y.shape()) {
+      NdarrayAssign::Assign(ctx, y, b);
+      BroadcastBinary::ImplaceApply(ctx, y, a);
+    } else {
+      BroadcastBinary::Apply(ctx, y, a, b);
+    }
+  }
+
+  static void CheckBroadcastable(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a,
+                                 const XpuVarNdarray<const T>& b) {
+    CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes());
+    CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes());
+    for (int i = 0; i < y.shape().NumAxes(); ++i) {
+      CHECK_EQ(y.shape().At(i), std::max(a.shape().At(i), b.shape().At(i)));
+      if (a.shape().At(i) != b.shape().At(i)) {
+        CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1);
+      }
+    }
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9a0fcdaec7bd0a2b9dd8c5e67dc78c4febdcd223
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp
@@ -0,0 +1,22 @@
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h"
+
+namespace oneflow {
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayApplyBroadcastBinaryCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func> final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a,
+                    const XpuVarNdarray<const T>& b) {
+    NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::Apply(y, a, b);
+  }
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                           const XpuVarNdarray<const T>& x) {
+    NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::ImplaceApply(y, x);
+  }
+};
+
+#define INSTANTIATE_BROADCAST_BINARY_FUNC(dtype_pair, NDIMS, binary_func) \
+  template struct NdArrayApplyBroadcastBinaryCoreWrapper<                 \
+      DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, binary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_BINARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ,
+                                 DIM_SEQ, ARITHMETIC_BINARY_FUNC_SEQ)
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu
new file mode 100644
index 0000000000000000000000000000000000000000..94522f200083a04cf8188e7aa26b572c6604c9ce
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu
@@ -0,0 +1,38 @@
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+__global__ void GpuBroadcastBinaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> a,
+                                       const XpuVarNdarray<const T> b) {
+  NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::Apply(y, a, b);
+}
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+__global__ void GpuBroadcastBinaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> x) {
+  NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::ImplaceApply(y, x);
+}
+
+}  // namespace
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayApplyBroadcastBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS, binary_func> final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a,
+                    const XpuVarNdarray<const T>& b) {
+    size_t n = y.host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, a, b);
+  }
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                           const XpuVarNdarray<const T>& x) {
+    size_t n = y.host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, x);
+  }
+};
+
+#define INSTANTIATE_BROADCAST_BINARY_FUNC(dtype_pair, NDIMS, binary_func) \
+  template struct NdArrayApplyBroadcastBinaryCoreWrapper<                 \
+      DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, binary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_BINARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ,
+                                 DIM_SEQ, ARITHMETIC_BINARY_FUNC_SEQ)
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h
new file mode 100644
index 0000000000000000000000000000000000000000..ab67762ed03ce40f994ffab991bca1b2aaf9205c
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h
@@ -0,0 +1,37 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_
+
+#include "oneflow/core/ndarray/xpu_util.h"
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/xpu_broadcast_ndarray.h"
+#include "oneflow/core/ndarray/xpu_binary_func_ndarray.h"
+#include "oneflow/core/ndarray/binary_func.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayApplyBroadcastBinaryCoreWrapper final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a,
+                    const XpuVarNdarray<const T>& b);
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                           const XpuVarNdarray<const T>& x);
+};
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayApplyBroadcastBinaryCore final {
+  OF_DEVICE_FUNC static void Apply(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a,
+                                   const XpuVarNdarray<const T>& b) {
+    const auto& ret =
+        a.Broadcast(y.shape()).template BinaryFunc<binary_func>(b.Broadcast(y.shape()));
+    y.template Assign<NDIMS>(ret);
+  }
+  OF_DEVICE_FUNC static void ImplaceApply(const XpuVarNdarray<T>& y,
+                                          const XpuVarNdarray<const T>& x) {
+    y.template BinaryAssign<binary_func, NDIMS>(x.Broadcast(y.shape()));
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_unary.h
new file mode 100644
index 0000000000000000000000000000000000000000..432b966ad37c75d46f06505ae548ca0a3bb24828
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary.h
@@ -0,0 +1,26 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_
+
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h"
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, int NDIMS, const T (*unary_func)(const T)>
+struct NdArrayApplyBroadcastUnary final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    CheckBroadcastable(y, x);
+    NdArrayApplyBroadcastUnaryCoreWrapper<device_type, T, NDIMS, unary_func>::Apply(ctx, y, x);
+  }
+
+  static void CheckBroadcastable(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes());
+    for (int i = 0; i < y.shape().NumAxes(); ++i) {
+      CHECK(x.shape().At(i) == 1 || x.shape().At(i) == y.shape().At(i));
+    }
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..842c1d20a5c972c95f66b49fd4f87d952f1dac5b
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp
@@ -0,0 +1,17 @@
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h"
+
+namespace oneflow {
+
+template<typename T, int NDIMS, const T (*unary_func)(const T)>
+struct NdArrayApplyBroadcastUnaryCoreWrapper<DeviceType::kCPU, T, NDIMS, unary_func> final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    NdArrayApplyBroadcastUnaryCore<T, NDIMS, unary_func>::Apply(y, x);
+  }
+};
+
+#define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \
+  template struct NdArrayApplyBroadcastUnaryCoreWrapper<                \
+      DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ,
+                                 DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ)
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f705f8b069a3fc43c9bf720262800db328651506
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu
@@ -0,0 +1,27 @@
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, int NDIMS, const T (*unary_func)(const T)>
+__global__ void GpuBroadcastUnaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> x) {
+  NdArrayApplyBroadcastUnaryCore<T, NDIMS, unary_func>::Apply(y, x);
+}
+
+}  // namespace
+
+template<typename T, int NDIMS, const T (*unary_func)(const T)>
+struct NdArrayApplyBroadcastUnaryCoreWrapper<DeviceType::kGPU, T, NDIMS, unary_func> final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    size_t n = y.host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc<T, NDIMS, unary_func>), ctx, n, y, x);
+  }
+};
+
+#define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \
+  template struct NdArrayApplyBroadcastUnaryCoreWrapper<                \
+      DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ,
+                                 DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ)
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h
new file mode 100644
index 0000000000000000000000000000000000000000..18213b7bd858515ba67f2ccd4ab884a5908c546d
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_
+
+#include "oneflow/core/ndarray/xpu_util.h"
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/xpu_broadcast_ndarray.h"
+#include "oneflow/core/ndarray/xpu_unary_func_ndarray.h"
+#include "oneflow/core/ndarray/unary_func.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, int NDIMS, const T (*unary_func)(const T)>
+struct NdArrayApplyBroadcastUnaryCoreWrapper final {
+  static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x);
+};
+
+template<typename T, int NDIMS, const T (*unary_func)(const T)>
+struct NdArrayApplyBroadcastUnaryCore final {
+  OF_DEVICE_FUNC static void Apply(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    y.template Assign<NDIMS>(x.Broadcast(y.shape()).template UnaryFunc<unary_func>());
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_
diff --git a/oneflow/core/ndarray/ndarray_apply_unary.h b/oneflow/core/ndarray/ndarray_apply_unary.h
new file mode 100644
index 0000000000000000000000000000000000000000..774cdaa2b63ecd842a4bd6a873dd5197fbfb9fd2
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_unary.h
@@ -0,0 +1,18 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_
+
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/ndarray/ndarray_apply_unary_core.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, const T (*unary_func)(const T)>
+struct NdArrayApplyUnary final {
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) {
+    NdArrayApplyUnaryCoreWrapper<device_type, T, unary_func>::ImplaceApply(ctx, y);
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_
diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cpp b/oneflow/core/ndarray/ndarray_apply_unary_core.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dda45c5e8ff56ff03b76acf99a83cf5e198d5916
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cpp
@@ -0,0 +1,19 @@
+#include "oneflow/core/ndarray/ndarray_apply_unary_core.h"
+#include "oneflow/core/ndarray/unary_func.h"
+
+namespace oneflow {
+
+template<typename T, const T (*unary_func)(const T)>
+struct NdArrayApplyUnaryCoreWrapper<DeviceType::kCPU, T, unary_func> final {
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) {
+    NdArrayApplyUnaryCore<T, unary_func>::ImplaceApply(y.ptr(), y.shape().ElemNum());
+  }
+};
+
+#define INSTANTIATE_NDARRAY_APPLY_UNARY_CORE(dtype_pair, unary_func)                           \
+  template struct NdArrayApplyUnaryCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), \
+                                               unary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_APPLY_UNARY_CORE, ARITHMETIC_DATA_TYPE_SEQ,
+                                 ARITHMETIC_UNARY_FUNC_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_unary_core.cu
new file mode 100644
index 0000000000000000000000000000000000000000..256bf5deb81a757abc0a3794b483eea69a28e609
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cu
@@ -0,0 +1,29 @@
+#include "oneflow/core/ndarray/ndarray_apply_unary_core.h"
+#include "oneflow/core/ndarray/unary_func.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, const T (*unary_func)(const T)>
+__global__ void NdArrayApplyUnaryImplaceApplyGpu(T* ptr, size_t n) {
+  NdArrayApplyUnaryCore<T, unary_func>::ImplaceApply(ptr, n);
+}
+
+}  // namespace
+
+template<typename T, const T (*unary_func)(const T)>
+struct NdArrayApplyUnaryCoreWrapper<DeviceType::kGPU, T, unary_func> final {
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) {
+    size_t n = y.host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((NdArrayApplyUnaryImplaceApplyGpu<T, unary_func>), ctx, n, y.host_ptr(), n);
+  }
+};
+
+#define INSTANTIATE_NDARRAY_APPLY_UNARY_CORE(dtype_pair, unary_func)                           \
+  template struct NdArrayApplyUnaryCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), \
+                                               unary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_APPLY_UNARY_CORE, ARITHMETIC_DATA_TYPE_SEQ,
+                                 ARITHMETIC_UNARY_FUNC_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.h b/oneflow/core/ndarray/ndarray_apply_unary_core.h
new file mode 100644
index 0000000000000000000000000000000000000000..ddec5689e837d9d2cccf7f378195c2799e42cc7d
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_apply_unary_core.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_
+
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/ndarray/xpu_unary_func_ndarray.h"
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, const T (*unary_func)(const T)>
+struct NdArrayApplyUnaryCoreWrapper final {
+  static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y);
+};
+
+template<typename T, const T (*unary_func)(const T)>
+struct NdArrayApplyUnaryCore final {
+  OF_DEVICE_FUNC static void ImplaceApply(T* y, size_t n) {
+    XPU_1D_KERNEL_LOOP(i, n) { y[i] = unary_func(y[i]); }
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_
diff --git a/oneflow/core/ndarray/ndarray_assign_core.cpp b/oneflow/core/ndarray/ndarray_assign_core.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..351edf63db93da80ca07ac972ef13a6b828ec51e
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_assign_core.cpp
@@ -0,0 +1,18 @@
+#include "oneflow/core/ndarray/ndarray_assign_core.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<typename T, int NDIMS>
+struct NdArrayAssignCoreWrapper<DeviceType::kCPU, T, NDIMS> final {
+  static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                     const XpuReducedNdarray<T, NDIMS>& reduced) {
+    NdArrayAssignCore<T, NDIMS>::Assign(y, reduced);
+  }
+};
+
+#define INSTANTIATE_NDARRAY_ASSIGN(dtype_pair, NDIMS) \
+  template struct NdArrayAssignCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ, DIM_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_assign_core.cu b/oneflow/core/ndarray/ndarray_assign_core.cu
new file mode 100644
index 0000000000000000000000000000000000000000..e7e05d7c7336eca7b8ac5e66fecf21202a4079b8
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_assign_core.cu
@@ -0,0 +1,29 @@
+#include "oneflow/core/ndarray/ndarray_assign_core.h"
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, int NDIMS>
+__global__ void NdArrayAssignGpu(XpuVarNdarray<T> y, const XpuReducedNdarray<T, NDIMS> reduced) {
+  NdArrayAssignCore<T, NDIMS>::Assign(y, reduced);
+}
+
+}  // namespace
+
+template<typename T, int NDIMS>
+struct NdArrayAssignCoreWrapper<DeviceType::kGPU, T, NDIMS> final {
+  static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                     const XpuReducedNdarray<T, NDIMS>& reduced) {
+    size_t n = y.host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((NdArrayAssignGpu<T, NDIMS>), ctx, n, y, reduced);
+  }
+};
+
+#define INSTANTIATE_NDARRAY_ASSIGN(dtype_pair, NDIMS) \
+  template struct NdArrayAssignCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ, DIM_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_assign_core.h b/oneflow/core/ndarray/ndarray_assign_core.h
new file mode 100644
index 0000000000000000000000000000000000000000..4a822b46bdb031ff29e5c93fcdd2394a68ef16f4
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_assign_core.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_
+
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/xpu_reduced_ndarray.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, int NDIMS>
+struct NdArrayAssignCoreWrapper final {
+  static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                     const XpuReducedNdarray<T, NDIMS>& reduced);
+};
+
+template<typename T, int NDIMS>
+struct NdArrayAssignCore final {
+  OF_DEVICE_FUNC static void Assign(const XpuVarNdarray<T>& y,
+                                    const XpuReducedNdarray<T, NDIMS>& reduced) {
+    y.template Assign<NDIMS>(reduced);
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_
diff --git a/oneflow/core/ndarray/ndarray_reduce.h b/oneflow/core/ndarray/ndarray_reduce.h
new file mode 100644
index 0000000000000000000000000000000000000000..f508c5af381042c6d4c74fd2d40874274286eeb6
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_reduce.h
@@ -0,0 +1,31 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_
+
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/ndarray/ndarray_reduce_impl.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)>
+struct NdArrayReduce final {
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes());
+    if (NdarrayNoReduce<device_type, T, binary_func>::Matched(y, x)) {
+      NdarrayNoReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage);
+    } else if (NdarrayScalarReduce<device_type, T, binary_func>::Matched(y, x)) {
+      NdarrayScalarReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage);
+    } else if (NdarrayMatrixRowReduce<device_type, T, binary_func>::Matched(y, x)) {
+      NdarrayMatrixRowReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage);
+    } else if (NdarrayMatrixColReduce<device_type, T, binary_func>::Matched(y, x)) {
+      NdarrayMatrixColReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage);
+    } else {
+      NdarrayDefaultReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage);
+    }
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_
diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.cpp b/oneflow/core/ndarray/ndarray_reduce_impl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e34366ae22e86738a90327076d586076684bb2f6
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_reduce_impl.cpp
@@ -0,0 +1,45 @@
+#include "oneflow/core/common/data_type.h"
+#include "oneflow/core/common/preprocessor.h"
+#include "oneflow/core/ndarray/ndarray_reduce_impl.h"
+#include "oneflow/core/ndarray/binary_func.h"
+
+namespace oneflow {
+
+#define SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(struct_name)                                            \
+  template<typename T, const T (*binary_func)(const T, const T)>                                   \
+  struct struct_name<DeviceType::kCPU, T, binary_func> final {                                     \
+    static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {              \
+      return false;                                                                                \
+    }                                                                                              \
+    static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, \
+                       const XpuVarNdarray<T>& tmp_storage) {                                      \
+      UNIMPLEMENTED();                                                                             \
+    }                                                                                              \
+  }
+SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce);
+SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce);
+SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce);
+#undef SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL
+
+#define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func)                                       \
+  template struct NdarrayScalarReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>;    \
+  template struct NdarrayMatrixRowReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \
+  template struct NdarrayMatrixColReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ,
+                                 REDUCE_BINARY_FUNC_SEQ);
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayReduceCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func> final {
+  static void ReduceAxis(DeviceCtx* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced,
+                         const XpuReducedNdarray<T, NDIMS>& x, int axis) {
+    NdArrayReduceCore<T, NDIMS, binary_func>::ReduceAxis(dst_reduced, x, axis);
+  }
+};
+
+#define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func)                   \
+  template struct NdArrayReduceCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, \
+                                           binary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ,
+                                 DIM_SEQ, REDUCE_BINARY_FUNC_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.cu b/oneflow/core/ndarray/ndarray_reduce_impl.cu
new file mode 100644
index 0000000000000000000000000000000000000000..abec90b2c0e48baac7ab01bc067aa1889f5794ce
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_reduce_impl.cu
@@ -0,0 +1,193 @@
+#include <cub/cub.cuh>
+#include "oneflow/core/ndarray/ndarray_reduce_impl.h"
+#include "oneflow/core/ndarray/binary_func.h"
+#include "oneflow/core/common/preprocessor.h"
+#include "oneflow/core/common/shape.h"
+
+namespace oneflow {
+
+template<typename T, const T (*binary_func)(const T, const T), typename Enable = void>
+struct CubFunctor4BianryFunc;
+
+#define SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(binary_func, cub_functor)                            \
+  template<typename T, const T (*bfunc)(const T, const T)>                                        \
+  struct CubFunctor4BianryFunc<T, bfunc, typename std::enable_if<bfunc == &binary_func<T>>::type> \
+      final {                                                                                     \
+    using type = cub_functor;                                                                     \
+  }
+
+SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(BinaryFuncAdd, cub::Sum);
+SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(BinaryFuncMax, cub::Max);
+SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(BinaryFuncMin, cub::Min);
+
+#undef SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC
+
+namespace {
+
+template<typename T, const T (*binary_func)(const T, const T)>
+void __global__ NdarrayMatrixColReduceNaiveCudaKernel(T* y_ptr, const T* x_ptr, int32_t num_rows,
+                                                      int32_t num_cols) {
+  CUDA_1D_KERNEL_LOOP(j, num_cols) {
+    T reduced = x_ptr[j];
+    FOR_RANGE(int32_t, i, 1, num_rows) { reduced = binary_func(reduced, x_ptr[i * num_cols + j]); }
+    y_ptr[j] = reduced;
+  }
+}
+
+}  // namespace
+
+struct RowOffsetFunctor final {
+  OF_DEVICE_FUNC explicit RowOffsetFunctor(int32_t num_cols) : num_cols_(num_cols) {}
+  OF_DEVICE_FUNC int32_t operator()(const int32_t& x) const { return x * num_cols_; }
+  int32_t num_cols_;
+};
+
+template<typename T, const T (*binary_func)(const T, const T)>
+struct NdarrayScalarReduce<DeviceType::kGPU, T, binary_func> final {
+  static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    return y.shape().ElemNum() == 1;
+  }
+
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    CHECK(Matched(y, x));
+    size_t x_size = x.shape().ElemNum();
+    size_t tmp_storage_bytes = 0;
+    auto DoReduce = [&](T* tmp_storage_ptr) {
+      int retcode =
+          cub::DeviceReduce::Reduce(tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), x_size,
+                                    typename CubFunctor4BianryFunc<T, binary_func>::type(),
+                                    UnitOfBinaryFunc<T, binary_func>::value, ctx->cuda_stream());
+      CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error";
+    };
+    DoReduce(nullptr);
+    CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes);
+    DoReduce(tmp_storage.ptr());
+  }
+};
+
+template<typename T, const T (*binary_func)(const T, const T)>
+struct NdarrayMatrixRowReduce<DeviceType::kGPU, T, binary_func> final {
+  static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; }
+    const auto& x_squeezed = SqueezeRight(x.shape());
+    const auto& y_squeezed = SqueezeRight(y.shape());
+    if (x_squeezed.NumAxes() == 0) { return false; }
+    for (int i = 0; i < y_squeezed.NumAxes(); ++i) {
+      if (x_squeezed.At(i) != y_squeezed.At(i)) { return false; }
+    }
+    CHECK_EQ(x.shape().ElemNum() % y.shape().ElemNum(), 0);
+    return true;
+  }
+
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    CHECK(Matched(y, x));
+    int32_t num_rows = y.shape().ElemNum();
+    int32_t num_cols = x.shape().ElemNum() / y.shape().ElemNum();
+    RowOffsetFunctor get_row_offset(num_cols);
+    cub::CountingInputIterator<int32_t> counting_intput_it(0);
+    cub::TransformInputIterator<int32_t, RowOffsetFunctor, cub::CountingInputIterator<int32_t>>
+        transform_input_iter(counting_intput_it, get_row_offset);
+    size_t tmp_storage_bytes = 0;
+    auto DoReduce = [&](T* tmp_storage_ptr) {
+      int retcode = cub::DeviceSegmentedReduce::Reduce(
+          tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), num_rows, transform_input_iter,
+          transform_input_iter + 1, typename CubFunctor4BianryFunc<T, binary_func>::type(),
+          UnitOfBinaryFunc<T, binary_func>::value, ctx->cuda_stream());
+      CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error";
+    };
+    DoReduce(nullptr);
+    CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes);
+    DoReduce(tmp_storage.ptr());
+  }
+
+ private:
+  static XpuShape SqueezeRight(const XpuShape& shape) {
+    std::vector<int64_t> dim_vec;
+    for (int i = 0; i < shape.NumAxes(); ++i) { dim_vec.push_back(shape.At(i)); }
+    for (int i = shape.NumAxes() - 1; i >= 0; --i) {
+      if (dim_vec.at(i) != 1) { break; }
+      dim_vec.pop_back();
+    }
+    if (dim_vec.empty()) { dim_vec.push_back(1LL); }
+    return XpuShape(Shape(dim_vec));
+  }
+};
+
+template<typename T, const T (*binary_func)(const T, const T)>
+struct NdarrayMatrixColReduce<DeviceType::kGPU, T, binary_func> final {
+  static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; }
+    const auto& x_squeezed = SqueezeLeft(x.shape());
+    const auto& y_squeezed = SqueezeLeft(y.shape());
+    if (x_squeezed.NumAxes() == 0) { return false; }
+    for (int i = 0; i < y_squeezed.NumAxes(); ++i) {
+      if (x_squeezed.At(x_squeezed.NumAxes() - 1 - i)
+          != y_squeezed.At(y_squeezed.NumAxes() - 1 - i)) {
+        return false;
+      }
+    }
+    CHECK_EQ(x.shape().ElemNum() % y.shape().ElemNum(), 0);
+    return true;
+  }
+
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    CHECK(Matched(y, x));
+    int32_t num_rows = x.shape().ElemNum() / y.shape().ElemNum();
+    int32_t num_cols = y.shape().ElemNum();
+    RUN_CUDA_KERNEL((NdarrayMatrixColReduceNaiveCudaKernel<T, binary_func>), ctx, num_cols, y.ptr(),
+                    x.ptr(), num_rows, num_cols);
+  }
+
+ private:
+  static XpuShape SqueezeLeft(const XpuShape& shape) {
+    std::vector<int64_t> dim_vec;
+    bool all_squeezed = false;
+    for (int i = 0; i < shape.NumAxes(); ++i) {
+      if (all_squeezed == false) {
+        if (shape.At(i) == 1) { continue; }
+        all_squeezed = true;
+      }
+      dim_vec.push_back(shape.At(i));
+    }
+    if (dim_vec.empty()) { dim_vec.push_back(1LL); }
+    return XpuShape(Shape(dim_vec));
+  }
+};
+
+namespace {
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+__global__ void NdArrayReduceGpuImplaceReduceAxis(const XpuReducedNdarray<T, NDIMS> dst_reduced,
+                                                  const XpuReducedNdarray<T, NDIMS> x, int axis) {
+  NdArrayReduceCore<T, NDIMS, binary_func>::ReduceAxis(dst_reduced, x, axis);
+}
+
+}  // namespace
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayReduceCoreWrapper<DeviceType::kGPU, T, NDIMS, binary_func> final {
+  static void ReduceAxis(DeviceCtx* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced,
+                         const XpuReducedNdarray<T, NDIMS>& x, int axis) {
+    size_t n = x.host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((NdArrayReduceGpuImplaceReduceAxis<T, NDIMS, binary_func>), ctx, n, dst_reduced,
+                    x, axis);
+  }
+};
+
+#define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func)                                       \
+  template struct NdarrayScalarReduce<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype), binary_func>;    \
+  template struct NdarrayMatrixRowReduce<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \
+  template struct NdarrayMatrixColReduce<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype), binary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ,
+                                 REDUCE_BINARY_FUNC_SEQ);
+
+#define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func)                   \
+  template struct NdArrayReduceCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, \
+                                           binary_func>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ,
+                                 DIM_SEQ, REDUCE_BINARY_FUNC_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.h b/oneflow/core/ndarray/ndarray_reduce_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..d6191db2b79c352c9a5668a80d2348cb410ad175
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_reduce_impl.h
@@ -0,0 +1,109 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_
+
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/common/switch_func.h"
+#include "oneflow/core/ndarray/xpu_ndarray_assign.h"
+#include "oneflow/core/ndarray/binary_func.h"
+
+namespace oneflow {
+
+#define DECLARE_NDARRAY_REDUCE_IMPL(struct_name)                                                   \
+  template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)>           \
+  struct struct_name final {                                                                       \
+    static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x);               \
+    static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, \
+                       const XpuVarNdarray<T>& tmp_storage);                                       \
+  }
+DECLARE_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce);
+DECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce);
+DECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce);
+#undef DECLARE_NDARRAY_REDUCE_IMPL
+
+template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)>
+struct NdarrayNoReduce final {
+  static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    return x.shape() == y.shape();
+  }
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    XpuNdArrayAssign<device_type, T>::Assign(ctx, y, x);
+  }
+};
+
+template<DeviceType device_type, typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayReduceCoreWrapper final {
+  static void ReduceAxis(DeviceCtx* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced,
+                         const XpuReducedNdarray<T, NDIMS>& x, int axis);
+};
+
+template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)>
+struct NdarrayDefaultReduce final {
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    return SwitchReduce(SwitchCase(y.shape().NumAxes()), ctx, y, x, tmp_storage);
+  }
+
+ private:
+#define DEFINE_NDARRAY_REDUCE(func_name, NDIMS) func_name<NDIMS>
+  DEFINE_STATIC_SWITCH_FUNC(void, Reduce, DEFINE_NDARRAY_REDUCE, MAKE_NDIM_CTRV_SEQ(DIM_SEQ));
+#undef DEFINE_NDARRAY_REDUCE
+
+  template<int NDIMS>
+  static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                     const XpuVarNdarray<T>& tmp_storage) {
+    XpuVarNdarray<T> storage(x.shape(), tmp_storage.ptr());
+    XpuShape cur_shape(x.shape());
+    CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes());
+    CHECK(x.shape() != y.shape());
+    XpuNdArrayAssign<device_type, T>::Assign(ctx, storage, x);
+    for (int i = 0; i < x.shape().NumAxes(); ++i) {
+      if (y.shape().At(i) == x.shape().At(i)) { continue; }
+      CHECK_EQ(y.shape().At(i), 1);
+      CHECK_GT(x.shape().At(i), y.shape().At(i));
+      ImplaceReduceAxis<NDIMS>(ctx, i, storage, &cur_shape);
+    }
+    XpuReducedNdarray<T, NDIMS> reduced(y.shape(), storage);
+    XpuNdArrayAssign<device_type, T>::template Assign<NDIMS>(ctx, y, reduced);
+  }
+
+  template<int NDIMS>
+  static void ImplaceReduceAxis(DeviceCtx* ctx, int axis, const XpuVarNdarray<T>& implace,
+                                XpuShape* cur_shape) {
+    int64_t target_elem_num = cur_shape->ElemNum() / cur_shape->At(axis);
+    while (cur_shape->At(axis) > 1) {
+      int64_t shrink = 8 + std::sqrt(target_elem_num);
+      XpuReducedNdarray<T, NDIMS> from(*cur_shape, implace);
+      int64_t new_dim_value = (cur_shape->At(axis) + (shrink - 1)) / shrink;
+      cur_shape->Set(axis, new_dim_value);
+      XpuReducedNdarray<T, NDIMS> to(*cur_shape, implace);
+      NdArrayReduceCoreWrapper<device_type, T, NDIMS, binary_func>::ReduceAxis(ctx, to, from, axis);
+    }
+  }
+};
+
+template<typename T, int NDIMS, const T (*binary_func)(const T, const T)>
+struct NdArrayReduceCore final {
+  template<typename X>
+  OF_DEVICE_FUNC static void ReduceAxis(const XpuReducedNdarray<T, NDIMS>& dst_reduced, const X& x,
+                                        int axis) {
+    size_t n = dst_reduced.shape().ElemNum();
+    int64_t dst_dim_val = dst_reduced.shape().At(axis);
+    XPU_1D_KERNEL_LOOP(i, n) {
+      T* dst_reduced_ptr = dst_reduced.template Mut(i);
+      int64_t coord[NDIMS];
+      dst_reduced.shape().template Offset2Coordinate<NDIMS>(i, coord);
+      T reduced = UnitOfBinaryFunc<T, binary_func>::value;
+      while (coord[axis] < x.shape().At(axis)) {
+        reduced = binary_func(reduced, x.template Get<NDIMS>(coord));
+        coord[axis] += dst_dim_val;
+      }
+      *dst_reduced_ptr = reduced;
+    }
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_
diff --git a/oneflow/core/ndarray/ndarray_reduce_test.cpp b/oneflow/core/ndarray/ndarray_reduce_test.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f3d7e11506b8073a36e3c73ba68e6956d08a9885
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_reduce_test.cpp
@@ -0,0 +1,42 @@
+#include "oneflow/core/ndarray/ndarray_util.h"
+#include <gtest/gtest.h>
+
+namespace oneflow {
+
+namespace test {
+
+namespace {
+
+void TestMiddleAxis(int num) {
+  std::vector<int32_t> data(num * num * num, 1);
+  std::vector<int32_t> tmp_storage(num * num * num, -8888);
+  XpuVarNdarray<const int32_t> x(XpuShape(Shape({num, num, num})), data.data());
+  XpuVarNdarray<int32_t> tmp(XpuShape(Shape({num, num, num})), tmp_storage.data());
+  std::vector<int32_t> ret(num * num, -999);
+  XpuVarNdarray<int32_t> y(XpuShape(Shape({num, 1, num})), ret.data());
+  NdArrayReduce<DeviceType::kCPU, int32_t, BinaryFuncAdd>::Reduce(nullptr, y, x, tmp);
+  for (int i = 0; i < num; ++i) {
+    for (int j = 0; j < num; ++j) { ASSERT_EQ(ret[i * num + j], num); }
+  }
+}
+
+}  // namespace
+
+TEST(NdArrayReduce, sum) {
+  std::vector<int32_t> data(100, 1);
+  std::vector<int32_t> tmp_storage(100, -1);
+  XpuVarNdarray<const int32_t> x(XpuShape(Shape({100})), data.data());
+  XpuVarNdarray<int32_t> tmp(XpuShape(Shape({100})), tmp_storage.data());
+  int32_t ret = -100;
+  XpuVarNdarray<int32_t> y(XpuShape(Shape({1})), &ret);
+  NdArrayReduce<DeviceType::kCPU, int32_t, BinaryFuncAdd>::Reduce(nullptr, y, x, tmp);
+  ASSERT_EQ(ret, 100);
+}
+
+TEST(NdArrayReduce, middle_axis_2) { TestMiddleAxis(10); }
+
+TEST(NdArrayReduce, middle_axis_10) { TestMiddleAxis(125); }
+
+}  // namespace test
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/ndarray_util.h b/oneflow/core/ndarray/ndarray_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..64813f3b356c2086d64c397e11df77795d89f8ad
--- /dev/null
+++ b/oneflow/core/ndarray/ndarray_util.h
@@ -0,0 +1,84 @@
+#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_
+#define ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_
+
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/ndarray_reduce.h"
+#include "oneflow/core/ndarray/ndarray_apply_unary.h"
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary.h"
+#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary.h"
+#include "oneflow/core/ndarray/xpu_reduced_ndarray.h"
+#include "oneflow/core/common/switch_func.h"
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+struct NdarrayUtil final {
+  template<const T (*unary_func)(const T)>
+  static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                             const XpuVarNdarray<const T>& x) {
+    CHECK_EQ(x.shape().NumAxes(), y.shape().NumAxes());
+    return Unary<unary_func>::SwitchBroadcastApply(SwitchCase(x.shape().NumAxes()), ctx, y, x);
+  }
+  static void BroadcastTo(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                          const XpuVarNdarray<const T>& x) {
+    return BroadcastApply<UnaryFuncIdentity>(ctx, y, x);
+  }
+  template<const T (*binary_func)(const T, const T)>
+  static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                             const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
+    CHECK_EQ(a.shape().NumAxes(), y.shape().NumAxes());
+    CHECK_EQ(b.shape().NumAxes(), y.shape().NumAxes());
+    return Binary<binary_func>::SwitchBroadcastApply(SwitchCase(y.shape().NumAxes()), ctx, y, a, b);
+  }
+  static void ReduceSum(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                        const XpuVarNdarray<T>& tmp_storage) {
+    return NdArrayReduce<device_type, T, BinaryFuncAdd>::Reduce(ctx, y, x, tmp_storage);
+  }
+  static void ReduceMax(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                        const XpuVarNdarray<T>& tmp_storage) {
+    return NdArrayReduce<device_type, T, BinaryFuncMax>::Reduce(ctx, y, x, tmp_storage);
+  }
+  static void ReduceMin(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x,
+                        const XpuVarNdarray<T>& tmp_storage) {
+    return NdArrayReduce<device_type, T, BinaryFuncMin>::Reduce(ctx, y, x, tmp_storage);
+  }
+  template<const T (*unary_func)(const T)>
+  static void ImplaceApplyUnary(DeviceCtx* ctx, const XpuVarNdarray<T>& y) {
+    return NdArrayApplyUnary<device_type, T, unary_func>::ImplaceApply(ctx, y);
+  }
+
+ private:
+  template<const T (*unary_func)(const T)>
+  struct Unary final {
+    template<int NDIMS>
+    static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                               const XpuVarNdarray<const T>& x) {
+      return NdArrayApplyBroadcastUnary<device_type, T, NDIMS, unary_func>::Apply(ctx, y, x);
+    }
+#define DEFINE_NDARRAY_BROADCAST_UNARY(func_name, NDIMS) \
+  NdarrayUtil<device_type, T>::Unary<unary_func>::func_name<NDIMS>
+    DEFINE_STATIC_SWITCH_FUNC(void, BroadcastApply, DEFINE_NDARRAY_BROADCAST_UNARY,
+                              MAKE_NDIM_CTRV_SEQ(DIM_SEQ));
+#undef DEFINE_NDARRAY_BROADCAST_UNARY
+  };
+
+  template<const T (*binary_func)(const T, const T)>
+  struct Binary final {
+    template<int NDIMS>
+    static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                               const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
+      return NdArrayApplyBroadcastBinary<device_type, T, NDIMS, binary_func>::Apply(ctx, y, a, b);
+    }
+#define DEFINE_NDARRAY_BROADCAST_BINARY(func_name, NDIMS) \
+  NdarrayUtil<device_type, T>::Binary<binary_func>::func_name<NDIMS>
+    DEFINE_STATIC_SWITCH_FUNC(void, BroadcastApply, DEFINE_NDARRAY_BROADCAST_BINARY,
+                              MAKE_NDIM_CTRV_SEQ(DIM_SEQ));
+#undef DEFINE_NDARRAY_BROADCAST_BINARY
+  };
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_
diff --git a/oneflow/core/ndarray/unary_func.h b/oneflow/core/ndarray/unary_func.h
new file mode 100644
index 0000000000000000000000000000000000000000..b6642a785ec5bde07bafccd390cd36f88bc1988b
--- /dev/null
+++ b/oneflow/core/ndarray/unary_func.h
@@ -0,0 +1,74 @@
+#ifndef ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_
+#define ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_
+
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+#define ARITHMETIC_UNARY_FUNC_SEQ         \
+  OF_PP_MAKE_TUPLE_SEQ(UnaryFuncIdentity) \
+  OF_PP_MAKE_TUPLE_SEQ(UnaryFuncMinus)    \
+  OF_PP_MAKE_TUPLE_SEQ(UnaryFuncLog2)     \
+  OF_PP_MAKE_TUPLE_SEQ(UnaryFuncExp2)
+
+template<typename T>
+OF_DEVICE_FUNC const T UnaryFuncIdentity(const T x) {
+  return x;
+}
+
+template<typename T>
+OF_DEVICE_FUNC const T UnaryFuncMinus(const T x) {
+  return -x;
+}
+
+template<typename T>
+OF_DEVICE_FUNC
+    typename std::enable_if<std::is_same<T, float>::value || std::is_same<T, double>::value,
+                            const T>::type
+    UnaryFuncLog2(const T x) {
+#if defined(__CUDACC__)
+  return log2(x);
+#else
+  return std::log2(x);
+#endif
+}
+
+template<typename T>
+OF_DEVICE_FUNC
+    typename std::enable_if<!(std::is_same<T, float>::value || std::is_same<T, double>::value),
+                            const T>::type
+    UnaryFuncLog2(const T x) {
+#if defined(__CUDACC__)
+  return log2(static_cast<float>(x));
+#else
+  return std::log2(x);
+#endif
+}
+
+template<typename T>
+OF_DEVICE_FUNC
+    typename std::enable_if<std::is_same<T, float>::value || std::is_same<T, double>::value,
+                            const T>::type
+    UnaryFuncExp2(const T x) {
+#if defined(__CUDACC__)
+  return exp2(x);
+#else
+  return std::exp2(x);
+#endif
+}
+
+template<typename T>
+OF_DEVICE_FUNC
+    typename std::enable_if<!(std::is_same<T, float>::value || std::is_same<T, double>::value),
+                            const T>::type
+    UnaryFuncExp2(const T x) {
+#if defined(__CUDACC__)
+  return exp2(static_cast<float>(x));
+#else
+  return std::exp2(x);
+#endif
+}
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_
diff --git a/oneflow/core/ndarray/xpu_binary_func_ndarray.h b/oneflow/core/ndarray/xpu_binary_func_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..be2debecfd9b197b0245e1af2649ad599a50ff08
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_binary_func_ndarray.h
@@ -0,0 +1,23 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_
+
+namespace oneflow {
+
+template<typename T, const T (*binary_func)(const T, const T), typename A, typename B>
+class XpuBinaryFuncNdarray final {
+ public:
+  OF_DEVICE_FUNC XpuBinaryFuncNdarray(const A& a, const B& b) : a_(a), b_(b) {}
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    return binary_func(a_.Get<NDIMS>(offset), b_.Get<NDIMS>(offset));
+  }
+
+ private:
+  const A& a_;
+  const B& b_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_
diff --git a/oneflow/core/ndarray/xpu_broadcast_ndarray.h b/oneflow/core/ndarray/xpu_broadcast_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..9f46e992df6b9ae718cda6a0c661d0f5eb7c9da7
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_broadcast_ndarray.h
@@ -0,0 +1,53 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_
+
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/xpu_ndarray_base.h"
+
+namespace oneflow {
+
+template<typename T, int NDIMS>
+struct XpuBroadcastNdarrayUtil;
+
+template<typename T>
+class XpuBroadcastNdarray final : public XpuNdarrayBase<XpuBroadcastNdarray<T>, T> {
+ public:
+  OF_DEVICE_FUNC XpuBroadcastNdarray(const XpuShape& shape, const XpuVarNdarray<T>& var)
+      : shape_(shape), var_(var) {}
+  OF_DEVICE_FUNC ~XpuBroadcastNdarray() = default;
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    int64_t coord[NDIMS];
+    shape_.template Offset2Coordinate<NDIMS>(offset, coord);
+    XpuBroadcastNdarrayUtil<T, NDIMS>::SrcCoordinate(var_.shape(), coord);
+    return var_.template Get<NDIMS>(coord);
+  }
+
+  OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; }
+  OF_DEVICE_FUNC const XpuVarNdarray<T>& var() const { return var_; }
+
+ private:
+  const XpuShape& shape_;
+  const XpuVarNdarray<T>& var_;
+};
+
+#define IMPLACE_SET_SRC_COORD(i) coord[i] %= src_shape.At(i);
+#define SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(n)                                                \
+  template<typename T>                                                                          \
+  struct XpuBroadcastNdarrayUtil<T, n + 1> final {                                              \
+    OF_DEVICE_FUNC static void SrcCoordinate(const XpuShape& src_shape, int64_t coord[n + 1]) { \
+      OF_PP_FOR_EACH_TUPLE(IMPLACE_SET_SRC_COORD, GET_SEQ(n));                                  \
+    }                                                                                           \
+  }
+SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(0);
+SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(1);
+SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(2);
+SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(3);
+SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(4);
+#undef SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL
+#undef IMPLACE_SET_SRC_COORD
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_
diff --git a/oneflow/core/ndarray/xpu_ndarray_assign.cu b/oneflow/core/ndarray/xpu_ndarray_assign.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c53ed205a87b4c019b5470651c61c1efc07ef4e5
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_ndarray_assign.cu
@@ -0,0 +1,29 @@
+#include "oneflow/core/ndarray/ndarray_assign_core.h"
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T, int NDIMS>
+__global__ void NdArrayAssignGpu(XpuVarNdarray<T> y, const XpuReducedNdarray<T, NDIMS> reduced) {
+  NdArrayAssignCore<T, NDIMS>::Assign(y, reduced);
+}
+
+}  // namespace
+
+template<typename T, int NDIMS>
+struct NdArrayAssignCoreWrapper<DeviceType::kGPU, T, NDIMS> final {
+  static void Assign(DeviceCtx* ctx, XpuVarNdarray<T>* y,
+                     const XpuReducedNdarray<T, NDIMS>& reduced) {
+    size_t n = y->host_shape().HostElemNum();
+    RUN_CUDA_KERNEL((NdArrayAssignGpu<T, NDIMS>), ctx, n, *y, reduced);
+  }
+};
+
+#define INSTANTIATE_NDARRAY_ASSIGN(dtype_pair, NDIMS) \
+  template struct NdArrayAssignCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS>;
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ, DIM_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/core/ndarray/xpu_ndarray_assign.h b/oneflow/core/ndarray/xpu_ndarray_assign.h
new file mode 100644
index 0000000000000000000000000000000000000000..ecf34075b2897287f49d6add7a5c9636f458f814
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_ndarray_assign.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_
+
+#include "oneflow/core/ndarray/ndarray_assign_core.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<DeviceType device_type, typename T>
+struct XpuNdArrayAssign final {
+  template<int NDIMS>
+  static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
+                     const XpuReducedNdarray<T, NDIMS>& reduced) {
+    NdArrayAssignCoreWrapper<device_type, T, NDIMS>::Assign(ctx, y, reduced);
+  }
+  static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
+    CHECK(y.shape() == x.shape());
+    if (x.ptr() == y.ptr()) { return; }
+    Memcpy<device_type>(ctx, y.ptr(), x.ptr(), y.shape().ElemNum() * sizeof(T));
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_
diff --git a/oneflow/core/ndarray/xpu_ndarray_base.h b/oneflow/core/ndarray/xpu_ndarray_base.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa3fa8e89fdcc23b950e630cafd4b022b4a41e76
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_ndarray_base.h
@@ -0,0 +1,48 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_
+
+namespace oneflow {
+
+template<typename T, const T (*unary_func)(const T), typename X>
+class XpuUnaryFuncNdarray;
+template<typename T, const T (*binary_func)(const T, const T), typename A, typename B>
+class XpuBinaryFuncNdarray;
+template<typename T>
+class XpuBroadcastNdarray;
+template<typename T, int, typename X>
+class XpuTransposeNdarray;
+template<typename T, int, typename X>
+class XpuReshapeNdarray;
+
+template<typename DerivedT, typename T>
+class XpuNdarrayBase {
+ public:
+  OF_DEVICE_FUNC XpuNdarrayBase() = default;
+  OF_DEVICE_FUNC ~XpuNdarrayBase() = default;
+
+  template<const T (*unary_func)(const T)>
+  OF_DEVICE_FUNC XpuUnaryFuncNdarray<T, unary_func, DerivedT> UnaryFunc() const {
+    return XpuUnaryFuncNdarray<T, unary_func, DerivedT>(*static_cast<const DerivedT*>(this));
+  }
+  template<const T (*binary_func)(const T, const T), typename X>
+  OF_DEVICE_FUNC XpuBinaryFuncNdarray<T, binary_func, DerivedT, X> BinaryFunc(const X& x) const {
+    return XpuBinaryFuncNdarray<T, binary_func, DerivedT, X>(*static_cast<const DerivedT*>(this),
+                                                             x);
+  }
+  OF_DEVICE_FUNC XpuBroadcastNdarray<const T> Broadcast(const XpuShape& shape) const {
+    return XpuBroadcastNdarray<const T>(shape, *static_cast<const DerivedT*>(this));
+  }
+  template<int NDIMS>
+  OF_DEVICE_FUNC XpuTransposeNdarray<T, NDIMS, DerivedT> Transpose(
+      const int64_t perm[NDIMS]) const {
+    return XpuTransposeNdarray<T, NDIMS, DerivedT>(*static_cast<const DerivedT*>(this), perm);
+  }
+  template<int NDIMS>
+  OF_DEVICE_FUNC XpuReshapeNdarray<T, NDIMS, DerivedT> Reshape(const int64_t shape[NDIMS]) {
+    return XpuReshapeNdarray<T, NDIMS, DerivedT>(*static_cast<const DerivedT*>(this), shape);
+  }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_
diff --git a/oneflow/core/ndarray/xpu_reduced_ndarray.h b/oneflow/core/ndarray/xpu_reduced_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..d2bdd21e979db5406fecf023dcd3ba3b7c874c4f
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_reduced_ndarray.h
@@ -0,0 +1,51 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_
+
+#include "oneflow/core/ndarray/xpu_var_ndarray.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+#include "oneflow/core/ndarray/unary_func.h"
+
+namespace oneflow {
+
+template<typename T, int NDIMS, typename X = XpuVarNdarray<T>>
+class XpuReducedNdarray final {
+ public:
+  OF_DEVICE_FUNC XpuReducedNdarray(const XpuShape& shape, const X& data)
+      : shape_(shape), data_(data) {}
+
+  OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; }
+  const XpuShape& host_shape() const { return shape_; }
+  OF_DEVICE_FUNC const X& data() const { return data_; }
+
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    int64_t coord[NDIMS];
+    shape_.template Offset2Coordinate<NDIMS>(offset, coord);
+    return Get(coord);
+  }
+
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const {
+    return data_.template Get<ndims>(coord);
+  }
+
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T* Mut(int64_t offset) const {
+    int64_t coord[NDIMS];
+    shape_.template Offset2Coordinate<NDIMS>(offset, coord);
+    return Mut(coord);
+  }
+
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {
+    return data_.template Mut<NDIMS>(coord);
+  }
+
+ private:
+  XpuShape shape_;
+  X data_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_
diff --git a/oneflow/core/ndarray/xpu_reshape_ndarray.h b/oneflow/core/ndarray/xpu_reshape_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..612fa8fad97acc8f047f357de935cb1a532b5f2d
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_reshape_ndarray.h
@@ -0,0 +1,39 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_
+
+namespace oneflow {
+
+template<typename T, int NDIMS, typename X = XpuVarNdarray<T>>
+class XpuReshapeNdarray final {
+ public:
+  OF_DEVICE_FUNC XpuReshapeNdarray(const X& x, const int64_t dim[NDIMS])
+      : x_(x), shape_(dim, NDIMS) {}
+
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    return x_.template Get<ndims>(offset);
+  }
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T* Mut(int64_t offset) const {
+    return x_.template Mut<ndims>(offset);
+  }
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const {
+    return Get<ndims>(Coord2Offset(coord));
+  }
+  template<int ndims = NDIMS>
+  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {
+    return Get<NDIMS>(Coord2Offset(coord));
+  }
+
+ private:
+  OF_DEVICE_FUNC int64_t Coord2Offset(const int64_t coord[NDIMS]) const {
+    return XpuShapeUtil<NDIMS>::Coord2Offset(shape_, coord);
+  }
+  const X& x_;
+  XpuShape shape_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_
diff --git a/oneflow/core/ndarray/xpu_shape.h b/oneflow/core/ndarray/xpu_shape.h
new file mode 100644
index 0000000000000000000000000000000000000000..0099085d7b1155dc47d05ca1e9f0c46367106f19
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_shape.h
@@ -0,0 +1,96 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_
+
+#include "oneflow/core/common/shape.h"
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+
+namespace oneflow {
+
+template<int NDIMS>
+struct XpuShapeUtil;
+
+class XpuShape final {
+ public:
+  explicit XpuShape(const Shape& shape);
+  OF_DEVICE_FUNC XpuShape(const int64_t dim[], int num_axes);
+  OF_DEVICE_FUNC XpuShape(const XpuShape&) = default;
+
+  OF_DEVICE_FUNC int64_t At(int64_t dim) const { return dim_[dim]; }
+
+  OF_DEVICE_FUNC size_t ElemNum() const { return elem_num_; }
+  OF_DEVICE_FUNC size_t NumAxes() const { return num_axes_; }
+  size_t HostElemNum() const { return elem_num_; }
+  bool operator==(const XpuShape&) const;
+  bool operator!=(const XpuShape& rhs) const { return !(*this == rhs); }
+
+  OF_DEVICE_FUNC void Set(int64_t axis, int64_t value) {
+    dim_[axis] = value;
+    UpdateDimElemNumAndElemNum();
+  }
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC int64_t Coordinate2Offset(const int64_t coord[NDIMS]) const {
+    return XpuShapeUtil<NDIMS>::Coordinate2Offset(*this, coord);
+  }
+  template<int NDIMS>
+  OF_DEVICE_FUNC void Offset2Coordinate(int64_t offset, int64_t coord[NDIMS]) const {
+    XpuShapeUtil<NDIMS>::Offset2Coordinate(*this, offset, coord);
+  }
+
+  OF_DEVICE_FUNC void UpdateDimElemNumAndElemNum() {
+    elem_num_ = 1;
+    for (int i = num_axes_ - 1; i >= 0; --i) {
+      dim_elem_num_[i] = elem_num_;
+      elem_num_ *= dim_[i];
+    }
+  }
+
+  size_t num_axes_;
+  size_t elem_num_;
+  int64_t dim_[OF_PP_SEQ_SIZE(DIM_SEQ)];
+  int64_t dim_elem_num_[OF_PP_SEQ_SIZE(DIM_SEQ)];
+};
+
+template<>
+struct XpuShapeUtil<1> final {
+  OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape, const int64_t coord[1]) {
+    return coord[0];
+  }
+  OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset,
+                                               int64_t coord[1]) {
+    coord[0] = offset;
+  }
+};
+
+#define COORD_MUL_STRIDE(i) coord[i] * shape.dim_elem_num_[i] +
+#define EXTRACT_COORD(i)                      \
+  coord[i] = offset / shape.dim_elem_num_[i]; \
+  offset %= shape.dim_elem_num_[i];
+
+#define SPECIALIZE_XPU_SHAPE_UTIL(n)                                                    \
+  template<>                                                                            \
+  struct XpuShapeUtil<n + 2> final {                                                    \
+    OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape,              \
+                                                    const int64_t coord[n + 2]) {       \
+      return OF_PP_FOR_EACH_TUPLE(COORD_MUL_STRIDE, GET_SEQ(n)) coord[n + 1];           \
+    }                                                                                   \
+    OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset, \
+                                                 int64_t coord[n + 2]) {                \
+      OF_PP_FOR_EACH_TUPLE(EXTRACT_COORD, GET_SEQ(n));                                  \
+      coord[n + 1] = offset;                                                            \
+    }                                                                                   \
+  };
+
+SPECIALIZE_XPU_SHAPE_UTIL(0);
+SPECIALIZE_XPU_SHAPE_UTIL(1);
+SPECIALIZE_XPU_SHAPE_UTIL(2);
+SPECIALIZE_XPU_SHAPE_UTIL(3);
+#undef SPECIALIZE_XPU_SHAPE_UTIL
+#undef EXTRACT_COORD
+#undef COORD_MUL_STRIDE
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_
diff --git a/oneflow/core/ndarray/xpu_transpose_ndarray.h b/oneflow/core/ndarray/xpu_transpose_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..e233076541d6c6ee558214f185c850b2fe98dce6
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_transpose_ndarray.h
@@ -0,0 +1,64 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_
+
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+template<typename T, int NDIMS, typename X = XpuVarNdarray<T>>
+class XpuTransposeNdarray final {
+ public:
+  OF_DEVICE_FUNC XpuTransposeNdarray(const X& x, const int64_t perm[NDIMS])
+      : x_(x), shape_(x.shape()) {
+    for (int i = 0; i < NDIMS; ++i) {
+      perm_[i] = perm[i];
+      shape_.Set(i, x.shape().At(perm[i]));
+    }
+  }
+
+  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    int64_t coord[NDIMS];
+    Offset2Coord(offset, coord);
+    return Get(coord);
+  }
+
+  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>
+  OF_DEVICE_FUNC T* Mut(int64_t offset) const {
+    int64_t coord[NDIMS];
+    Offset2Coord(offset, coord);
+    return Mut(coord);
+  }
+
+  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>
+  OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const {
+    int64_t permuted_coord[NDIMS];
+    PermuteCoord(coord, permuted_coord);
+    return x_.template Get<ndims>(permuted_coord);
+  }
+
+  template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type>
+  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {
+    int64_t permuted_coord[NDIMS];
+    PermuteCoord(coord, permuted_coord);
+    return x_.template Mut<NDIMS>(permuted_coord);
+  }
+
+ private:
+  OF_DEVICE_FUNC void Offset2Coord(int64_t offset, int64_t coord[NDIMS]) const {
+    shape_.Offset2Coordinate<NDIMS>(offset, coord);
+  }
+
+  OF_DEVICE_FUNC void PermuteCoord(const int64_t coord[NDIMS],
+                                   int64_t permuted_coord[NDIMS]) const {
+    for (int i = 0; i < NDIMS; ++i) { permuted_coord[perm_[i]] = coord[i]; }
+  }
+
+  const X& x_;
+  XpuShape shape_;
+  int64_t perm_[NDIMS];
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_
diff --git a/oneflow/core/ndarray/xpu_unary_func_ndarray.h b/oneflow/core/ndarray/xpu_unary_func_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..d4dbdd2a4efb399cb695c14403255538eae19ad6
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_unary_func_ndarray.h
@@ -0,0 +1,22 @@
+#ifndef ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_
+#define ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_
+
+namespace oneflow {
+
+template<typename T, const T (*unary_func)(const T), typename X>
+class XpuUnaryFuncNdarray final {
+ public:
+  OF_DEVICE_FUNC XpuUnaryFuncNdarray(const X& x) : x_(x) {}
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    return unary_func(x_.Get<NDIMS>(offset));
+  }
+
+ private:
+  const X& x_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_
diff --git a/oneflow/core/ndarray/xpu_util.h b/oneflow/core/ndarray/xpu_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..2876e46b1a6112013cb11ffd417fd36b4062757d
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_util.h
@@ -0,0 +1,40 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_
+
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/device/cuda_util.h"
+
+namespace oneflow {
+
+#if defined(__CUDACC__)
+#define XPU_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP(i, n)
+#else
+#define XPU_1D_KERNEL_LOOP(i, n) FOR_RANGE(int64_t, i, 0, n)
+#endif
+
+#if defined(__CUDACC__)
+#define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n)     \
+  for (int64_t i = blockIdx.x; i < (m); i += gridDim.x) \
+    for (int64_t j = threadIdx.x; j < (n); j += blockDim.x)
+#else
+#define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n) \
+  for (int64_t i = 0; i < (m); ++i)                 \
+    for (int64_t j = 0; j < (n); ++j)
+#endif
+
+#if defined(__CUDACC__)
+#define OF_GLOBAL_FUNC __global__
+#else
+#define OF_GLOBAL_FUNC
+#endif
+
+#define GET_SEQ(n) OF_PP_CAT(OF_PP_CAT(GET_SEQ_, n), )
+#define GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(0)
+#define GET_SEQ_1 GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(1)
+#define GET_SEQ_2 GET_SEQ_1 OF_PP_MAKE_TUPLE_SEQ(2)
+#define GET_SEQ_3 GET_SEQ_2 OF_PP_MAKE_TUPLE_SEQ(3)
+#define GET_SEQ_4 GET_SEQ_3 OF_PP_MAKE_TUPLE_SEQ(4)
+#define GET_SEQ_5 GET_SEQ_5 OF_PP_MAKE_TUPLE_SEQ(5)
+}
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_
diff --git a/oneflow/core/ndarray/xpu_var_ndarray.h b/oneflow/core/ndarray/xpu_var_ndarray.h
new file mode 100644
index 0000000000000000000000000000000000000000..f7ecdc5550611f5bf0479018a6b7bda056c4b921
--- /dev/null
+++ b/oneflow/core/ndarray/xpu_var_ndarray.h
@@ -0,0 +1,73 @@
+#ifndef ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_
+#define ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_
+
+#include "oneflow/core/ndarray/xpu_shape.h"
+#include "oneflow/core/kernel/kernel_util.h"
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/register/blob.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+#include "oneflow/core/ndarray/xpu_ndarray_base.h"
+
+namespace oneflow {
+
+template<typename T>
+class XpuVarNdarray final : public XpuNdarrayBase<XpuVarNdarray<T>, T> {
+ public:
+  explicit XpuVarNdarray(const Blob* blob, int ndims_extend_to)
+      : shape_(blob->shape().CreateLeftExtendedShape(ndims_extend_to)),
+        ptr_(blob->dptr<typename std::remove_const<T>::type>()) {}
+  explicit XpuVarNdarray(Blob* blob, int ndims_extend_to)
+      : shape_(blob->shape().CreateLeftExtendedShape(ndims_extend_to)), ptr_(blob->mut_dptr<T>()) {}
+  XpuVarNdarray(const Shape& shape, T* ptr) : shape_(shape), ptr_(ptr) {}
+  OF_DEVICE_FUNC ALWAYS_INLINE XpuVarNdarray(const XpuVarNdarray&) = default;
+  OF_DEVICE_FUNC ALWAYS_INLINE XpuVarNdarray(const XpuShape& shape, T* ptr)
+      : shape_(shape), ptr_(ptr) {}
+
+  const XpuShape& host_shape() const { return shape_; }
+  T* host_ptr() const { return ptr_; }
+
+  OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; }
+  OF_DEVICE_FUNC T* ptr() const { return ptr_; }
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t offset) const {
+    return ptr_[offset];
+  }
+  template<int NDIMS>
+  OF_DEVICE_FUNC T Get(int64_t coord[NDIMS]) const {
+    return ptr_[shape().template Coordinate2Offset<NDIMS>(coord)];
+  }
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC T* Mut(int64_t offset) const {
+    return ptr_ + offset;
+  }
+
+  template<int NDIMS>
+  OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const {
+    return ptr_ + shape().template Coordinate2Offset<NDIMS>(coord);
+  }
+
+  template<int NDIMS, typename X>
+  OF_DEVICE_FUNC void Assign(const X& x) const {
+    size_t n = shape_.ElemNum();
+    XPU_1D_KERNEL_LOOP(i, n) { ptr_[i] = x.template Get<NDIMS>(i); }
+  }
+
+  template<const T (*binary_func)(const T, const T), int NDIMS, typename X>
+  OF_DEVICE_FUNC void BinaryAssign(const X& x) const {
+    size_t n = shape_.ElemNum();
+    XPU_1D_KERNEL_LOOP(i, n) {
+      T* ptr_i = ptr_ + i;
+      *ptr_i = binary_func(*ptr_i, x.template Get<NDIMS>(i));
+    }
+  }
+
+ private:
+  XpuShape shape_;
+  T* ptr_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_
diff --git a/oneflow/core/operator/accumulate_op.h b/oneflow/core/operator/accumulate_op.h
index b3bd2dfa84a95a25c359215f6a35cfc7ead4d7ff..89e798ac34969666d7b96ba2aa9016ec0af502bc 100644
--- a/oneflow/core/operator/accumulate_op.h
+++ b/oneflow/core/operator/accumulate_op.h
@@ -16,8 +16,15 @@ class AccumulateOp final : public Operator {
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override {}
+  void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+                                const ParallelContext* parallel_ctx,
+                                Shape* time_shape) const override {
+    TODO();
+  }
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
 };
diff --git a/oneflow/core/operator/accuracy_op.cpp b/oneflow/core/operator/accuracy_op.cpp
index 40372efffe12cfeb4e04fe31d85a469da6bac12b..1d742a6b13184e9e4528d83c44c02f482d45baed 100644
--- a/oneflow/core/operator/accuracy_op.cpp
+++ b/oneflow/core/operator/accuracy_op.cpp
@@ -7,6 +7,10 @@ void AccuracyOp::InitFromOpConf() {
   EnrollInputBn("label", false);
   EnrollOutputBn("accuracy", false);
   EnrollOutputBn("accuracy_instance_num", false);
+  if (op_conf().accuracy_conf().has_weight()) {
+    EnrollInputBn("weight", false);
+    EnrollDataTmpBn("weight_reduce_tmp");
+  }
 }
 
 const PbMessage& AccuracyOp::GetCustomizedConf() const { return op_conf().accuracy_conf(); }
@@ -20,8 +24,7 @@ void AccuracyOp::VirtualGenKernelConf(
 }
 
 void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                const ParallelContext* parallel_ctx,
-                                std::function<void(OpContext*)>) const {
+                                const ParallelContext* parallel_ctx) const {
   BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
   BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
   CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field());
@@ -34,6 +37,20 @@ void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> Get
   CHECK_GE(pred_blob_desc->shape().NumAxes(), 2);
   CHECK_EQ(label_blob_desc->shape(), Shape({pred_blob_desc->shape().At(0)}));
 
+  if (op_conf().accuracy_conf().has_weight()) {
+    const BlobDesc* weight = GetBlobDesc4BnInOp("weight");
+    CHECK_EQ(weight->shape(), label_blob_desc->shape());
+    CHECK_EQ(weight->data_type(), pred_blob_desc->data_type());
+    CHECK_EQ(weight->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field());
+    CHECK_EQ(weight->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape());
+    if (label_blob_desc->has_dim0_inner_shape()) {
+      CHECK_EQ(weight->dim0_inner_shape(), label_blob_desc->dim0_inner_shape());
+    }
+    BlobDesc* weight_reduce_tmp = GetBlobDesc4BnInOp("weight_reduce_tmp");
+    weight_reduce_tmp->mut_shape() = weight->shape();
+    weight_reduce_tmp->set_data_type(weight->data_type());
+  }
+
   // accuracy
   BlobDesc* accuracy_blob_desc = GetBlobDesc4BnInOp("accuracy");
   *accuracy_blob_desc = *pred_blob_desc;
diff --git a/oneflow/core/operator/accuracy_op.h b/oneflow/core/operator/accuracy_op.h
index 2fea279dc66a0b654e932adf850d6f06e784486d..ece825f2c64da93b5edac889217266ffbe312dbf 100644
--- a/oneflow/core/operator/accuracy_op.h
+++ b/oneflow/core/operator/accuracy_op.h
@@ -17,12 +17,13 @@ class AccuracyOp final : public Operator {
 
   const PbMessage& GetCustomizedConf() const override;
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx,
-                      std::function<void(OpContext*)> EnrollOpCtx) const override;
+                      const ParallelContext* parallel_ctx) const override;
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
 
diff --git a/oneflow/core/operator/accuracy_print_op.h b/oneflow/core/operator/accuracy_print_op.h
index 60bf09bfcb84fa86cfdab79bf7e9e1bd7aa4e352..ac027e6210d5c038913661ef09965021eb0d0778 100644
--- a/oneflow/core/operator/accuracy_print_op.h
+++ b/oneflow/core/operator/accuracy_print_op.h
@@ -15,6 +15,8 @@ class AccuracyPrintOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override;
 };
 
diff --git a/oneflow/core/operator/adam_model_update_op.cpp b/oneflow/core/operator/adam_model_update_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f95358e170521cb74b90af29244751429148423f
--- /dev/null
+++ b/oneflow/core/operator/adam_model_update_op.cpp
@@ -0,0 +1,40 @@
+#include "oneflow/core/operator/adam_model_update_op.h"
+
+namespace oneflow {
+
+void AdamModelUpdateOp::MdUpdtVirtualInitFromOpConf() {
+  const auto& adam_conf = op_conf().normal_mdupdt_conf().user_conf().adam_conf();
+  CHECK_GE(adam_conf.beta1(), 0);
+  CHECK_LT(adam_conf.beta1(), 1);
+  CHECK_GE(adam_conf.beta2(), 0);
+  CHECK_LT(adam_conf.beta2(), 1);
+
+  EnrollForwardModelBn("m");
+  EnrollForwardModelBn("v");
+  if (adam_conf.do_bias_correction()) {
+    EnrollForwardModelBn("beta1_t");
+    EnrollForwardModelBn("beta2_t");
+  }
+}
+
+void AdamModelUpdateOp::MdUpdtVirtualInferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  const auto& adam_conf = op_conf().normal_mdupdt_conf().user_conf().adam_conf();
+  const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model");
+  CHECK_EQ(model_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
+  CHECK_EQ(model_blob_desc->has_data_id_field(), false);
+  *GetBlobDesc4BnInOp("m") = *model_blob_desc;
+  *GetBlobDesc4BnInOp("v") = *model_blob_desc;
+
+  if (adam_conf.do_bias_correction()) {
+    *GetBlobDesc4BnInOp("beta1_t") = *model_blob_desc;
+    *GetBlobDesc4BnInOp("beta2_t") = *model_blob_desc;
+    GetBlobDesc4BnInOp("beta1_t")->mut_shape() = Shape({1});
+    GetBlobDesc4BnInOp("beta2_t")->mut_shape() = Shape({1});
+  }
+}
+
+REGISTER_CLASS(NormalModelUpdateOpUserConf::kAdamConf, NormalModelUpdtOp, AdamModelUpdateOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/adam_model_update_op.h b/oneflow/core/operator/adam_model_update_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..21cc800bfac47a8f578b29ef3a6f039b7d130a24
--- /dev/null
+++ b/oneflow/core/operator/adam_model_update_op.h
@@ -0,0 +1,22 @@
+#ifndef ONEFLOW_CORE_OPERATOR_ADAM_MODEL_UPDATE_OP_H_
+#define ONEFLOW_CORE_OPERATOR_ADAM_MODEL_UPDATE_OP_H_
+
+#include "oneflow/core/operator/normal_model_update_op.h"
+
+namespace oneflow {
+
+class AdamModelUpdateOp final : public NormalModelUpdtOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(AdamModelUpdateOp);
+  AdamModelUpdateOp() = default;
+  ~AdamModelUpdateOp() = default;
+
+ private:
+  void MdUpdtVirtualInitFromOpConf() override;
+  void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                   const ParallelContext* parallel_ctx) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_ADAM_MODEL_UPDATE_OP_H_
diff --git a/oneflow/core/operator/add_op.cpp b/oneflow/core/operator/add_op.cpp
index 17d9ff2b6aa2ab615aaad19d10dbb787d1081f0a..365e734186acba54b8b885c078cf5f8807689605 100644
--- a/oneflow/core/operator/add_op.cpp
+++ b/oneflow/core/operator/add_op.cpp
@@ -4,8 +4,9 @@ namespace oneflow {
 
 void AddOp::VirtualInitFromOpConf() { CHECK(op_conf().has_add_conf()); }
 const PbMessage& AddOp::GetCustomizedConf() const { return op_conf().add_conf(); }
-void AddOp::FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                               const ParallelContext* parallel_ctx) const {
+void AddOp::VirtualFixInDiffBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
   if (!Global<JobDesc>::Get()->enable_blob_mem_sharing()) { return; }
   int64_t blob_mem_id = oneflow_cast<int64_t>(NewUniqueId());
   FOR_RANGE(size_t, i, 0, input_diff_bns().size()) {
diff --git a/oneflow/core/operator/add_op.h b/oneflow/core/operator/add_op.h
index c607e03a2a153d39b3e5eb984fa1e9fc3095d43a..086cb16100550c1f31266030cc792189ff78627f 100644
--- a/oneflow/core/operator/add_op.h
+++ b/oneflow/core/operator/add_op.h
@@ -15,9 +15,13 @@ class AddOp final : public CWiseOp {
   const PbMessage& GetCustomizedConf() const override;
   bool NeedInBlobWhenBackward() const override { return false; }
   bool NeedOutBlobWhenBackward() const override { return false; }
-  virtual void FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                  const ParallelContext*) const override;
+  void VirtualFixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                 const ParallelContext*) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_OPERATOR_ADD_OP_H_
diff --git a/oneflow/core/operator/basic_rnn_op.h b/oneflow/core/operator/basic_rnn_op.h
index d01a167b1903479cb0ca0129f47166436c20edfd..99b4dcd29ef500cd4f2681e6024bef716314ee6b 100644
--- a/oneflow/core/operator/basic_rnn_op.h
+++ b/oneflow/core/operator/basic_rnn_op.h
@@ -13,6 +13,8 @@ class BasicRnnOp final : public RecurrentOp {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   void VirtualInitFromOpConf();
   void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                              const ParallelContext* parallel_ctx) const;
diff --git a/oneflow/core/operator/batch_gather_op.cpp b/oneflow/core/operator/batch_gather_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..989fe38f3d1be0b91d25b88276974f9795ccf9e0
--- /dev/null
+++ b/oneflow/core/operator/batch_gather_op.cpp
@@ -0,0 +1,38 @@
+#include "oneflow/core/operator/batch_gather_op.h"
+
+namespace oneflow {
+
+void BatchGatherOp::InitFromOpConf() {
+  CHECK(op_conf().has_batch_gather_conf());
+  EnrollInputBn("in");
+  EnrollInputBn("indices", false);
+  EnrollOutputBn("out");
+}
+
+const PbMessage& BatchGatherOp::GetCustomizedConf() const { return op_conf().batch_gather_conf(); }
+
+void BatchGatherOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                   const ParallelContext* parallel_ctx) const {
+  const BlobDesc* in = GetBlobDesc4BnInOp("in");
+  CHECK_GT(in->shape().NumAxes(), 0);
+  const BlobDesc* indices = GetBlobDesc4BnInOp("indices");
+  CHECK_GT(indices->shape().NumAxes(), 0);
+  CHECK(IsIntegralDataType(indices->data_type()));
+  const std::vector<int64_t>& in_dim_vec = in->shape().dim_vec();
+  const std::vector<int64_t>& indices_dim_vec = indices->shape().dim_vec();
+  CHECK_LE(indices_dim_vec.size(), in_dim_vec.size());
+  FOR_RANGE(int64_t, i, 0, indices_dim_vec.size() - 1) {
+    CHECK_EQ(indices_dim_vec.at(i), in_dim_vec.at(i));
+  }
+  // out
+  std::vector<int64_t> out_dim_vec(indices_dim_vec);
+  out_dim_vec.insert(out_dim_vec.end(), in_dim_vec.cbegin() + indices_dim_vec.size(),
+                     in_dim_vec.cend());
+  BlobDesc* out = GetBlobDesc4BnInOp("out");
+  *out = *in;
+  out->mut_shape() = Shape(out_dim_vec);
+}
+
+REGISTER_OP(OperatorConf::kBatchGatherConf, BatchGatherOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/batch_gather_op.h b/oneflow/core/operator/batch_gather_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..469100f9057ae2482edd72dd63fddbba593f78cd
--- /dev/null
+++ b/oneflow/core/operator/batch_gather_op.h
@@ -0,0 +1,26 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BATCH_GATHER_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BATCH_GATHER_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class BatchGatherOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BatchGatherOp);
+  BatchGatherOp() = default;
+  ~BatchGatherOp() override = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BATCH_GATHER_OP_H_
diff --git a/oneflow/core/operator/bias_add_op.cpp b/oneflow/core/operator/bias_add_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..66b0107cc3a6f78a35d8048674be71cf995b8821
--- /dev/null
+++ b/oneflow/core/operator/bias_add_op.cpp
@@ -0,0 +1,45 @@
+#include "oneflow/core/operator/bias_add_op.h"
+#include "oneflow/core/common/balanced_splitter.h"
+
+namespace oneflow {
+
+void BiasAddOp::InitFromOpConf() {
+  CHECK(op_conf().has_bias_add_conf());
+  EnrollInputBn("a");
+  EnrollInputBn("b");
+  EnrollOutputBn("out");
+  EnrollConstBufBn("bias_multiplier");
+}
+
+const PbMessage& BiasAddOp::GetCustomizedConf() const { return op_conf().bias_add_conf(); }
+
+bool BiasAddOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const {
+  CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end());
+  return ibn == "b";
+}
+
+void BiasAddOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(Make_DS_MB_2_DS_OpParallelSignature(this));
+  auto EqZero = [](int32_t axis) { return axis == 0; };
+  op_parallel_signatures->emplace_back(Make_DB_MS_2_MS_OpParallelSignature(this, EqZero));
+}
+
+void BiasAddOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                               const ParallelContext* parallel_ctx) const {
+  const BlobDesc* a_blob_desc = GetBlobDesc4BnInOp("a");
+  const BlobDesc* b_blob_desc = GetBlobDesc4BnInOp("b");
+
+  CHECK_EQ(a_blob_desc->shape().NumAxes(), 2);
+  CHECK_EQ(b_blob_desc->shape().NumAxes(), 1);
+  CHECK_EQ(a_blob_desc->shape().At(1), b_blob_desc->shape().At(0));
+  CHECK_EQ(a_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
+  CHECK_EQ(b_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
+
+  *GetBlobDesc4BnInOp("out") = *a_blob_desc;
+  GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({a_blob_desc->shape().At(0), 1});
+}
+
+REGISTER_OP(OperatorConf::kBiasAddConf, BiasAddOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/bias_add_op.h b/oneflow/core/operator/bias_add_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..f9fa0332199b731969fe8c5fdfbb5762d0ef9f1a
--- /dev/null
+++ b/oneflow/core/operator/bias_add_op.h
@@ -0,0 +1,34 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BIAS_ADD_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BIAS_ADD_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class BiasAddOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BiasAddOp);
+  BiasAddOp() = default;
+  ~BiasAddOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override {
+    return 1;
+  }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override;
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BIAS_ADD_OP_H_
diff --git a/oneflow/core/operator/boxing_op.h b/oneflow/core/operator/boxing_op.h
index 2df43329bbc9d782c92143058d2f76fba8f79e9b..a5607953e4f7cd8d6a9f23a089f52a145407af53 100644
--- a/oneflow/core/operator/boxing_op.h
+++ b/oneflow/core/operator/boxing_op.h
@@ -23,6 +23,8 @@ class BoxingOp final : public Operator {
                             KernelConf* kernel_conf) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override;
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
   void InferDataTmpBlobDesc(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
diff --git a/oneflow/core/operator/broadcast_add_op.cpp b/oneflow/core/operator/broadcast_add_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..67634bca061c1132b1d20d5225f7cae704ceba4a
--- /dev/null
+++ b/oneflow/core/operator/broadcast_add_op.cpp
@@ -0,0 +1,11 @@
+#include "oneflow/core/operator/broadcast_add_op.h"
+
+namespace oneflow {
+
+const PbMessage& BroadcastAddOp::GetCustomizedConf() const {
+  return op_conf().broadcast_add_conf();
+}
+
+REGISTER_OP(OperatorConf::kBroadcastAddConf, BroadcastAddOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/broadcast_add_op.h b/oneflow/core/operator/broadcast_add_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..94dbcf02a75d8d2eaf8957cd1dd51213371dfb1f
--- /dev/null
+++ b/oneflow/core/operator/broadcast_add_op.h
@@ -0,0 +1,22 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_ADD_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BROADCAST_ADD_OP_H_
+
+#include "oneflow/core/operator/broadcast_binary_op.h"
+
+namespace oneflow {
+
+class BroadcastAddOp final : public BroadcastBinaryOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastAddOp);
+  BroadcastAddOp() = default;
+  ~BroadcastAddOp() = default;
+
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+  const PbMessage& GetCustomizedConf() const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BROADCAST_ADD_OP_H_
diff --git a/oneflow/core/operator/broadcast_binary_op.cpp b/oneflow/core/operator/broadcast_binary_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b411e90f026f260e5c94ed41da53272c92ada30b
--- /dev/null
+++ b/oneflow/core/operator/broadcast_binary_op.cpp
@@ -0,0 +1,119 @@
+#include "oneflow/core/operator/broadcast_binary_op.h"
+
+namespace oneflow {
+
+namespace {
+
+bool IsScalarBlob(const BlobDesc* blob) {
+  return blob->shape().NumAxes() == 1 && blob->shape().At(0) == 1;
+}
+
+class BroadcastBinaryOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastBinaryOpParallelSignature);
+  ~BroadcastBinaryOpParallelSignature() override = default;
+
+  BroadcastBinaryOpParallelSignature(const Operator* op,
+                                     const HashSet<std::string>& model_input_bns)
+      : OpParallelSignature(op), model_input_bns_(model_input_bns) {}
+
+  const std::string Description() const override {
+    return op().op_name() + ": (C, ..., S(0), ...) -> (S(0), ...)";
+  }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    for (const auto& bn : op().input_bns()) {
+      const auto& sbp_infer_hint = SbpInferHint4Ibn(bn);
+      bool is_model_input_bns = (model_input_bns_.find(bn) != model_input_bns_.end());
+      bool has_actual_model_input = sbp_infer_hint.is_model_blob();
+      if (is_model_input_bns ^ has_actual_model_input) {
+        return MakeOpParallelMatchSignatureMismatch();
+      }
+    }
+    if (parallel_ctx->policy() == kDataParallel) { return MakeOpParallelMatchSuccess(); }
+    return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kDataParallel);
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    for (const auto& bn : op().input_bns()) {
+      if (model_input_bns_.find(bn) != model_input_bns_.end()) {
+        (*bn2sbp)[bn].mutable_broadcast_parallel();
+      } else {
+        const auto& in_sbp = SbpInferHint4Ibn(bn).sbp_parallel();
+        if (in_sbp.has_broadcast_parallel()) {
+          (*bn2sbp)[bn].mutable_broadcast_parallel();
+        } else {
+          (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0);
+        }
+      }
+    }
+    for (const auto& bn : op().output_bns()) {
+      (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0);
+    }
+  }
+
+ private:
+  HashSet<std::string> model_input_bns_;
+};
+
+std::unique_ptr<const OpParallelSignature> MakeBroadcastBinaryOpParallelSignature(
+    const Operator* op, const HashSet<std::string>& model_input_bns) {
+  return std::unique_ptr<const OpParallelSignature>(
+      new BroadcastBinaryOpParallelSignature(op, model_input_bns));
+}
+
+}  // namespace
+
+void BroadcastBinaryOp::InitFromOpConf() {
+  EnrollInputBn("a");
+  EnrollInputBn("b");
+  EnrollOutputBn("out");
+  EnrollBwBufBn("bw_buf");
+}
+
+void BroadcastBinaryOp::InferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  const BlobDesc* a_blob_desc = GetBlobDesc4BnInOp("a");
+  const BlobDesc* b_blob_desc = GetBlobDesc4BnInOp("b");
+  CHECK_EQ(a_blob_desc->data_type(), b_blob_desc->data_type());
+  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
+  size_t output_num_axes = std::max(a_blob_desc->shape().NumAxes(), b_blob_desc->shape().NumAxes());
+  if (IsScalarBlob(a_blob_desc)) {
+    *out_blob_desc = *b_blob_desc;
+  } else if (IsScalarBlob(b_blob_desc)) {
+    *out_blob_desc = *a_blob_desc;
+  } else {
+    const auto& a_shape = a_blob_desc->shape().CreateLeftExtendedShape(output_num_axes);
+    const auto& b_shape = b_blob_desc->shape().CreateLeftExtendedShape(output_num_axes);
+    *out_blob_desc = *a_blob_desc;
+    Shape out_shape(a_shape);
+    FOR_RANGE(int64_t, i, 0, a_shape.NumAxes()) {
+      CHECK(a_shape.At(i) == 1 || b_shape.At(i) == 1 || a_shape.At(i) == b_shape.At(i));
+      out_shape.Set(i, std::max(a_shape.At(i), b_shape.At(i)));
+    }
+    out_blob_desc->mut_shape() = out_shape;
+  }
+}
+
+void BroadcastBinaryOp::InferBwBufBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const {
+  const BlobDesc* out = GetBlobDesc4BnInOp("out");
+  BlobDesc* bw_buf = GetBlobDesc4BnInOp("bw_buf");
+  bw_buf->mut_shape() = Shape({out->shape().elem_cnt()});
+  bw_buf->set_data_type(out->data_type());
+}
+
+void BroadcastBinaryOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {}));
+  op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {"a"}));
+  op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {"b"}));
+  op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {"a", "b"}));
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/broadcast_binary_op.h b/oneflow/core/operator/broadcast_binary_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..669c634401df9f9d942d1f2a41e0c20d928078ee
--- /dev/null
+++ b/oneflow/core/operator/broadcast_binary_op.h
@@ -0,0 +1,29 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_BINARY_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BROADCAST_BINARY_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class BroadcastBinaryOp : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastBinaryOp);
+  BroadcastBinaryOp() = default;
+  virtual ~BroadcastBinaryOp() = default;
+
+  void InitFromOpConf() override;
+  bool IsAllOutputConst() const override { return GetValFromCustomizedConf<bool>("is_const"); }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                           const ParallelContext*) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BROADCAST_BINARY_OP_H_
diff --git a/oneflow/core/operator/broadcast_div_op.cpp b/oneflow/core/operator/broadcast_div_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2dc4d625b79167a308ca599c1a083e84966089db
--- /dev/null
+++ b/oneflow/core/operator/broadcast_div_op.cpp
@@ -0,0 +1,11 @@
+#include "oneflow/core/operator/broadcast_div_op.h"
+
+namespace oneflow {
+
+const PbMessage& BroadcastDivOp::GetCustomizedConf() const {
+  return op_conf().broadcast_div_conf();
+}
+
+REGISTER_OP(OperatorConf::kBroadcastDivConf, BroadcastDivOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/broadcast_div_op.h b/oneflow/core/operator/broadcast_div_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..a6eb5eec6abb8fc14eb6922e0f03d4916d47f2c2
--- /dev/null
+++ b/oneflow/core/operator/broadcast_div_op.h
@@ -0,0 +1,19 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_DIV_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BROADCAST_DIV_OP_H_
+
+#include "oneflow/core/operator/broadcast_binary_op.h"
+
+namespace oneflow {
+
+class BroadcastDivOp final : public BroadcastBinaryOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastDivOp);
+  BroadcastDivOp() = default;
+  ~BroadcastDivOp() = default;
+
+  const PbMessage& GetCustomizedConf() const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BROADCAST_DIV_OP_H_
diff --git a/oneflow/core/operator/broadcast_mul_op.cpp b/oneflow/core/operator/broadcast_mul_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..17aad41f418e92ddefedb6eefb14b9b18eee742d
--- /dev/null
+++ b/oneflow/core/operator/broadcast_mul_op.cpp
@@ -0,0 +1,11 @@
+#include "oneflow/core/operator/broadcast_mul_op.h"
+
+namespace oneflow {
+
+const PbMessage& BroadcastMulOp::GetCustomizedConf() const {
+  return op_conf().broadcast_mul_conf();
+}
+
+REGISTER_OP(OperatorConf::kBroadcastMulConf, BroadcastMulOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/broadcast_mul_op.h b/oneflow/core/operator/broadcast_mul_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..5a7576936ab3dfbcd79c0d448d5fa1c09b239a28
--- /dev/null
+++ b/oneflow/core/operator/broadcast_mul_op.h
@@ -0,0 +1,20 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_MUL_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BROADCAST_MUL_OP_H_
+
+#include "oneflow/core/operator/broadcast_binary_op.h"
+
+namespace oneflow {
+
+class BroadcastMulOp final : public BroadcastBinaryOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastMulOp);
+  BroadcastMulOp() = default;
+  ~BroadcastMulOp() = default;
+
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BROADCAST_MUL_OP_H_
diff --git a/oneflow/core/operator/broadcast_sub_op.cpp b/oneflow/core/operator/broadcast_sub_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c91b3e18fa4e5becffe55941cb8396278876fc90
--- /dev/null
+++ b/oneflow/core/operator/broadcast_sub_op.cpp
@@ -0,0 +1,11 @@
+#include "oneflow/core/operator/broadcast_sub_op.h"
+
+namespace oneflow {
+
+const PbMessage& BroadcastSubOp::GetCustomizedConf() const {
+  return op_conf().broadcast_sub_conf();
+}
+
+REGISTER_OP(OperatorConf::kBroadcastSubConf, BroadcastSubOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/broadcast_sub_op.h b/oneflow/core/operator/broadcast_sub_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..2ef143bb3219393c2b4123330e08ac8f733291c2
--- /dev/null
+++ b/oneflow/core/operator/broadcast_sub_op.h
@@ -0,0 +1,21 @@
+#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_SUB_OP_H_
+#define ONEFLOW_CORE_OPERATOR_BROADCAST_SUB_OP_H_
+
+#include "oneflow/core/operator/broadcast_binary_op.h"
+
+namespace oneflow {
+
+class BroadcastSubOp final : public BroadcastBinaryOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastSubOp);
+  BroadcastSubOp() = default;
+  ~BroadcastSubOp() = default;
+
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_BROADCAST_SUB_OP_H_
diff --git a/oneflow/core/operator/cast_op.cpp b/oneflow/core/operator/cast_op.cpp
index 4fbdefc8c78b192c3ffe5ee569e0254febf2388f..4c3d58080b7f82eaa66ab233ed44f0e6872ece5b 100644
--- a/oneflow/core/operator/cast_op.cpp
+++ b/oneflow/core/operator/cast_op.cpp
@@ -18,6 +18,6 @@ void CastOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob
   out_blob_desc->set_data_type(op_conf().cast_conf().data_type());
 }
 
-REGISTER_CPU_OP(OperatorConf::kCastConf, CastOp);
+REGISTER_OP(OperatorConf::kCastConf, CastOp);
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/cast_op.h b/oneflow/core/operator/cast_op.h
index 2e8585bd0162a3b4e077efcb2790da4129836022..3f4f52f7201cfc90fc0f4949df24513df9bd3699 100644
--- a/oneflow/core/operator/cast_op.h
+++ b/oneflow/core/operator/cast_op.h
@@ -13,9 +13,11 @@ class CastOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-  bool IsElemWiseOp() const override { return true; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/clone_op.h b/oneflow/core/operator/clone_op.h
index af9a70f7e73c8e002598fd17dcdac4879816ba5b..699f8357ea97d7b059c8fd960839d6dafd44cb86 100644
--- a/oneflow/core/operator/clone_op.h
+++ b/oneflow/core/operator/clone_op.h
@@ -22,6 +22,8 @@ class CloneOp final : public Operator {
       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return LogicalBlobId(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override { return LogicalBlobId(); }
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
diff --git a/oneflow/core/operator/concat_op.h b/oneflow/core/operator/concat_op.h
index e5cb8994609fa067aab7e38e703e12fa2d0144f4..2d61e95f0c0ca9b74642e975c691f8cfcc571a31 100644
--- a/oneflow/core/operator/concat_op.h
+++ b/oneflow/core/operator/concat_op.h
@@ -19,6 +19,9 @@ class ConcatOp final : public Operator {
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/constant_op.cpp b/oneflow/core/operator/constant_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8467491cd186b701a0a43c6e7c752b2f900ee0f7
--- /dev/null
+++ b/oneflow/core/operator/constant_op.cpp
@@ -0,0 +1,58 @@
+#include "oneflow/core/operator/constant_op.h"
+
+namespace oneflow {
+
+void ConstantOp::InitFromOpConf() {
+  CHECK(op_conf().has_constant_conf());
+  EnrollInputBn("tick", false);
+  EnrollOutputBn("out", false);
+}
+
+const PbMessage& ConstantOp::GetCustomizedConf() const { return op_conf().constant_conf(); }
+
+void ConstantOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                const ParallelContext* parallel_ctx,
+                                int64_t record_piece_size) const {
+  CHECK_EQ(parallel_ctx->policy(), ParallelPolicy::kDataParallel);
+  const ConstantOpConf& conf = op_conf().constant_conf();
+  const DataType& data_type =
+      conf.has_data_type() ? conf.data_type() : Global<JobDesc>::Get()->DefaultDataType();
+  std::vector<int64_t> dim_vec;
+  if (conf.use_device_piece_size_as_dim0()) {
+    CHECK_EQ(record_piece_size % parallel_ctx->parallel_num(), 0);
+    dim_vec.push_back(record_piece_size / parallel_ctx->parallel_num());
+  }
+  if (conf.has_shape()) {
+    dim_vec.insert(dim_vec.end(), conf.shape().dim().cbegin(), conf.shape().dim().cend());
+  }
+  if (dim_vec.empty()) { dim_vec.push_back(1); }
+  BlobDesc* out = GetBlobDesc4BnInOp("out");
+  out->set_data_type(data_type);
+  out->mut_shape() = Shape(dim_vec);
+}
+
+void ConstantOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
+  kernel_conf->mutable_constant_conf()->set_random_seed(NewRandomSeed());
+  const DataType& data_type = GetBlobDesc4BnInOp("out")->data_type();
+  if (op_conf().constant_conf().has_initializer()) {
+    *kernel_conf->mutable_constant_conf()->mutable_initializer() =
+        op_conf().constant_conf().initializer();
+  } else if (IsFloatingDataType(data_type)) {
+    InitializerConf conf;
+    conf.mutable_constant_conf()->set_value(0);
+    *kernel_conf->mutable_constant_conf()->mutable_initializer() = conf;
+  } else if (IsIntegralDataType(data_type)) {
+    InitializerConf conf;
+    conf.mutable_constant_int_conf()->set_value(0);
+    *kernel_conf->mutable_constant_conf()->mutable_initializer() = conf;
+  } else {
+    UNIMPLEMENTED();
+  }
+  kernel_conf->set_data_type(data_type);
+}
+
+REGISTER_OP(OperatorConf::kConstantConf, ConstantOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/constant_op.h b/oneflow/core/operator/constant_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..4c791cb01603b3f5633bd2dd53afc64fa1941107
--- /dev/null
+++ b/oneflow/core/operator/constant_op.h
@@ -0,0 +1,33 @@
+#ifndef ONEFLOW_CORE_OPERATOR_CONSTANT_OP_H_
+#define ONEFLOW_CORE_OPERATOR_CONSTANT_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class ConstantOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ConstantOp);
+  ConstantOp() = default;
+  ~ConstantOp() override = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx,
+                      int64_t record_piece_size) const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  bool IsAllOutputConst() const override { return true; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx,
+                            KernelConf* kernel_conf) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_CONSTANT_OP_H_
diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp
index 550fa3fc101df74e5956f1d75323ca0fa2509dd1..7df482164e6109dbef249688f2e5360b02e6d3ef 100644
--- a/oneflow/core/operator/conv_op.cpp
+++ b/oneflow/core/operator/conv_op.cpp
@@ -94,7 +94,7 @@ void ConvOp<NDims>::InitFromOpConf() {
 
 template<int32_t NDims>
 void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                   const ParallelContext* parallel_ctx,
+                                   const ParallelContext* parallel_ctx, int64_t record_piece_size,
                                    std::function<void(OpContext*)> EnrollOpCtx) const {
   const std::string& data_format = GetValFromCustomizedConf<std::string>("data_format");
 
@@ -288,7 +288,9 @@ PbMessage* ConvOp<NDims>::MutableCustomizedKernelConf(KernelConf* kernel_conf) c
 }
 
 template<int32_t NDims>
-int32_t ConvOp<NDims>::ModelSplitAxis() const {
+int32_t ConvOp<NDims>::OutputBlobModelSplitAxis(
+    const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+    const std::string& obn) const {
   if (GetValFromCustomizedConf<std::string>("data_format") == "channels_first") {
     return 1;
   } else if (GetValFromCustomizedConf<std::string>("data_format") == "channels_last") {
@@ -298,11 +300,6 @@ int32_t ConvOp<NDims>::ModelSplitAxis() const {
   }
 }
 
-template<int32_t NDims>
-int32_t ConvOp<NDims>::MaxModelSplitNum() const {
-  return GetValFromCustomizedConf<int32_t>("filters");
-}
-
 #ifdef WITH_CUDA
 template<int32_t NDims>
 void ConvOp<NDims>::InferCudnnAlgo(
diff --git a/oneflow/core/operator/conv_op.h b/oneflow/core/operator/conv_op.h
index d0d782edce761f57870407483c375f807a813911..4d2316b0b77ac5501367012f2ee18bd5d3512e07 100644
--- a/oneflow/core/operator/conv_op.h
+++ b/oneflow/core/operator/conv_op.h
@@ -41,15 +41,18 @@ class ConvOp : public Operator {
   void InitFromOpConf() override;
   bool NeedOutBlobWhenBackward() const override { return false; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext*,
+                      const ParallelContext* parallel_ctx, int64_t record_piece_size,
                       std::function<void(OpContext*)> EnrollOpCtx) const override;
   void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                            const ParallelContext*, const OpContext*) const override;
 
-  int32_t ModelSplitAxis() const override;
-  int32_t MaxModelSplitNum() const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   PbMessage* MutableCustomizedKernelConf(KernelConf*) const override;
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext*, KernelConf*, const OpContext*) const override;
diff --git a/oneflow/core/operator/copy_comm_net_op.h b/oneflow/core/operator/copy_comm_net_op.h
index 3709ed055674d3fd31abfb93e31da58d663eb4c9..1954eda28d8ddbde5b7ca2bd91335e15a8e4a888 100644
--- a/oneflow/core/operator/copy_comm_net_op.h
+++ b/oneflow/core/operator/copy_comm_net_op.h
@@ -15,6 +15,8 @@ class CopyCommNetOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override;
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/copy_hd_op.h b/oneflow/core/operator/copy_hd_op.h
index 657d7b0f906bc8d8041579c011fb21f12102457f..fa35b16b185d0cf42c588963d0e2e7e5148a49fd 100644
--- a/oneflow/core/operator/copy_hd_op.h
+++ b/oneflow/core/operator/copy_hd_op.h
@@ -15,6 +15,8 @@ class CopyHdOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
 };
diff --git a/oneflow/core/operator/debug_op.cpp b/oneflow/core/operator/debug_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d7ae2a9ee4fa6cebf7ae1b882602101a67187260
--- /dev/null
+++ b/oneflow/core/operator/debug_op.cpp
@@ -0,0 +1,17 @@
+#include "oneflow/core/operator/debug_op.h"
+
+namespace oneflow {
+
+void DebugOp::InitFromOpConf() {
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+void DebugOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                             const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_CPU_OP(OperatorConf::kDebugConf, DebugOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/debug_op.h b/oneflow/core/operator/debug_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0b854da59a82ebcd817de7f8cb3734bacc498d53
--- /dev/null
+++ b/oneflow/core/operator/debug_op.h
@@ -0,0 +1,28 @@
+#ifndef ONEFLOW_CORE_OPERATOR_DEBUG_OP_H_
+#define ONEFLOW_CORE_OPERATOR_DEBUG_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class DebugOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(DebugOp);
+  DebugOp() = default;
+  ~DebugOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().debug_conf(); }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_DEBUG_OP_H_
diff --git a/oneflow/core/operator/decode_ofrecord_op.h b/oneflow/core/operator/decode_ofrecord_op.h
index cab4107ec7411dba39896f56339904bb0b5ba785..f94dcbde24334f2bb364784a96f5e5679b68e903 100644
--- a/oneflow/core/operator/decode_ofrecord_op.h
+++ b/oneflow/core/operator/decode_ofrecord_op.h
@@ -24,6 +24,8 @@ class DecodeOFRecordOp final : public Operator {
                             KernelConf* kernel_conf) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
 
diff --git a/oneflow/core/operator/decode_random_op.cpp b/oneflow/core/operator/decode_random_op.cpp
index 22c1b6a00214e2ac87d315debd1392d726c5604a..4e35ca19aaabca7b71d0c32f62372740ff69ae33 100644
--- a/oneflow/core/operator/decode_random_op.cpp
+++ b/oneflow/core/operator/decode_random_op.cpp
@@ -18,13 +18,13 @@ void DecodeRandomOp::VirtualGenKernelConf(
 }
 
 void DecodeRandomOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                    const ParallelContext* parallel_ctx) const {
+                                    const ParallelContext* parallel_ctx,
+                                    int64_t record_piece_size) const {
   BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
   const DecodeRandomOpConf& conf = op_conf().decode_random_conf();
   std::vector<int64_t> dim_vec(1 + conf.shape().dim_size());
-  int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize();
-  CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0);
-  dim_vec[0] = global_piece_size / parallel_ctx->parallel_num();
+  CHECK_EQ(record_piece_size % parallel_ctx->parallel_num(), 0);
+  dim_vec[0] = record_piece_size / parallel_ctx->parallel_num();
   FOR_RANGE(size_t, j, 1, dim_vec.size()) { dim_vec[j] = conf.shape().dim(j - 1); }
   out_blob_desc->mut_shape() = Shape(dim_vec);
   out_blob_desc->set_data_type(conf.data_type());
diff --git a/oneflow/core/operator/decode_random_op.h b/oneflow/core/operator/decode_random_op.h
index 0aaad916d826c883bddb95f4dfc8d14646db6f92..2ab009feffd36b07e75e40291b8b07c403872168 100644
--- a/oneflow/core/operator/decode_random_op.h
+++ b/oneflow/core/operator/decode_random_op.h
@@ -17,9 +17,12 @@ class DecodeRandomOp final : public Operator {
   LogicalNode* NewProperLogicalNode() override { return new DecodeRandomLogicalNode; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override;
+                      const ParallelContext* parallel_ctx,
+                      int64_t record_piece_size) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext* parallel_ctx,
                             KernelConf* kernel_conf) const override;
diff --git a/oneflow/core/operator/define_test_blob_op.cpp b/oneflow/core/operator/define_test_blob_op.cpp
index a56f8fb13d6fdabc695678601ad25f282fa95435..bdbd8fc0955151464963445f2ec9d1226cc0e68d 100644
--- a/oneflow/core/operator/define_test_blob_op.cpp
+++ b/oneflow/core/operator/define_test_blob_op.cpp
@@ -4,7 +4,7 @@ namespace oneflow {
 
 void DefineTestBlobOp::InitFromOpConf() {
   CHECK(op_conf().has_define_test_blob_conf());
-  EnrollOutputBn("out", false);
+  EnrollOutputBn("out", op_conf().define_test_blob_conf().has_diff());
 }
 
 const PbMessage& DefineTestBlobOp::GetCustomizedConf() const {
diff --git a/oneflow/core/operator/define_test_blob_op.h b/oneflow/core/operator/define_test_blob_op.h
index 69a91479c51052f04984e4fe9904d0cca6a9a63d..f8034c787668cc5bb39df4b7108717b1e740b20c 100644
--- a/oneflow/core/operator/define_test_blob_op.h
+++ b/oneflow/core/operator/define_test_blob_op.h
@@ -18,6 +18,9 @@ class DefineTestBlobOp final : public Operator {
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/dot_op.h b/oneflow/core/operator/dot_op.h
index dc2291eb6d50b517100765c30021a4dea06bbb99..bc926bfe0ecce38672a2cfd29fed2f32d091d3c7 100644
--- a/oneflow/core/operator/dot_op.h
+++ b/oneflow/core/operator/dot_op.h
@@ -12,6 +12,9 @@ class DotOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/dropout_op.h b/oneflow/core/operator/dropout_op.h
index c634d1b05d658c98971c3801221a43a90c569103..b837f27862f919da110a487eade2987dec987c9f 100644
--- a/oneflow/core/operator/dropout_op.h
+++ b/oneflow/core/operator/dropout_op.h
@@ -13,7 +13,6 @@ class DropoutOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-  bool IsElemWiseOp() const override { return true; }
   bool NeedInBlobWhenBackward() const override { return false; }
   bool NeedOutBlobWhenBackward() const override { return false; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
@@ -22,6 +21,9 @@ class DropoutOp final : public Operator {
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext* parallel_ctx,
                             KernelConf* kernel_conf) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/embedding_lookup_accumulate_op.h b/oneflow/core/operator/embedding_lookup_accumulate_op.h
index dc42310fde5ffa21907d6fa0fcd38f7a43d7077d..c95274ded53a5f7c418b8cbb7489858b42582a7d 100644
--- a/oneflow/core/operator/embedding_lookup_accumulate_op.h
+++ b/oneflow/core/operator/embedding_lookup_accumulate_op.h
@@ -13,11 +13,11 @@ class EmbeddingLookupAccumulateOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
 };
diff --git a/oneflow/core/operator/embedding_lookup_op.h b/oneflow/core/operator/embedding_lookup_op.h
index ef9611f0b797c188119c738fa5c807ffd6336737..4ac9c0ce343c1bc962ac45d5ad68e1d5442f3985 100644
--- a/oneflow/core/operator/embedding_lookup_op.h
+++ b/oneflow/core/operator/embedding_lookup_op.h
@@ -16,10 +16,14 @@ class EmbeddingLookupOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
-  int32_t ModelSplitAxis() const override { return 1; }
-  int32_t MaxModelSplitNum() const override { return op_conf().embedding_lookup_conf().units(); }
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override {
+    return 1;
+  }
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/fully_connected_op.h b/oneflow/core/operator/fully_connected_op.h
index 0adf7a3a2612e2b53a972367c5355967e5e7d52a..efa229120d17549385f2c89657d7cc3ab59df39d 100644
--- a/oneflow/core/operator/fully_connected_op.h
+++ b/oneflow/core/operator/fully_connected_op.h
@@ -16,8 +16,14 @@ class FullyConnectedOp final : public Operator {
   bool NeedOutBlobWhenBackward() const override { return false; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
-  int32_t ModelSplitAxis() const override { return 1; }
-  int32_t MaxModelSplitNum() const override { return op_conf().fully_connected_conf().units(); }
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override {
+    return 1;
+  }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/gather_op.cpp b/oneflow/core/operator/gather_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e4f8658ca88adf69f11a6fe9b243a468a32c5859
--- /dev/null
+++ b/oneflow/core/operator/gather_op.cpp
@@ -0,0 +1,115 @@
+#include "oneflow/core/operator/gather_op.h"
+
+namespace oneflow {
+
+namespace {
+
+int64_t GetGatherAxis(const GatherOpConf& conf, int64_t num_axes) {
+  const int64_t axis = conf.axis() < 0 ? num_axes + conf.axis() : conf.axis();
+  CHECK_GE(axis, 0);
+  CHECK_LT(axis, num_axes);
+  return axis;
+}
+
+int64_t GetGatherAxis(const GatherOpConf& conf, const BlobDesc* in_blob_desc) {
+  return GetGatherAxis(conf, in_blob_desc->shape().NumAxes());
+}
+
+class Gather_DB_MS_2_P_OpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(Gather_DB_MS_2_P_OpParallelSignature);
+  ~Gather_DB_MS_2_P_OpParallelSignature() override = default;
+
+  Gather_DB_MS_2_P_OpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": (C, S) -> P"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp,
+      const ParallelContext* parallel_ctx) const override {
+    const SbpInferHint& in_sbp_infer_hint = SbpInferHint4BnInOp("in");
+    if (!in_sbp_infer_hint.is_model_split()) { return MakeOpParallelMatchSignatureMismatch(); }
+    if (in_sbp_infer_hint.split_axis() != 0) { return MakeOpParallelMatchSignatureMismatch(); }
+    if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); }
+    return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel);
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    (*bn2sbp)["indices"].mutable_broadcast_parallel();
+    (*bn2sbp)["in"].mutable_split_parallel()->set_axis(0);
+    (*bn2sbp)["out"].mutable_partial_sum_parallel();
+  }
+};
+
+}  // namespace
+
+void GatherOp::InitFromOpConf() {
+  CHECK(op_conf().has_gather_conf());
+  EnrollInputBn("indices", false);
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+const PbMessage& GatherOp::GetCustomizedConf() const { return op_conf().gather_conf(); }
+
+bool GatherOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const {
+  CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end());
+  return ibn == "in";
+}
+
+void GatherOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                              const ParallelContext* parallel_ctx) const {
+  const BlobDesc* indices = GetBlobDesc4BnInOp("indices");
+  CHECK(IsIntegralDataType(indices->data_type()));
+  CHECK_GT(indices->shape().NumAxes(), 0);
+  const BlobDesc* in = GetBlobDesc4BnInOp("in");
+  CHECK_GT(in->shape().NumAxes(), 0);
+  const int64_t axis = GetGatherAxis(op_conf().gather_conf(), in);
+  BlobDesc* out = GetBlobDesc4BnInOp("out");
+  *out = *in;
+  std::vector<int64_t> dim_vec;
+  dim_vec.insert(dim_vec.end(), in->shape().dim_vec().cbegin(),
+                 in->shape().dim_vec().cbegin() + axis);
+  dim_vec.insert(dim_vec.end(), indices->shape().dim_vec().cbegin(),
+                 indices->shape().dim_vec().cend());
+  dim_vec.insert(dim_vec.end(), in->shape().dim_vec().cbegin() + axis + 1,
+                 in->shape().dim_vec().end());
+  out->mut_shape() = Shape(dim_vec);
+}
+
+void GatherOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
+  const int64_t axis = GetGatherAxis(op_conf().gather_conf(), GetBlobDesc4BnInOp("in"));
+  kernel_conf->mutable_gather_conf()->set_axis(axis);
+}
+
+void GatherOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(MakeDataSplitOpParallelSignature(this));
+  op_parallel_signatures->emplace_back(Make_DS_MB_2_DS_OpParallelSignature(this));
+  auto GtZero = [](int32_t axis) { return axis > 0; };
+  op_parallel_signatures->emplace_back(Make_DB_MS_2_MS_OpParallelSignature(this, GtZero));
+  op_parallel_signatures->emplace_back(new Gather_DB_MS_2_P_OpParallelSignature(this));
+}
+
+int32_t GatherOp::OutputBlobModelSplitAxis(
+    const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+    const std::string& obn) const {
+  const SbpInferHint& indices_sbp_infer_hint = SbpInferHint4Ibn("indices");
+  CHECK(indices_sbp_infer_hint.is_data_blob());
+  const SbpInferHint& in_sbp_infer_hint = SbpInferHint4Ibn("in");
+  const int64_t in_num_axes = in_sbp_infer_hint.num_axes();
+  const int64_t gather_axis = GetGatherAxis(op_conf().gather_conf(), in_num_axes);
+  CHECK(in_sbp_infer_hint.is_model_split());
+  CHECK_GT(in_sbp_infer_hint.split_axis(), 0);
+  CHECK_GT(in_sbp_infer_hint.split_axis(), gather_axis);
+  CHECK_LT(in_sbp_infer_hint.split_axis(), in_num_axes);
+  return in_sbp_infer_hint.split_axis() + indices_sbp_infer_hint.num_axes() - 1;
+}
+
+REGISTER_OP(OperatorConf::kGatherConf, GatherOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/gather_op.h b/oneflow/core/operator/gather_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..7d9e3287fc2ff91a011f73d1e3339f9d38df1669
--- /dev/null
+++ b/oneflow/core/operator/gather_op.h
@@ -0,0 +1,34 @@
+#ifndef ONEFLOW_CORE_OPERATOR_GATHER_OP_H_
+#define ONEFLOW_CORE_OPERATOR_GATHER_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class GatherOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(GatherOp);
+  GatherOp() = default;
+  ~GatherOp() override = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx,
+                            KernelConf* kernel_conf) const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override;
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_GATHER_OP_H_
diff --git a/oneflow/core/operator/gelu_op.cpp b/oneflow/core/operator/gelu_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..bfc4f64e1ccf5ed325a35a9ab986188929a12fc4
--- /dev/null
+++ b/oneflow/core/operator/gelu_op.cpp
@@ -0,0 +1,22 @@
+#include "oneflow/core/operator/gelu_op.h"
+#include "oneflow/core/common/balanced_splitter.h"
+
+namespace oneflow {
+
+void GeluOp::InitFromOpConf() {
+  CHECK(op_conf().has_gelu_conf());
+
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+const PbMessage& GeluOp::GetCustomizedConf() const { return op_conf().gelu_conf(); }
+
+void GeluOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_OP(OperatorConf::kGeluConf, GeluOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/gelu_op.h b/oneflow/core/operator/gelu_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..66a083b3f13f7ef4e8659d0acc1cb8d6c4c4a016
--- /dev/null
+++ b/oneflow/core/operator/gelu_op.h
@@ -0,0 +1,26 @@
+#ifndef ONEFLOW_CORE_OPERATOR_GELU_OP_H_
+#define ONEFLOW_CORE_OPERATOR_GELU_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class GeluOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(GeluOp);
+  GeluOp() = default;
+  ~GeluOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_GELU_OP_H_
diff --git a/oneflow/core/operator/hinge_loss_op.h b/oneflow/core/operator/hinge_loss_op.h
index 527754fd6c2d27b61a89844c5f214f910ed814cb..f17cc11c6ccb956faf7ff93bd2055c40627a6ca0 100644
--- a/oneflow/core/operator/hinge_loss_op.h
+++ b/oneflow/core/operator/hinge_loss_op.h
@@ -14,6 +14,8 @@ class HingeLossOp final : public LossOp {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void VirtualInitFromOpConf() override;
   void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                              const ParallelContext* parallel_ctx) const override;
diff --git a/oneflow/core/operator/identity_loss_op.cpp b/oneflow/core/operator/identity_loss_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..92100a568f360d5a9795b5d13ab40f34f4f4890b
--- /dev/null
+++ b/oneflow/core/operator/identity_loss_op.cpp
@@ -0,0 +1,26 @@
+#include "oneflow/core/operator/identity_loss_op.h"
+
+namespace oneflow {
+
+const PbMessage& IdentityLossOp::GetCustomizedConf() const {
+  return op_conf().identity_loss_conf();
+}
+
+LossKernelConf* IdentityLossOp::GetMutLossKernelConf(KernelConf* kernel_conf) const {
+  return kernel_conf->mutable_identity_loss_conf()->mutable_loss_conf();
+}
+
+void IdentityLossOp::VirtualInitFromOpConf() { EnrollConstBufBn("ones"); }
+
+void IdentityLossOp::VirtualInferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  const BlobDesc* prediction = GetBlobDesc4BnInOp("prediction");
+  BlobDesc* ones = GetBlobDesc4BnInOp("ones");
+  ones->set_data_type(prediction->data_type());
+  ones->mut_shape() = prediction->shape();
+}
+
+REGISTER_OP(OperatorConf::kIdentityLossConf, IdentityLossOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/identity_loss_op.h b/oneflow/core/operator/identity_loss_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0fc406d710799149c20e561b8b818caf9a43c336
--- /dev/null
+++ b/oneflow/core/operator/identity_loss_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_IDENTITY_LOSS_OP_H_
+#define ONEFLOW_CORE_OPERATOR_IDENTITY_LOSS_OP_H_
+
+#include "oneflow/core/operator/loss_op.h"
+
+namespace oneflow {
+
+class IdentityLossOp final : public LossOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(IdentityLossOp);
+  IdentityLossOp() = default;
+  ~IdentityLossOp() override = default;
+
+  const PbMessage& GetCustomizedConf() const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
+  LossKernelConf* GetMutLossKernelConf(KernelConf*) const override;
+  void VirtualInitFromOpConf() override;
+  void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                             const ParallelContext* parallel_ctx) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_IDENTITY_LOSS_OP_H_
diff --git a/oneflow/core/operator/identity_op.cpp b/oneflow/core/operator/identity_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..7287398718997ae51b2588a81ed59945ab58de72
--- /dev/null
+++ b/oneflow/core/operator/identity_op.cpp
@@ -0,0 +1,15 @@
+#include "oneflow/core/operator/identity_op.h"
+
+namespace oneflow {
+
+void IdentityOp::InitFromOpConf() {
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+void IdentityOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/identity_op.h b/oneflow/core/operator/identity_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..bca97c449adb7640fdc455ad7e99ecda2f10e9ff
--- /dev/null
+++ b/oneflow/core/operator/identity_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_IDENTITY_OP_H_
+#define ONEFLOW_CORE_OPERATOR_IDENTITY_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class IdentityOp : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(IdentityOp);
+  IdentityOp() = default;
+  virtual ~IdentityOp() = default;
+
+  void InitFromOpConf() override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_IDENTITY_OP_H_
diff --git a/oneflow/core/operator/l2_normalize_op.h b/oneflow/core/operator/l2_normalize_op.h
index 1d5a8c5369d1b1dc2a98d7f2d3439484e44a3026..9f827c4079db3e2ef0238d1e3d91d77b57fdad50 100644
--- a/oneflow/core/operator/l2_normalize_op.h
+++ b/oneflow/core/operator/l2_normalize_op.h
@@ -16,6 +16,9 @@ class L2NormalizeOp final : public Operator {
   bool NeedInBlobWhenBackward() const override { return false; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/lars_model_update_op.cpp b/oneflow/core/operator/lars_model_update_op.cpp
index 3c3d24a9a29af3f9e61b11f8136b6967d27fe7e4..51a0ca916f58baa1a42874e2076382c30e855d37 100644
--- a/oneflow/core/operator/lars_model_update_op.cpp
+++ b/oneflow/core/operator/lars_model_update_op.cpp
@@ -7,7 +7,7 @@ void LARSModelUpdateOp::MdUpdtVirtualInitFromOpConf() {
   EnrollDataTmpBn("data_tmp");
 }
 
-void LARSModelUpdateOp::InferBlobDescs(
+void LARSModelUpdateOp::MdUpdtVirtualInferBlobDescs(
     std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
     const ParallelContext* parallel_ctx) const {
   const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model");
diff --git a/oneflow/core/operator/lars_model_update_op.h b/oneflow/core/operator/lars_model_update_op.h
index 1040b8101d3cd98f7c4fdadd315344bfb7848a22..35673b31877b29eaaa52a6eed5be3638e23e19d5 100644
--- a/oneflow/core/operator/lars_model_update_op.h
+++ b/oneflow/core/operator/lars_model_update_op.h
@@ -11,11 +11,10 @@ class LARSModelUpdateOp final : public NormalModelUpdtOp {
   LARSModelUpdateOp() = default;
   ~LARSModelUpdateOp() = default;
 
-  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override;
-
  private:
   void MdUpdtVirtualInitFromOpConf() override;
+  void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                   const ParallelContext* parallel_ctx) const override;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/layer_norm_op.cpp b/oneflow/core/operator/layer_norm_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..35069f6e4c07bac6c6af5a03669ddb9196a6255c
--- /dev/null
+++ b/oneflow/core/operator/layer_norm_op.cpp
@@ -0,0 +1,86 @@
+#include "oneflow/core/operator/layer_norm_op.h"
+
+namespace oneflow {
+
+namespace {
+
+int64_t ShiftNegativeAxisIfNeed(const Shape& shape, int64_t axis) {
+  const int64_t shifted = axis < 0 ? axis + shape.NumAxes() : axis;
+  CHECK_GE(shifted, 0);
+  CHECK_LT(shifted, shape.NumAxes());
+  return shifted;
+}
+
+}  // namespace
+
+void LayerNormOp::InitFromOpConf() {
+  CHECK(op_conf().has_layer_norm_conf());
+  const LayerNormOpConf& conf = op_conf().layer_norm_conf();
+  if (!(conf.center() || conf.scale())) { mut_op_conf()->set_trainable(false); }
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+  if (conf.center()) { EnrollModelBn("beta"); }
+  if (conf.scale()) {
+    EnrollModelBn("gamma");
+    EnrollDataTmpBn("normalize_out");
+  }
+  EnrollDataTmpBn("cudnn_bn_mean");
+  EnrollDataTmpBn("cudnn_bn_inv_variance");
+  EnrollConstBufBn("cudnn_bn_scale_ones");
+  EnrollConstBufBn("cudnn_bn_bias_zeros");
+  EnrollBwBufBn("cudnn_bn_scale_diff_buf");
+  EnrollBwBufBn("cudnn_bn_bias_diff_buf");
+  if (op_conf().trainable()) { EnrollBwBufBn("bw_reduce_buf"); }
+}
+
+void LayerNormOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                 const ParallelContext* parallel_ctx) const {
+  CHECK(parallel_ctx->policy() != kModelParallel);
+  const BlobDesc* in = GetBlobDesc4BnInOp("in");
+  *GetBlobDesc4BnInOp("out") = *in;
+  const LayerNormOpConf& conf = op_conf().layer_norm_conf();
+  const int64_t begin_params_axis = ShiftNegativeAxisIfNeed(in->shape(), conf.begin_params_axis());
+  const Shape param_shape = Shape({in->shape().Count(begin_params_axis)});
+  if (conf.center()) {
+    BlobDesc* beta = GetBlobDesc4BnInOp("beta");
+    beta->mut_shape() = param_shape;
+    beta->set_data_type(in->data_type());
+  }
+  if (conf.scale()) {
+    BlobDesc* gamma = GetBlobDesc4BnInOp("gamma");
+    gamma->mut_shape() = param_shape;
+    gamma->set_data_type(in->data_type());
+    *GetBlobDesc4BnInOp("normalize_out") = *in;
+  }
+  const int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(in->shape(), conf.begin_norm_axis());
+  const Shape bn_param_shape = Shape({in->shape().Count(0, begin_norm_axis)});
+  BlobDesc* cudnn_bn_mean = GetBlobDesc4BnInOp("cudnn_bn_mean");
+  cudnn_bn_mean->mut_shape() = bn_param_shape;
+  cudnn_bn_mean->set_data_type(in->data_type());
+  *GetBlobDesc4BnInOp("cudnn_bn_inv_variance") = *cudnn_bn_mean;
+  *GetBlobDesc4BnInOp("cudnn_bn_scale_ones") = *cudnn_bn_mean;
+  *GetBlobDesc4BnInOp("cudnn_bn_bias_zeros") = *cudnn_bn_mean;
+}
+
+void LayerNormOp::InferBwBufBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  CHECK(parallel_ctx->policy() != kModelParallel);
+  const BlobDesc* in = GetBlobDesc4BnInOp("in");
+  const LayerNormOpConf& conf = op_conf().layer_norm_conf();
+  if (op_conf().trainable()) { *GetBlobDesc4BnInOp("bw_reduce_buf") = *in; }
+  const int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(in->shape(), conf.begin_norm_axis());
+  const Shape bn_param_shape = Shape({in->shape().Count(0, begin_norm_axis)});
+  BlobDesc* bn_scale_diff = GetBlobDesc4BnInOp("cudnn_bn_scale_diff_buf");
+  bn_scale_diff->mut_shape() = bn_param_shape;
+  bn_scale_diff->set_data_type(in->data_type());
+  *GetBlobDesc4BnInOp("cudnn_bn_bias_diff_buf") = *bn_scale_diff;
+}
+
+void LayerNormOp::VirtualFixParallelDesc(ParallelDesc* pr_desc) const {
+  pr_desc->set_policy(ParallelPolicy::kDataParallel);
+}
+
+REGISTER_OP(OperatorConf::kLayerNormConf, LayerNormOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/layer_norm_op.h b/oneflow/core/operator/layer_norm_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..9fcf8327ee0ae7d32d50a06d2b8da81e9d3733db
--- /dev/null
+++ b/oneflow/core/operator/layer_norm_op.h
@@ -0,0 +1,30 @@
+#ifndef ONEFLOW_CORE_OPERATOR_LAYER_NORM_OP_H_
+#define ONEFLOW_CORE_OPERATOR_LAYER_NORM_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class LayerNormOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(LayerNormOp);
+  LayerNormOp() = default;
+  ~LayerNormOp() override = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().layer_norm_conf(); }
+  bool NeedInBlobWhenBackward() const override { return true; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext*) const override;
+  void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                           const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+  void VirtualFixParallelDesc(ParallelDesc* pr_desc) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_LAYER_NORM_OP_H_
diff --git a/oneflow/core/operator/local_reponse_normalization_op.h b/oneflow/core/operator/local_reponse_normalization_op.h
index 22864365d8fc7351250c59765c7452d1eaf1c128..1028687023f7141c3dea67cd3ca79a2181b8cb6f 100644
--- a/oneflow/core/operator/local_reponse_normalization_op.h
+++ b/oneflow/core/operator/local_reponse_normalization_op.h
@@ -17,6 +17,8 @@ class LocalResponseNormalizationOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext* parallel_ctx,
                             KernelConf* kernel_conf) const override;
diff --git a/oneflow/core/operator/log_counter_op.h b/oneflow/core/operator/log_counter_op.h
index f464a9d926b72fbb9db581050d6763bbd85c6a91..28ad2ac1506c671d5605fcc654cd3ff2d552366f 100644
--- a/oneflow/core/operator/log_counter_op.h
+++ b/oneflow/core/operator/log_counter_op.h
@@ -15,6 +15,9 @@ class LogCounterOp final : public Operator {
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
   virtual LogicalNode* NewProperLogicalNode() { return new PrintLogicalNode; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/loss_op.cpp b/oneflow/core/operator/loss_op.cpp
index d91c02ee1a14c5f9f29ab75e86081881beced44f..ed450083398f9eb5d08a72e34e2e39ea991d9aa7 100644
--- a/oneflow/core/operator/loss_op.cpp
+++ b/oneflow/core/operator/loss_op.cpp
@@ -4,7 +4,7 @@ namespace oneflow {
 
 void LossOp::InitFromOpConf() {
   EnrollInputBn("prediction");
-  EnrollInputBn("label", false);
+  if (HasFieldInCustomizedConf("label")) { EnrollInputBn("label", false); }
   EnrollOutputBn("loss", false);
   EnrollOutputBn("loss_instance_num", false);
   if (!GetValFromCustomizedConf<std::string>("weight").empty()) {
@@ -20,23 +20,29 @@ void LossOp::VirtualGenKernelConf(
     const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
   LossKernelConf* conf = GetMutLossKernelConf(kernel_conf);
   conf->set_prediction_type(GetBlobDesc4BnInOp("prediction")->data_type());
-  conf->set_label_type(GetBlobDesc4BnInOp("label")->data_type());
+  if (HasFieldInCustomizedConf("label")) {
+    conf->set_label_type(GetBlobDesc4BnInOp("label")->data_type());
+  } else {
+    conf->set_label_type(DataType::kInvalidDataType);
+  }
   conf->set_weight_scalar(GetValFromCustomizedConf<float>("weight_scalar"));
   conf->set_reduction(static_cast<LossReductionType>(GetEnumFromCustomizedConf("reduction")));
 }
 
 void LossOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                            const ParallelContext* parallel_ctx,
-                            std::function<void(OpContext*)>) const {
+                            const ParallelContext* parallel_ctx) const {
   const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
-  const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
-  CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field());
-  CHECK_EQ(pred_blob_desc->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field());
-  CHECK_EQ(pred_blob_desc->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape());
+  if (HasFieldInCustomizedConf("label")) {
+    const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
+    CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field());
+    CHECK_EQ(pred_blob_desc->has_dim0_valid_num_field(),
+             label_blob_desc->has_dim0_valid_num_field());
+    CHECK_EQ(pred_blob_desc->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape());
+  }
   if (pred_blob_desc->has_dim0_inner_shape()) {
     CHECK_EQ(pred_blob_desc->dim0_inner_shape().At(0), 1);
   }
-  CHECK_GE(pred_blob_desc->shape().NumAxes(), 2);
+  CHECK_GT(pred_blob_desc->shape().NumAxes(), 0);
   // loss
   BlobDesc* loss_blob_desc = GetBlobDesc4BnInOp("loss");
   *loss_blob_desc = *pred_blob_desc;
diff --git a/oneflow/core/operator/loss_op.h b/oneflow/core/operator/loss_op.h
index ad05fd4e4420dd4b4766945b83eb735dc1bf17ed..1dc47d2d08940c562162ff2c9ff01b33c35b0cb5 100644
--- a/oneflow/core/operator/loss_op.h
+++ b/oneflow/core/operator/loss_op.h
@@ -16,8 +16,7 @@ class LossOp : public Operator {
   LogicalNode* NewProperLogicalNode() override { return new LossLogicalNode; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx,
-                      std::function<void(OpContext*)> EnrollOpCtx) const override;
+                      const ParallelContext* parallel_ctx) const override;
   bool IsLossOp() const override { return true; }
 
  protected:
@@ -32,6 +31,8 @@ class LossOp : public Operator {
                             KernelConf* kernel_conf) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
 
diff --git a/oneflow/core/operator/loss_print_op.h b/oneflow/core/operator/loss_print_op.h
index b7d5ef9d025b91cdeb159be9fd8ce9466f72160b..4c62db725486c4085b153991fbb03d4446b62c3e 100644
--- a/oneflow/core/operator/loss_print_op.h
+++ b/oneflow/core/operator/loss_print_op.h
@@ -15,6 +15,8 @@ class LossPrintOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override;
 };
 
diff --git a/oneflow/core/operator/matmul_op.cpp b/oneflow/core/operator/matmul_op.cpp
index 62bcb2a1daacb077c47500b288160ffca1c64da2..a53920733d5c1a62d980fac521bd6db8b382eca5 100644
--- a/oneflow/core/operator/matmul_op.cpp
+++ b/oneflow/core/operator/matmul_op.cpp
@@ -2,37 +2,139 @@
 #include "oneflow/core/common/balanced_splitter.h"
 namespace oneflow {
 
+namespace {
+
+class Matmul_MS_MS_2_P_OpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(Matmul_MS_MS_2_P_OpParallelSignature);
+  ~Matmul_MS_MS_2_P_OpParallelSignature() override = default;
+
+  Matmul_MS_MS_2_P_OpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": (S, S) -> P"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    const auto& b_sbp_infer_hint = SbpInferHint4Ibn("b");
+    if (!b_sbp_infer_hint.is_model_split()) { return MakeOpParallelMatchSignatureMismatch(); }
+    int32_t b_expected_split_axis = (op().op_conf().matmul_conf().transpose_b() ? 1 : 0);
+    if (b_sbp_infer_hint.split_axis() != b_expected_split_axis) {
+      return MakeOpParallelMatchSignatureMismatch();
+    }
+    if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); }
+    return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel);
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    int32_t a_split_axis = (op().op_conf().matmul_conf().transpose_a() ? 0 : 1);
+    (*bn2sbp)["a"].mutable_split_parallel()->set_axis(a_split_axis);
+    (*bn2sbp)["b"] = SbpInferHint4Ibn("b").sbp_parallel();
+    (*bn2sbp)["out"].mutable_partial_sum_parallel();
+  }
+};
+
+}  // namespace
+
 void MatmulOp::InitFromOpConf() {
   CHECK(op_conf().has_matmul_conf());
-
-  EnrollInputBn("in");
-  EnrollInputBn("weight");
+  EnrollInputBn("a");
+  EnrollInputBn("b");
   EnrollOutputBn("out");
-  if (op_conf().matmul_conf().has_bias()) {
-    EnrollInputBn("bias");
-    EnrollConstBufBn("bias_multiplier");
-  }
+  EnrollFwBufBn("fw_buf");
+  EnrollBwBufBn("bw_buf");
 }
 
 const PbMessage& MatmulOp::GetCustomizedConf() const { return op_conf().matmul_conf(); }
 
+bool MatmulOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const {
+  CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end());
+  return ibn == "b";
+}
+
+void MatmulOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(MakeDataSplitOpParallelSignature(this));
+  op_parallel_signatures->emplace_back(Make_DS_MB_2_DS_OpParallelSignature(this));
+  auto IsValidSplit = [this](int32_t axis) {
+    int32_t b_expected_split_axis = (op_conf().matmul_conf().transpose_b() ? 0 : 1);
+    return axis == b_expected_split_axis;
+  };
+  op_parallel_signatures->emplace_back(Make_DB_MS_2_MS_OpParallelSignature(this, IsValidSplit));
+  op_parallel_signatures->emplace_back(new Matmul_MS_MS_2_P_OpParallelSignature(this));
+}
+
 void MatmulOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                               const ParallelContext* parallel_ctx) const {
   const MatmulOpConf& conf = op_conf().matmul_conf();
-  BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
-  CHECK_EQ(in_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
-  int32_t units = conf.units();
-  if (parallel_ctx->policy() == kModelParallel) {
-    BalancedSplitter splitter(units, parallel_ctx->parallel_num());
-    units = splitter.At(parallel_ctx->parallel_id()).size();
-  }
-  // out
+  BlobDesc* a_blob_desc = GetBlobDesc4BnInOp("a");
+  BlobDesc* b_blob_desc = GetBlobDesc4BnInOp("b");
+  CHECK_EQ(a_blob_desc->shape().NumAxes(), b_blob_desc->shape().NumAxes());
+  CHECK_GE(a_blob_desc->shape().NumAxes(), 2);
+  size_t num_axes = a_blob_desc->shape().NumAxes();
+  if (conf.transpose_a()) {
+    CHECK(!a_blob_desc->has_dim0_valid_num_field());
+    CHECK(!a_blob_desc->has_dim1_valid_num_field());
+    CHECK(!a_blob_desc->has_dim2_valid_num_field());
+  }
+  if (conf.transpose_b()) {
+    CHECK(!b_blob_desc->has_dim0_valid_num_field());
+    CHECK(!b_blob_desc->has_dim1_valid_num_field());
+    CHECK(!b_blob_desc->has_dim2_valid_num_field());
+  }
+  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
+  *out_blob_desc = *a_blob_desc;
+  FOR_RANGE(int32_t, i, 0, num_axes - 2) {
+    CHECK_EQ(a_blob_desc->shape().At(i), b_blob_desc->shape().At(i));
+  }
+  int64_t a_dim_index = conf.transpose_a() ? num_axes - 1 : num_axes - 2;
+  out_blob_desc->mut_shape().Set(num_axes - 2, a_blob_desc->shape().At(a_dim_index));
+  int64_t b_dim_index = conf.transpose_b() ? num_axes - 2 : num_axes - 1;
+  out_blob_desc->mut_shape().Set(num_axes - 1, b_blob_desc->shape().At(b_dim_index));
+  int64_t a_mid_dim_index = conf.transpose_a() ? num_axes - 2 : num_axes - 1;
+  int64_t b_mid_dim_index = conf.transpose_b() ? num_axes - 1 : num_axes - 2;
+  CHECK_EQ(a_blob_desc->shape().At(a_mid_dim_index), b_blob_desc->shape().At(b_mid_dim_index));
+  if (device_type() == DeviceType::kGPU && num_axes >= 3) {
+    int batch_num = a_blob_desc->shape().Count(0, num_axes - 2);
+    // Assume gpu address is 64 bit
+    BlobDesc* fw_buf_blob_desc = GetBlobDesc4BnInOp("fw_buf");
+    *fw_buf_blob_desc = *out_blob_desc;
+    fw_buf_blob_desc->mut_shape() = {3 * batch_num};
+    fw_buf_blob_desc->set_data_type(DataType::kInt64);
+    fw_buf_blob_desc->set_has_data_id_field(false);
+  }
+}
+
+int32_t MatmulOp::OutputBlobModelSplitAxis(
+    const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+    const std::string& obn) const {
+  CHECK_EQ(SbpInferHint4Ibn("a").num_axes(), SbpInferHint4Ibn("b").num_axes());
+  const auto& b_sbp_infer_hint = SbpInferHint4Ibn("b");
+  CHECK_EQ(SbpInferHint4Ibn("b").num_axes(), 2);
+  CHECK(b_sbp_infer_hint.is_model_split());
+  int32_t b_model_split_axis = b_sbp_infer_hint.split_axis();
+  if (op_conf().matmul_conf().transpose_b()) {
+    if (b_model_split_axis == 0) { return 1; }
+  } else {
+    if (b_model_split_axis == 1) { return 1; }
+  }
+  UNIMPLEMENTED();
+  return -1;
+}
+
+void MatmulOp::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                   const ParallelContext*) const {
   BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
-  *out_blob_desc = *in_blob_desc;
-  out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), units});
-  if (conf.has_bias()) {
-    // bias_multiplier
-    GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({in_blob_desc->shape().At(0), 1});
+  size_t num_axes = out_blob_desc->shape().NumAxes();
+  if (device_type() == DeviceType::kGPU && num_axes >= 3) {
+    BlobDesc* bw_buf_blob_desc = GetBlobDesc4BnInOp("bw_buf");
+    int32_t batch_num = out_blob_desc->shape().Count(0, num_axes - 2);
+    *bw_buf_blob_desc = *out_blob_desc;
+    bw_buf_blob_desc->mut_shape() = {3 * batch_num};
+    bw_buf_blob_desc->set_data_type(DataType::kInt64);
+    bw_buf_blob_desc->set_has_data_id_field(false);
   }
 }
 
diff --git a/oneflow/core/operator/matmul_op.h b/oneflow/core/operator/matmul_op.h
index 9c07cb64157015f71c3467727c730919c21f25ea..a7ced9dabb37a9f5240eb702ca5272f10597f8e6 100644
--- a/oneflow/core/operator/matmul_op.h
+++ b/oneflow/core/operator/matmul_op.h
@@ -8,12 +8,22 @@ class MatmulOp final : public Operator {
   OF_DISALLOW_COPY_AND_MOVE(MatmulOp);
   MatmulOp() = default;
   ~MatmulOp() = default;
+
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
-  int32_t ModelSplitAxis() const override { return 1; }
-  int32_t MaxModelSplitNum() const override { return op_conf().matmul_conf().units(); }
+  void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                           const ParallelContext*) const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override;
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/maximum_op.h b/oneflow/core/operator/maximum_op.h
index e76aed1dd0b1783d41839266fc2a8f1aee07355a..d67c7ee35b80b82192e6472556b19685f229e02d 100644
--- a/oneflow/core/operator/maximum_op.h
+++ b/oneflow/core/operator/maximum_op.h
@@ -14,9 +14,11 @@ class MaximumOp final : public CWiseOp {
   void VirtualInitFromOpConf() override;
 
   const PbMessage& GetCustomizedConf() const override;
-
   void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                              const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/mean_op.cpp b/oneflow/core/operator/mean_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..851d14b6c2457bcab165f824241beb1f5ea05a26
--- /dev/null
+++ b/oneflow/core/operator/mean_op.cpp
@@ -0,0 +1,31 @@
+#include "oneflow/core/operator/mean_op.h"
+
+namespace oneflow {
+
+void MeanOp::InitFromOpConf() {
+  CHECK(op_conf().has_mean_conf());
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+  EnrollFwBufBn("fw_tmp");
+  EnrollBwBufBn("bw_tmp");
+}
+
+void MeanOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx) const {
+  const BlobDesc* in_blob = GetBlobDesc4BnInOp("in");
+  BlobDesc* out_blob = GetBlobDesc4BnInOp("out");
+  *out_blob = *in_blob;
+  std::vector<int64_t> dim_vec = in_blob->shape().dim_vec();
+  dim_vec.back() = 1;
+  out_blob->mut_shape() = Shape(std::move(dim_vec));
+  *GetBlobDesc4BnInOp("fw_tmp") = *in_blob;
+}
+
+void MeanOp::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                 const ParallelContext*) const {
+  *GetBlobDesc4BnInOp("bw_tmp") = *GetBlobDesc4BnInOp("out");
+}
+
+REGISTER_OP(OperatorConf::kMeanConf, MeanOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/mean_op.h b/oneflow/core/operator/mean_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..2a977d23d4705532e3302af68dd48e0c48e6f99e
--- /dev/null
+++ b/oneflow/core/operator/mean_op.h
@@ -0,0 +1,29 @@
+#ifndef ONEFLOW_CORE_OPERATOR_MEAN_OP_H_
+#define ONEFLOW_CORE_OPERATOR_MEAN_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class MeanOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(MeanOp);
+  MeanOp() = default;
+  ~MeanOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().mean_conf(); }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                           const ParallelContext*) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_MEAN_OP_H_
diff --git a/oneflow/core/operator/model_save_op.h b/oneflow/core/operator/model_save_op.h
index 157af7e8fa778b2e86f3d41848397f072eb33be9..e80dd03ecfcaf0e4ac4d8cefcbb1c3ab61460945 100644
--- a/oneflow/core/operator/model_save_op.h
+++ b/oneflow/core/operator/model_save_op.h
@@ -13,6 +13,9 @@ class ModelSaveOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/momentum_model_update_op.cpp b/oneflow/core/operator/momentum_model_update_op.cpp
index f39cc3892e34b9ac032f8a50d00d1969c81c487d..674dd29e1e93ef0a4ac0ec8a5ee43ec34b088781 100644
--- a/oneflow/core/operator/momentum_model_update_op.cpp
+++ b/oneflow/core/operator/momentum_model_update_op.cpp
@@ -4,7 +4,7 @@ namespace oneflow {
 
 void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollForwardModelBn("momentum"); }
 
-void MomentumModelUpdateOp::InferBlobDescs(
+void MomentumModelUpdateOp::MdUpdtVirtualInferBlobDescs(
     std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
     const ParallelContext* parallel_ctx) const {
   const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model");
diff --git a/oneflow/core/operator/momentum_model_update_op.h b/oneflow/core/operator/momentum_model_update_op.h
index 7be6e82f85b9daca548fb8c358082b96bc4dd92b..085deb7f3525c11652b46a280b95767b39591caf 100644
--- a/oneflow/core/operator/momentum_model_update_op.h
+++ b/oneflow/core/operator/momentum_model_update_op.h
@@ -11,11 +11,10 @@ class MomentumModelUpdateOp final : public NormalModelUpdtOp {
   MomentumModelUpdateOp() = default;
   ~MomentumModelUpdateOp() = default;
 
-  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override;
-
  private:
   void MdUpdtVirtualInitFromOpConf() override;
+  void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                   const ParallelContext* parallel_ctx) const override;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/multiply_op.h b/oneflow/core/operator/multiply_op.h
index 9194a1c210a94cd188d422e0f0fb85388a56555f..5500497c508215ba38637d99b3ddbd4ba99a37a5 100644
--- a/oneflow/core/operator/multiply_op.h
+++ b/oneflow/core/operator/multiply_op.h
@@ -12,6 +12,9 @@ class MultiplyOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/naive_model_update_op.h b/oneflow/core/operator/naive_model_update_op.h
index 2582f20aa6921f6f1d6c00c3c68755524c67e0e5..414d5f14d4252d1d2c94a1aa5e1ec57d1db1c78e 100644
--- a/oneflow/core/operator/naive_model_update_op.h
+++ b/oneflow/core/operator/naive_model_update_op.h
@@ -10,9 +10,6 @@ class NaiveModelUpdateOp final : public NormalModelUpdtOp {
   OF_DISALLOW_COPY_AND_MOVE(NaiveModelUpdateOp);
   NaiveModelUpdateOp() = default;
   ~NaiveModelUpdateOp() = default;
-
-  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override {}
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/nccl_all_gather_op.h b/oneflow/core/operator/nccl_all_gather_op.h
index d57f0189778001356e73e1fe159a5ad02b7ff508..b5fc4c1a289cbfc381db2add952d6b98b7852995 100644
--- a/oneflow/core/operator/nccl_all_gather_op.h
+++ b/oneflow/core/operator/nccl_all_gather_op.h
@@ -18,6 +18,8 @@ class NcclAllGatherOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/nccl_all_reduce_op.h b/oneflow/core/operator/nccl_all_reduce_op.h
index 969f758b906a814cc86e7ebd9660ba410cbcfbb5..b634fa962a4143a95f928af532b43cd46d5c5059 100644
--- a/oneflow/core/operator/nccl_all_reduce_op.h
+++ b/oneflow/core/operator/nccl_all_reduce_op.h
@@ -18,6 +18,8 @@ class NcclAllReduceOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/nccl_reduce_scatter_op.h b/oneflow/core/operator/nccl_reduce_scatter_op.h
index 9441205425cc7dbb48bf704f7d260dc21d93b3e2..086d1732b3099f8aae5580ce9007d30299c16128 100644
--- a/oneflow/core/operator/nccl_reduce_scatter_op.h
+++ b/oneflow/core/operator/nccl_reduce_scatter_op.h
@@ -18,6 +18,7 @@ class NcclReduceScatterOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/normal_model_update_op.cpp b/oneflow/core/operator/normal_model_update_op.cpp
index a36e4ab03a24b1e7ca8cdde99c8a409193d7f406..28c7a9eb6df72270b6ee637e85758cb75f3f9ec0 100644
--- a/oneflow/core/operator/normal_model_update_op.cpp
+++ b/oneflow/core/operator/normal_model_update_op.cpp
@@ -2,6 +2,7 @@
 #include "oneflow/core/operator/rmsprop_model_update_op.h"
 #include "oneflow/core/operator/momentum_model_update_op.h"
 #include "oneflow/core/operator/lars_model_update_op.h"
+#include "oneflow/core/operator/adam_model_update_op.h"
 
 namespace oneflow {
 
@@ -9,9 +10,24 @@ void NormalModelUpdtOp::InitFromOpConf() {
   EnrollInputBn("model_diff", false);
   EnrollInputBn("total_instance_num_diff", false);
   EnrollOutputBn("model", false);
+  if (op_conf().normal_mdupdt_conf().user_conf().has_clip_conf()
+      && op_conf().normal_mdupdt_conf().user_conf().clip_conf().has_clip_by_global_norm()) {
+    EnrollDataTmpBn("data_tmp");
+  }
   MdUpdtVirtualInitFromOpConf();
 }
 
+void NormalModelUpdtOp::InferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  if (op_conf().normal_mdupdt_conf().user_conf().has_clip_conf()
+      && op_conf().normal_mdupdt_conf().user_conf().clip_conf().has_clip_by_global_norm()) {
+    *GetBlobDesc4BnInOp("data_tmp") = *GetBlobDesc4BnInOp("model_diff");
+    GetBlobDesc4BnInOp("data_tmp")->mut_shape() = Shape({1});
+  }
+  MdUpdtVirtualInferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);
+}
+
 const PbMessage& NormalModelUpdtOp::GetCustomizedConf() const {
   return op_conf().normal_mdupdt_conf();
 }
diff --git a/oneflow/core/operator/normal_model_update_op.h b/oneflow/core/operator/normal_model_update_op.h
index 160c20e1a9277cd9ed70c9700de2c16bf8c57a28..32b2f1fcbd45cae3f0d0817d13f69164d44ba49a 100644
--- a/oneflow/core/operator/normal_model_update_op.h
+++ b/oneflow/core/operator/normal_model_update_op.h
@@ -10,14 +10,21 @@ class NormalModelUpdtOp : public Operator {
   OF_DISALLOW_COPY_AND_MOVE(NormalModelUpdtOp);
   virtual ~NormalModelUpdtOp() = default;
 
-  virtual void InitFromOpConf();
+  void InitFromOpConf() override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
   const PbMessage& GetCustomizedConf() const override;
 
  protected:
   NormalModelUpdtOp() = default;
   virtual void MdUpdtVirtualInitFromOpConf() {}
+  virtual void MdUpdtVirtualInferBlobDescs(
+      std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+      const ParallelContext*) const {}
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
 
diff --git a/oneflow/core/operator/normalization_op.cpp b/oneflow/core/operator/normalization_op.cpp
index f53dfd90fdf457b5a511626c4f4f0b21fad36ce8..211e81b1d8d18d4b2a2e284d94d5b0efae5c3bde 100644
--- a/oneflow/core/operator/normalization_op.cpp
+++ b/oneflow/core/operator/normalization_op.cpp
@@ -39,7 +39,8 @@ const PbMessage& NormalizationOp::GetCustomizedConf() const {
 
 void NormalizationOp::InferBlobDescs(
     std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-    const ParallelContext* parallel_ctx, std::function<void(OpContext*)> EnrollOpCtx) const {
+    const ParallelContext* parallel_ctx, int64_t record_piece_size,
+    std::function<void(OpContext*)> EnrollOpCtx) const {
   const auto& conf = op_conf().normalization_conf();
   const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
   const DataType in_data_type = in_blob_desc->data_type();
diff --git a/oneflow/core/operator/normalization_op.h b/oneflow/core/operator/normalization_op.h
index 373bae04dbceb22a1e6534a9a872e4632d850a80..5d63921d8e3c88605aa1d3c14aab3631f020b8f4 100644
--- a/oneflow/core/operator/normalization_op.h
+++ b/oneflow/core/operator/normalization_op.h
@@ -25,10 +25,12 @@ class NormalizationOp final : public Operator {
   bool NeedOutBlobWhenBackward() const override { return false; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext*,
+                      const ParallelContext* parallel_ctx, int64_t record_piece_size,
                       std::function<void(OpContext*)> EnrollOpCtx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void InferParamBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                            const NormalizationOpConf&, int64_t norm_part_num, DataType in_data_type,
                            bool use_cudnn) const;
diff --git a/oneflow/core/operator/one_hot_op.cpp b/oneflow/core/operator/one_hot_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3c510125f44a4e21dfcdeb6760935171cdffe933
--- /dev/null
+++ b/oneflow/core/operator/one_hot_op.cpp
@@ -0,0 +1,33 @@
+#include "oneflow/core/operator/one_hot_op.h"
+
+namespace oneflow {
+
+void OneHotOp::InitFromOpConf() {
+  CHECK(op_conf().has_one_hot_conf());
+  EnrollInputBn("indices", false);
+  EnrollOutputBn("out", false);
+}
+
+const PbMessage& OneHotOp::GetCustomizedConf() const { return op_conf().one_hot_conf(); }
+
+void OneHotOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                              const ParallelContext* parallel_ctx) const {
+  const OneHotOpConf& conf = op_conf().one_hot_conf();
+  const int64_t depth = conf.depth();
+  const DataType data_type =
+      conf.has_data_type() ? conf.data_type() : Global<JobDesc>::Get()->DefaultDataType();
+  CHECK_GT(depth, 0);
+  const BlobDesc* indices = GetBlobDesc4BnInOp("indices");
+  CHECK(IsIntegralDataType(indices->data_type()));
+  CHECK_GT(indices->shape().NumAxes(), 0);
+  BlobDesc* out = GetBlobDesc4BnInOp("out");
+  *out = *indices;
+  out->set_data_type(data_type);
+  std::vector<int64_t> dim_vec = indices->shape().dim_vec();
+  dim_vec.push_back(depth);
+  out->mut_shape() = Shape(dim_vec);
+}
+
+REGISTER_OP(OperatorConf::kOneHotConf, OneHotOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/one_hot_op.h b/oneflow/core/operator/one_hot_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..b8618e290a6b34cdefb5649947330d42d09c464a
--- /dev/null
+++ b/oneflow/core/operator/one_hot_op.h
@@ -0,0 +1,26 @@
+#ifndef ONEFLOW_CORE_OPERATOR_ONE_HOT_OP_H_
+#define ONEFLOW_CORE_OPERATOR_ONE_HOT_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class OneHotOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(OneHotOp);
+  OneHotOp() = default;
+  ~OneHotOp() override = default;
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_ONE_HOT_OP_H_
diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto
index 19a856a481790816beb68464f721f8a802096e6b..033834a309d4f9c1e13dbba5af05d8493c295083 100644
--- a/oneflow/core/operator/op_conf.proto
+++ b/oneflow/core/operator/op_conf.proto
@@ -4,7 +4,9 @@ package oneflow;
 import "oneflow/core/common/shape.proto";
 import "oneflow/core/common/data_type.proto";
 import "oneflow/core/record/image.proto";
+import "oneflow/core/record/record.proto";
 import "oneflow/core/job/resource.proto";
+import "oneflow/core/job/sbp_parallel.proto";
 import "oneflow/core/register/logical_blob_id.proto";
 
 enum ActivationType {
@@ -37,6 +39,10 @@ message RandomNormalInitializerConf {
   optional float std = 2 [default = 1];
 }
 
+message TruncatedNormalInitializerConf {
+  optional float std = 1 [default = 1];
+}
+
 enum VarianceNorm {
   kFanIn = 0;
   kFanOut = 1;
@@ -51,6 +57,20 @@ message MsraInitializerConf {
   required VarianceNorm variance_norm = 1;
 }
 
+//output[D_0 ... D_(axis - 1) i D_(axis + 1) ... D_n] = start + i * stride
+message RangeInitializerConf {
+  optional double start = 1 [default = 0];
+  optional double stride = 2 [default = 1];
+  optional int64 axis = 3 [default = -1];
+}
+
+message IntRangeInitializerConf {
+  optional int64 start = 1 [default = 0];
+  optional int64 stride = 2 [default = 1];
+  optional int64 axis = 3 [default = -1];
+}
+
+
 message InitializerConf {
   oneof type {
     ConstantInitializerConf constant_conf = 1;
@@ -58,8 +78,11 @@ message InitializerConf {
     RandomUniformInitializerConf random_uniform_conf = 3;
     RandomUniformIntInitializerConf random_uniform_int_conf = 4;
     RandomNormalInitializerConf random_normal_conf = 5;
-    XavierInitializerConf xavier_conf = 6;
-    MsraInitializerConf msra_conf = 7;
+    TruncatedNormalInitializerConf truncated_normal_conf = 6;
+    XavierInitializerConf xavier_conf = 7;
+    MsraInitializerConf msra_conf = 8;
+    RangeInitializerConf range_conf = 9;
+    IntRangeInitializerConf int_range_conf = 10;
   }
 }
 
@@ -206,6 +229,12 @@ enum LossReductionType {
   kSumOverNonZeroWeight = 3;
 }
 
+message SparseCrossEntropyOpConf {
+  required string prediction = 1;
+  required string label = 2;
+  required string out = 3;
+}
+
 message SparseSoftmaxCrossEntropyLossOpConf {
   required string prediction = 1;
   required string label = 2;
@@ -224,6 +253,25 @@ message SparseCrossEntropyLossOpConf {
   optional string weight = 6;
 }
 
+message SigmoidCrossEntropyLossOpConf {
+  required string prediction = 1;
+  required string label = 2;
+  required string loss = 3;
+  optional bool normalize = 4 [default = true];
+  optional float scale = 5 [default = 1.0];
+  optional LossReductionType reduction = 6 [default = kSumOverN];
+  optional float weight_scalar = 7 [default = 1.0];
+  optional string weight = 8;
+}
+
+message IdentityLossOpConf {
+  required string prediction = 1;
+  required string loss = 2;
+  optional LossReductionType reduction = 3 [default = kSumOverN];
+  optional float weight_scalar = 4 [default = 1.0];
+  optional string weight = 5;
+}
+
 message ConcatOpConf {
   repeated string in = 1;
   required string out = 2;
@@ -293,6 +341,13 @@ message LARSModelUpdateConf {
   optional float lars_coefficient = 3 [default = 0.0001];
 }
 
+message AdamModelUpdateConf {
+  optional float beta1 = 1 [default = 0.9];
+  optional float beta2 = 2 [default = 0.999];
+  optional float epsilon = 3 [default = 1e-8];
+  optional bool do_bias_correction = 4 [default = false];
+}
+
 message ExponentialDecayConf {
   required int64 decay_batches = 1;
   required double decay_rate = 2;
@@ -324,12 +379,12 @@ message PolynomialDecayConf {
 }
 
 message CosineDecayConf {
-  required int64 decay_batches = 1; 
+  required int64 decay_batches = 1;
   optional double alpha = 2 [default = 0.0];
 }
 
 message LinearCosineDecayConf {
-  required int64 decay_batches = 1; 
+  required int64 decay_batches = 1;
   optional double num_periods = 2 [default = 0.5];
   optional double alpha = 3 [default = 0.0];
   optional double beta = 4 [default = 0.001];
@@ -364,14 +419,27 @@ message WarmupConf {
   }
 }
 
+message ClipByGlobalNormConf {
+  required float clip_norm = 1;
+  optional float global_norm = 2;
+}
+
+message ClipConf {
+  oneof type {
+    ClipByGlobalNormConf clip_by_global_norm = 1;
+  }
+}
+
 message NormalModelUpdateOpUserConf {
   optional LearningRateDecayConf learning_rate_decay = 1;
   optional WarmupConf warmup_conf = 2;
+  optional ClipConf clip_conf = 3;
   oneof normal_mdupdt {
     NaiveModelUpdateConf naive_conf = 1000;
     MomentumModelUpdateConf momentum_conf = 1001;
     RMSPropModelUpdateConf rmsprop_conf = 1002;
     LARSModelUpdateConf lars_conf = 1003;
+    AdamModelUpdateConf adam_conf = 1004;
   }
 }
 
@@ -413,6 +481,11 @@ message LogCounterOpConf {
   optional int32 interval = 2 [default = 1];
 }
 
+message GeluOpConf {
+  required string in = 1;
+  required string out = 2;
+}
+
 message LossPrintOpConf {
   required LogicalBlobId loss_lbi = 1;
   required LogicalBlobId loss_instance_num_lbi = 2;
@@ -433,8 +506,15 @@ message ReduceSumOpConf {
     LogicalBlobId in_sys = 2; // For System
   }
   required string out = 3;
-  optional int32 axis = 4;
-  optional bool keepdims = 5 [default = false];
+  repeated int32 axis = 4;
+  optional bool keep_dims = 5 [default = false];
+}
+
+message ReduceMeanOpConf {
+  required string in = 1;
+  required string out = 2;
+  repeated int32 axis = 3;
+  optional bool keep_dims = 4 [default = false];
 }
 
 message BasicRnnOpConf {
@@ -460,6 +540,7 @@ message ReshapeOpConf {
   required string in = 1;
   required string out = 2;
   required ShapeProto shape = 3;
+  optional bool has_dim0_in_shape = 4;
 }
 
 message EmbeddingLookupOpConf {
@@ -491,10 +572,20 @@ message CastOpConf {
   required DataType data_type = 3;
 }
 
+message VariableOpConf {
+  required string tick = 1;
+  required string out = 2;
+  required ShapeProto shape = 3;
+  optional DataType data_type = 4;
+  optional InitializerConf initializer = 5;
+  optional string model_name = 6 [default = "weight"];
+  optional int32 model_split_axis = 7 [default = 0];
+}
+
 message LocalResponseNormalizationOpConf {
   required string in = 1;
   required string out = 2;
-  required string data_format = 3; 
+  required string data_format = 3;
   optional int32 depth_radius = 4 [default = 5];
   optional double bias = 5 [default = 1];
   optional double alpha = 6 [default = 1];
@@ -542,11 +633,16 @@ message PreprocessConf {
   }
 }
 
+message RandomShuffleConf {
+  optional int32 buffer_size = 1 [default = 1024];
+}
+
 message RecordLoadOpConf {
   required string out = 1;
   required string data_dir = 2;
   optional string part_name_prefix = 3 [default = "part-"];
   optional int32 part_name_suffix_length = 4 [default = -1];
+  optional RandomShuffleConf random_shuffle_conf = 5;
 }
 
 message BlobConf {
@@ -564,6 +660,7 @@ message DecodeOFRecordOpConf {
   optional int32 part_name_suffix_length = 3 [default = -1];
   optional string in = 4;
   repeated BlobConf blob = 5;
+  optional RandomShuffleConf random_shuffle_conf = 6;
 }
 
 message DecodeRandomOpConf {
@@ -582,6 +679,7 @@ message DefineTestBlobOpConf {
   optional int64 dim1_valid_num = 6;
   optional int64 dim2_valid_num = 7;
   repeated int64 record_id_in_device_piece = 8;
+  optional bool has_diff = 9 [default = false];
 }
 
 message NormalizationOpConf {
@@ -648,14 +746,17 @@ message AccuracyOpConf {
   required string label = 2;
   optional int32 top_k = 3 [default = 1];
   required string accuracy = 4;
+  optional string weight = 5;
 }
 
 message MatmulOpConf {
-  required string in = 1;
-  required string weight = 2;
-  optional string bias = 3;
-  required int32 units = 4;
+  // input lbn
+  required string a = 1;
+  required string b = 2;
+  // output bn
   required string out = 5;
+  optional bool transpose_a = 6 [default = false];
+  optional bool transpose_b = 7 [default = false];
 }
 
 message DotOpConf {
@@ -689,28 +790,193 @@ message HingeLossOpConf {
 message PackOpConf {
   required string in = 1;
   required string out = 2;
-  oneof pack_num_conf {
-    int32 pack_num = 3;
-    int32 pack_num_per_record = 4;
-  }
-  required string related_unpack = 5;
+  required int32 pack_num = 3;
+  required string related_unpack = 4;
 }
 
 message UnpackOpConf {
   required string in = 1;
   required string out = 2;
-  oneof unpack_num_conf {
-    int32 unpack_num = 3;
-    int32 unpack_num_per_record = 4;
-  }
+  required int32 unpack_num = 3;
 }
 
 message RepeatOpConf {
   required string in = 1;
   required string out = 2;
-  oneof repeat_num_conf {
-    int32 repeat_num = 3;
-    int32 repeat_num_per_record = 4;
+  required int32 repeat_num = 3;
+}
+
+message GatherOpConf {
+  required string in = 1;
+  required string indices = 2;
+  required string out = 3;
+  optional int64 axis = 4 [default = 0];
+}
+
+message BatchGatherOpConf {
+  required string in = 1;
+  required string indices = 2;
+  required string out = 3;
+}
+
+message SqrtOpConf {
+  required string in = 1;
+  required string out = 2;
+}
+
+message RsqrtOpConf {
+  required string in = 1;
+  required string out = 2;
+}
+
+message SquareOpConf {
+  required string in = 1;
+  required string out = 2;
+}
+
+message BroadcastAddOpConf {
+  required string a = 1;
+  required string b = 2;
+  required string out = 3;
+  optional bool is_const = 4 [default = false];
+}
+
+message BroadcastSubOpConf {
+  required string a = 1;
+  required string b = 2;
+  required string out = 3;
+  optional bool is_const = 4 [default = false];
+}
+
+message BroadcastMulOpConf {
+  required string a = 1;
+  required string b = 2;
+  required string out = 3;
+  optional bool is_const = 4 [default = false];
+}
+
+message BroadcastDivOpConf {
+  required string a = 1;
+  required string b = 2;
+  required string out = 3;
+  optional bool is_const = 4 [default = false];
+}
+
+
+message BiasAddOpConf {
+  // inputs
+  required string a = 1;
+  required string b = 2;
+  // output
+  required string out = 3;
+}
+
+message MeanOpConf {
+  required string in = 1;
+  required string out = 2;
+  // TODO: axis of mean
+}
+
+message DimSliceConf {
+  optional int32 start = 1 [default = 0];
+  optional int32 end = 2 [default = 0];
+  optional int32 stride = 3 [default = 1];
+}
+
+message SliceOpConf {
+  required string in = 1;
+  required string out = 2;
+  repeated DimSliceConf dim_slice_conf = 3;
+}
+
+message LayerNormOpConf {
+  required string in = 1;
+  required string out = 2;
+  optional bool center = 3 [default = true];
+  optional bool scale = 4 [default = true];
+  optional ActivationType activation = 5 [default = kNone];
+  optional int64 begin_norm_axis = 6 [default = 1];
+  optional int64 begin_params_axis = 7 [default = -1];
+  optional double epsilon = 8 [default = 1e-5];
+}
+
+message ConstantOpConf {
+  required string tick = 1;
+  required string out = 2;
+  optional ShapeProto shape = 3;
+  optional DataType data_type = 4;
+  optional InitializerConf initializer = 5;
+  optional bool use_device_piece_size_as_dim0 = 6 [default = false];
+}
+
+message DebugOpConf {
+  required string in = 1;
+  required string out = 2;
+  optional string in_blob_dump_dir = 3;
+  optional string out_diff_blob_dump_dir = 4;
+  optional string part_name_prefix = 5 [default = "part-"];
+  optional int32 part_name_suffix_length = 6 [default = -1];
+  oneof const_out {
+    string const_out_feature_load_filepath = 7;
+    Feature const_out_feature = 8;
+  }
+  oneof const_in_diff {
+    string const_in_diff_feature_load_filepath = 9;
+    Feature const_in_diff_feature = 10;
+  }
+}
+
+message OneHotOpConf {
+  required string indices = 1;
+  required string out = 2;
+  required int64 depth = 3;
+  optional DataType data_type = 4;
+}
+
+message ScalarAddOpConf {
+  required string in = 1;
+  required string out = 2;
+  oneof scalar_operand {
+    int64 int_operand = 3;
+    double float_operand = 4;
+  }
+}
+
+message ScalarMulOpConf {
+  required string in = 1;
+  required string out = 2;
+  oneof scalar_operand {
+    int64 int_operand = 3;
+    double float_operand = 4;
+  }
+}
+
+message ReduceIdentityOpConf {
+}
+
+message TickOpConf {
+  optional string in = 1;
+  required string out = 2;
+}
+
+message TupleIdentityOpConf {
+  repeated string in = 1;
+  repeated string out = 2;
+}
+
+message TopKOpConf {
+  required string in = 1;
+  required string out = 2;
+  optional int32 k = 3 [default = 1];
+  optional bool sorted = 4 [default = true];
+}
+
+message ParallelCastOpConf {
+  required string in = 1;
+  required string out = 2;
+  oneof parallel_type {
+    SplitParallel split_parallel = 3;
+    BroadcastParallel broadcast_parallel = 4;
   }
 }
 
@@ -751,8 +1017,12 @@ message OperatorConf {
     ModelSaveOpConf model_save_conf = 119;
     SharedModelDiffAddOpConf shared_model_diff_add_conf = 120;
     CastOpConf cast_conf = 121;
-    
+    VariableOpConf variable_conf = 122;
+    ReduceIdentityOpConf reduce_identity_conf = 123;
+    TickOpConf tick_conf = 124;
+
     // domain op
+    TupleIdentityOpConf tuple_identity_conf = 200;
     TransposeOpConf transpose_conf = 201;
     ReshapeOpConf reshape_conf = 202;
     BasicRnnOpConf basic_rnn_conf = 203;
@@ -793,7 +1063,34 @@ message OperatorConf {
     UnpackOpConf unpack_conf = 238;
     RepeatOpConf repeat_conf = 239;
     LogCounterOpConf log_counter_conf = 240;
-    L2NormalizeOpConf l2_normalize_conf = 241;
+    GeluOpConf gelu_conf = 241;
+    GatherOpConf gather_conf = 242;
+    BatchGatherOpConf batch_gather_conf = 243;
+    MeanOpConf mean_conf = 251;
+    SliceOpConf slice_conf = 252;
+    BiasAddOpConf bias_add_conf = 253;
+    LayerNormOpConf layer_norm_conf = 254;
+    ConstantOpConf constant_conf = 255;
+    DebugOpConf debug_conf = 256;
+    SigmoidCrossEntropyLossOpConf sigmoid_cross_entropy_loss_conf = 257;
+    OneHotOpConf one_hot_conf = 258;
+    IdentityLossOpConf identity_loss_conf = 259;
+    SparseCrossEntropyOpConf sparse_cross_entropy_conf= 260;
+    ReduceMeanOpConf reduce_mean_conf = 261;
+    TopKOpConf top_k_conf = 262;
+    ParallelCastOpConf parallel_cast_conf = 263;
+    L2NormalizeOpConf l2_normalize_conf = 264;
+
+    // math op
+    BroadcastAddOpConf broadcast_add_conf = 500;
+    BroadcastSubOpConf broadcast_sub_conf = 501;
+    BroadcastMulOpConf broadcast_mul_conf = 502;
+    BroadcastDivOpConf broadcast_div_conf = 503;
+    SquareOpConf square_conf = 504;
+    SqrtOpConf sqrt_conf = 505;
+    RsqrtOpConf rsqrt_conf = 506;
+    ScalarAddOpConf scalar_add_conf = 507;
+    ScalarMulOpConf scalar_mul_conf = 508;
   }
 }
 
diff --git a/oneflow/core/operator/op_parallel_match_result.proto b/oneflow/core/operator/op_parallel_match_result.proto
new file mode 100644
index 0000000000000000000000000000000000000000..2984274383bd3b33dee38eafe2ad5055e19bdb7d
--- /dev/null
+++ b/oneflow/core/operator/op_parallel_match_result.proto
@@ -0,0 +1,39 @@
+syntax = "proto2";
+package oneflow;
+
+import "oneflow/core/job/placement.proto";
+
+message OpParallelMatchSucess {
+}
+
+message OpParallelSignatureMismatch {
+}
+
+message ParallelPolicyError {
+  required ParallelPolicy configured = 1;
+  required ParallelPolicy expected = 2;
+}
+
+message ParallelNumError {
+  required int64 configured = 1;
+  required int64 expected = 2;
+}
+
+message OpParallelConfError {
+  optional ParallelPolicyError parallel_policy_error = 1;
+  optional ParallelNumError parallel_num_error = 2;
+}
+
+message OpParallelMatchFail {
+  oneof fail_type {
+    OpParallelSignatureMismatch signature_mismatch = 1;
+    OpParallelConfError conf_error = 2;
+  }
+}
+
+message OpParallelMatchResult {
+  oneof result_type {
+    OpParallelMatchSucess success = 1;
+    OpParallelMatchFail fail = 2;
+  }
+}
\ No newline at end of file
diff --git a/oneflow/core/operator/op_parallel_signature.cpp b/oneflow/core/operator/op_parallel_signature.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..decf39b7913c11afe1b098271503b02a8d07dbfa
--- /dev/null
+++ b/oneflow/core/operator/op_parallel_signature.cpp
@@ -0,0 +1,347 @@
+#include "oneflow/core/operator/op_parallel_signature.h"
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+const OpParallelMatchResult MakeOpParallelMatchSuccess() {
+  OpParallelMatchResult success;
+  success.mutable_success();
+  return success;
+}
+
+const OpParallelMatchResult MakeOpParallelMatchSignatureMismatch() {
+  OpParallelMatchResult signature_mismatch;
+  signature_mismatch.mutable_fail()->mutable_signature_mismatch();
+  return signature_mismatch;
+}
+
+const OpParallelMatchResult MakeOpParallelMatchParallelPolicyError(ParallelPolicy configured,
+                                                                   ParallelPolicy expected) {
+  OpParallelMatchResult policy_error;
+  auto* err = policy_error.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error();
+  err->set_configured(configured);
+  err->set_expected(expected);
+  return policy_error;
+}
+
+const OpParallelMatchResult MakeOpParallelMatchParallelNumError(int64_t configured,
+                                                                int64_t expected) {
+  OpParallelMatchResult parallel_num_error;
+  auto* err = parallel_num_error.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error();
+  err->set_configured(configured);
+  err->set_expected(expected);
+  return parallel_num_error;
+}
+
+namespace {
+
+class DataSplitOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(DataSplitOpParallelSignature);
+  ~DataSplitOpParallelSignature() override = default;
+
+  DataSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override {
+    return op().op_name() + ": (S(0), ...) -> (S(0), ...)";
+  }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    bool is_data_split = true;
+    for (const auto& bn : op().input_bns()) {
+      const SbpInferHint& sbp_infer_hint = SbpInferHint4Ibn(bn);
+      if (!sbp_infer_hint.is_data_blob()) {
+        is_data_split = false;
+        break;
+      }
+    }
+    if (!is_data_split) { return MakeOpParallelMatchSignatureMismatch(); }
+    if (parallel_ctx->policy() == kDataParallel) { return MakeOpParallelMatchSuccess(); }
+    return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kDataParallel);
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    for (const auto& bn : op().input_bns()) { (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); }
+    for (const auto& bn : op().output_bns()) {
+      (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0);
+    }
+  }
+};
+
+class BroadcastOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(BroadcastOpParallelSignature);
+  ~BroadcastOpParallelSignature() override = default;
+
+  BroadcastOpParallelSignature(const Operator* op) : OpParallelSignature(op) {
+    CHECK_EQ(op->input_bns().size(), 1);
+    CHECK(op->model_bns().empty());
+    CHECK(op->const_model_bns().empty());
+  }
+
+  const std::string Description() const override { return op().op_name() + ": (B,) -> (B, ...)"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    if (!SbpInferHint4Ibn(op().SoleIbn()).sbp_parallel().has_broadcast_parallel()) {
+      return MakeOpParallelMatchSignatureMismatch();
+    }
+    int64_t expected_parallel_num = SbpInferHint4Ibn(op().SoleIbn()).parallel_num();
+    bool parallel_policy_matched = (parallel_ctx->policy() == kDataParallel);
+    bool parallel_num_matched = (parallel_ctx->parallel_num() == expected_parallel_num);
+    if (parallel_policy_matched && parallel_num_matched) {
+      return MakeOpParallelMatchSuccess();
+    } else {
+      OpParallelMatchResult ret;
+      if (!parallel_policy_matched) {
+        auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error();
+        err->set_configured(parallel_ctx->policy());
+        err->set_expected(kDataParallel);
+      } else {
+        auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error();
+        err->set_configured(parallel_ctx->parallel_num());
+        err->set_expected(parallel_num_matched);
+      }
+      return ret;
+    }
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    for (const auto& bn : op().input_bns()) { (*bn2sbp)[bn].mutable_broadcast_parallel(); }
+    for (const auto& bn : op().output_bns()) { (*bn2sbp)[bn].mutable_broadcast_parallel(); }
+  }
+};
+
+class DS_MB_2_DS_OpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(DS_MB_2_DS_OpParallelSignature);
+  ~DS_MB_2_DS_OpParallelSignature() override = default;
+
+  DS_MB_2_DS_OpParallelSignature(const Operator* op) : OpParallelSignature(op) {
+    std::vector<std::string> model_input_bns;
+    for (const auto& bn : op->input_bns()) {
+      if (op->IsInputBlobAllowedModelSplit(bn)) {
+        model_input_bns.push_back(bn);
+      } else {
+        data_input_bns_.push_back(bn);
+      }
+    }
+    CHECK_GT(data_input_bns_.size(), 0);
+    CHECK_EQ(model_input_bns.size(), 1);
+    model_input_bn_ = model_input_bns.at(0);
+  }
+
+  const std::string Description() const override {
+    return op().op_name() + ": (B, S(0), ...) -> (S(0), ...)";
+  }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    const auto& sbp_infer_hint = SbpInferHint4Ibn(model_input_bn_);
+    if (!sbp_infer_hint.is_model_broadcast()) { return MakeOpParallelMatchSignatureMismatch(); }
+    bool parallel_policy_matched = (parallel_ctx->policy() == kDataParallel);
+    bool parallel_num_matched = (parallel_ctx->parallel_num() == sbp_infer_hint.parallel_num());
+    if (!parallel_policy_matched || !parallel_num_matched) {
+      OpParallelMatchResult ret;
+      if (!parallel_policy_matched) {
+        auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error();
+        err->set_configured(parallel_ctx->policy());
+        err->set_expected(kDataParallel);
+      }
+      if (!parallel_num_matched) {
+        auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error();
+        err->set_configured(parallel_ctx->parallel_num());
+        err->set_expected(parallel_num_matched);
+      }
+      return ret;
+    }
+    return MakeOpParallelMatchSuccess();
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    for (const auto& bn : data_input_bns_) { (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); }
+    (*bn2sbp)[model_input_bn_].mutable_broadcast_parallel();
+    for (const auto& bn : op().output_bns()) {
+      (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0);
+    }
+  }
+
+ private:
+  std::vector<std::string> data_input_bns_;
+  std::string model_input_bn_;
+};
+
+class SoleIbnOpModelSplitOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SoleIbnOpModelSplitOpParallelSignature);
+  ~SoleIbnOpModelSplitOpParallelSignature() override = default;
+
+  SoleIbnOpModelSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": (S,) -> (S, ...)"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    const SbpInferHint& sbp_infer_hint = SbpInferHint4Ibn(op().SoleIbn());
+    if (!(sbp_infer_hint.is_model_split()
+          || (sbp_infer_hint.is_data_split() && sbp_infer_hint.split_axis() > 0))) {
+      return MakeOpParallelMatchSignatureMismatch();
+    }
+    int64_t expected_parallel_num = sbp_infer_hint.parallel_num();
+    bool parallel_policy_matched = (parallel_ctx->policy() == kModelParallel);
+    bool parallel_num_matched = (parallel_ctx->parallel_num() == expected_parallel_num);
+    if (!(parallel_policy_matched && parallel_num_matched)) {
+      OpParallelMatchResult ret;
+      if (!parallel_policy_matched) {
+        auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error();
+        err->set_configured(parallel_ctx->policy());
+        err->set_expected(kModelParallel);
+      }
+      if (!parallel_num_matched) {
+        auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error();
+        err->set_configured(parallel_ctx->parallel_num());
+        err->set_expected(parallel_num_matched);
+      }
+      return ret;
+    }
+    return MakeOpParallelMatchSuccess();
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    (*bn2sbp)[op().SoleIbn()] = SbpInferHint4Ibn(op().SoleIbn()).sbp_parallel();
+    for (const auto& bn : op().output_bns()) {
+      (*bn2sbp)[bn].mutable_split_parallel()->set_axis(
+          op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, bn));
+    }
+  }
+};
+
+class ModelBnOpModelSplitOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ModelBnOpModelSplitOpParallelSignature);
+  ~ModelBnOpModelSplitOpParallelSignature() override = default;
+
+  ModelBnOpModelSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override {
+    return op().op_name() + ": (B, ...) -> (S, ...)";
+  }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); }
+    return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel);
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    for (const auto& bn : op().input_bns()) { (*bn2sbp)[bn].mutable_broadcast_parallel(); }
+    for (const auto& bn : op().output_bns()) {
+      (*bn2sbp)[bn].mutable_split_parallel()->set_axis(
+          op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, bn));
+    }
+  }
+};
+
+class DB_MS_2_MS_OpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(DB_MS_2_MS_OpParallelSignature);
+  ~DB_MS_2_MS_OpParallelSignature() override = default;
+
+  DB_MS_2_MS_OpParallelSignature(const Operator* op, std::function<bool(int32_t)> IsExpectedAxis)
+      : OpParallelSignature(op), IsExpectedAxis_(IsExpectedAxis) {
+    std::vector<std::string> model_input_bns;
+    for (const auto& bn : op->input_bns()) {
+      if (op->IsInputBlobAllowedModelSplit(bn)) {
+        model_input_bns.push_back(bn);
+      } else {
+        data_input_bns_.push_back(bn);
+      }
+    }
+    CHECK_GT(data_input_bns_.size(), 0);
+    CHECK_EQ(model_input_bns.size(), 1);
+    model_input_bn_ = model_input_bns.at(0);
+  }
+
+  const std::string Description() const override {
+    return op().op_name() + ": (B, S, ...) -> (S, ...)";
+  }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    const SbpInferHint& model_sbp_infer_hint = SbpInferHint4Ibn(model_input_bn_);
+    if (!(model_sbp_infer_hint.is_model_split()
+          && IsValidSplit(model_sbp_infer_hint.split_axis()))) {
+      return MakeOpParallelMatchSignatureMismatch();
+    }
+    if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); }
+    return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel);
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    for (const auto& bn : data_input_bns_) { (*bn2sbp)[bn].mutable_broadcast_parallel(); }
+    (*bn2sbp)[model_input_bn_] = SbpInferHint4Ibn(model_input_bn_).sbp_parallel();
+    for (const auto& bn : op().output_bns()) {
+      (*bn2sbp)[bn].mutable_split_parallel()->set_axis(
+          op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, bn));
+    }
+  }
+
+ private:
+  bool IsValidSplit(int32_t axis) const { return axis != -1 && IsExpectedAxis_(axis); }
+
+  const std::function<bool(int32_t)> IsExpectedAxis_;
+  std::vector<std::string> data_input_bns_;
+  std::string model_input_bn_;
+};
+
+}  // namespace
+
+std::unique_ptr<const OpParallelSignature> MakeDataSplitOpParallelSignature(const Operator* op) {
+  return std::unique_ptr<const OpParallelSignature>(new DataSplitOpParallelSignature(op));
+}
+
+std::unique_ptr<const OpParallelSignature> MakeBroadcastOpParallelSignature(const Operator* op) {
+  return std::unique_ptr<const OpParallelSignature>(new BroadcastOpParallelSignature(op));
+}
+
+std::unique_ptr<const OpParallelSignature> MakeModelSplitOpParallelSignature(const Operator* op) {
+  if (op->IsSoleInputBlobAllowedModelSplit()) {
+    return std::unique_ptr<const OpParallelSignature>(
+        new SoleIbnOpModelSplitOpParallelSignature(op));
+  } else {
+    CHECK(!op->model_bns().empty() || !op->const_model_bns().empty());
+    return std::unique_ptr<const OpParallelSignature>(
+        new ModelBnOpModelSplitOpParallelSignature(op));
+  }
+}
+
+std::unique_ptr<const OpParallelSignature> Make_DS_MB_2_DS_OpParallelSignature(const Operator* op) {
+  return std::unique_ptr<const OpParallelSignature>(new DS_MB_2_DS_OpParallelSignature(op));
+}
+
+std::unique_ptr<const OpParallelSignature> Make_DB_MS_2_MS_OpParallelSignature(
+    const Operator* op, std::function<bool(int32_t)> IsExpectedAxis) {
+  return std::unique_ptr<const OpParallelSignature>(
+      new DB_MS_2_MS_OpParallelSignature(op, IsExpectedAxis));
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/op_parallel_signature.h b/oneflow/core/operator/op_parallel_signature.h
new file mode 100644
index 0000000000000000000000000000000000000000..94b034b6973c2f620e15406da2627a238bed087a
--- /dev/null
+++ b/oneflow/core/operator/op_parallel_signature.h
@@ -0,0 +1,63 @@
+#ifndef ONEFLOW_CORE_OPERATOR_OP_PARALLEL_SIGNATURE_H_
+#define ONEFLOW_CORE_OPERATOR_OP_PARALLEL_SIGNATURE_H_
+
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/job/sbp_parallel.h"
+#include "oneflow/core/operator/op_parallel_match_result.pb.h"
+#include "oneflow/core/job/sbp_infer_hint.h"
+
+namespace oneflow {
+
+class Operator;
+
+class OpParallelSignature {
+ public:
+  virtual ~OpParallelSignature() = default;
+  virtual const std::string Description() const = 0;
+  virtual const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const = 0;
+  virtual void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const = 0;
+
+ protected:
+  OpParallelSignature(const Operator* op) : op_(op) {}
+  const Operator& op() const { return *op_; }
+
+ private:
+  const Operator* op_;
+};
+
+const OpParallelMatchResult MakeOpParallelMatchSuccess();
+const OpParallelMatchResult MakeOpParallelMatchSignatureMismatch();
+const OpParallelMatchResult MakeOpParallelMatchParallelPolicyError(ParallelPolicy configured,
+                                                                   ParallelPolicy expected);
+const OpParallelMatchResult MakeOpParallelMatchParallelNumError(int64_t configured,
+                                                                int64_t expected);
+
+class Operator;
+
+// (S(0), ...) -> (S(0), ...)
+std::unique_ptr<const OpParallelSignature> MakeDataSplitOpParallelSignature(const Operator* op);
+
+// (S,) -> (S, ...) or (B, ...) -> (S, ...)
+std::unique_ptr<const OpParallelSignature> MakeModelSplitOpParallelSignature(const Operator* op);
+
+// (B,) -> (B, ...)
+std::unique_ptr<const OpParallelSignature> MakeBroadcastOpParallelSignature(const Operator* op);
+
+// (B, S(0), ...) -> (S(0), ...)
+// return blobs: data splitted
+// intput blobs: split data input blobs and broadcast model input blobs
+std::unique_ptr<const OpParallelSignature> Make_DS_MB_2_DS_OpParallelSignature(const Operator* op);
+
+// (B, S, ...) -> (S, ...)
+// return blobs: model splitted
+// input blobs: broadcast data input blobs and split model input blobs
+std::unique_ptr<const OpParallelSignature> Make_DB_MS_2_MS_OpParallelSignature(
+    const Operator* op, std::function<bool(int32_t)> IsExpectedAxis);
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_OP_PARALLEL_SIGNATURE_H_
diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp
index 8d7aa8188998432edc035124633592abcc2b2922..7e050ecb9241c817e51e4da5aa119a11990561eb 100644
--- a/oneflow/core/operator/operator.cpp
+++ b/oneflow/core/operator/operator.cpp
@@ -1,5 +1,6 @@
 #include "oneflow/core/operator/operator.h"
 #include "oneflow/core/graph/logical_node.h"
+#include "oneflow/core/common/balanced_splitter.h"
 
 namespace oneflow {
 
@@ -73,17 +74,23 @@ const std::string& Operator::SoleBbbn() const {
 }
 
 void Operator::InferBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                const ParallelContext* parallel_ctx,
+                                const ParallelContext* parallel_ctx, int64_t record_piece_size,
                                 std::function<void(OpContext*)> EnrollOpCtx) const {
-  InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, EnrollOpCtx);
+  InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, record_piece_size, EnrollOpCtx);
   if (op_attribute_.model_bns().size() > 0) {
     InferTotalInstanceNumDesc(GetBlobDesc4BnInOp, parallel_ctx, EnrollOpCtx);
   }
 }
 
 void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                              const ParallelContext* parallel_ctx,
+                              const ParallelContext* parallel_ctx, int64_t record_piece_size,
                               std::function<void(OpContext*)> EnrollOpCtx) const {
+  InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, record_piece_size);
+}
+
+void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                              const ParallelContext* parallel_ctx,
+                              int64_t record_piece_size) const {
   InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);
 }
 
@@ -101,25 +108,149 @@ void Operator::InferBwBufBlobDescsIf(
   }
 }
 
+void Operator::InferOutputBlobTimeShapeIf(
+    std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+    const ParallelContext* parallel_ctx, Shape* time_shape) const {
+  InferOutputBlobTimeShape(GetTimeShape4BnInOp, parallel_ctx, time_shape);
+}
+
+void Operator::InferOutputBlobTimeShape(
+    std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*,
+    Shape* time_shape) const {
+  for (const std::string& bn : input_bns()) {
+    CHECK_EQ(*GetTimeShape4BnInOp(input_bns().Get(0)), *GetTimeShape4BnInOp(bn));
+  }
+  if (input_bns().empty() == false) {
+    *time_shape = *GetTimeShape4BnInOp(input_bns().Get(0));
+  } else {
+    *time_shape = Shape(
+        {Global<JobDesc>::Get()->TotalBatchNum(), Global<JobDesc>::Get()->NumOfPiecesInBatch()});
+  }
+}
+
+int32_t Operator::OutputBlobModelSplitAxis(
+    const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+    const std::string& obn) const {
+  if (IsSoleInputBlobAllowedModelSplit()) {
+    return SbpInferHint4Ibn(SoleIbn()).split_axis();
+  } else {
+    UNIMPLEMENTED();
+    return -1;
+  }
+}
+
+void Operator::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  bool has_model = !(model_bns().empty() && const_model_bns().empty());
+  op_parallel_signatures->emplace_back(MakeDataSplitOpParallelSignature(this));
+  if (IsSoleInputBlobAllowedModelSplit()) {
+    CHECK(!has_model);
+    op_parallel_signatures->emplace_back(MakeModelSplitOpParallelSignature(this));
+    op_parallel_signatures->emplace_back(MakeBroadcastOpParallelSignature(this));
+  } else if (has_model) {
+    for (const auto& ibn : input_bns()) { CHECK(!IsInputBlobAllowedModelSplit(ibn)); }
+    op_parallel_signatures->emplace_back(MakeModelSplitOpParallelSignature(this));
+  } else if (input_bns().size() == 1) {
+    op_parallel_signatures->emplace_back(MakeBroadcastOpParallelSignature(this));
+  } else {
+    // do nothing
+  }
+}
+
+void Operator::InferInputOutputSbpParallelIf(
+    std::function<SbpParallel*(const std::string&)> SbpParallel4BnInOp,
+    std::function<const SbpInferHint&(const std::string&)> SbpInferHint4Ibn,
+    const ParallelContext* parallel_ctx) const {
+  std::vector<std::unique_ptr<const OpParallelSignature>> op_parallel_signatures;
+  GetOpParallelSignatures(&op_parallel_signatures);
+  std::vector<OpParallelMatchResult> match_results;
+  for (const auto& signature : op_parallel_signatures) {
+    match_results.push_back(signature->GetMatchResult(SbpInferHint4Ibn, parallel_ctx));
+  }
+  int32_t match_success_cnt = 0;
+  for (const auto& result : match_results) {
+    if (result.has_success()) { ++match_success_cnt; }
+  }
+  if (match_success_cnt == 1) {
+    const OpParallelSignature* match_signature = nullptr;
+    FOR_RANGE(int32_t, i, 0, op_parallel_signatures.size()) {
+      if (match_results.at(i).has_success()) {
+        match_signature = op_parallel_signatures.at(i).get();
+      }
+    }
+    HashMap<std::string, SbpParallel> bn2sbp;
+    match_signature->GenerateSignature(SbpInferHint4Ibn, &bn2sbp);
+    for (const auto& pair : bn2sbp) {
+      auto* sbp_parallel = SbpParallel4BnInOp(pair.first);
+      *sbp_parallel = pair.second;
+    }
+  } else if (match_success_cnt == 0) {
+    std::stringstream ss;
+    FOR_RANGE(int32_t, i, 0, op_parallel_signatures.size()) {
+      CHECK(match_results.at(i).has_fail());
+      const auto& failed_msg = match_results.at(i).fail();
+      ss << "op_parallel_signature match failed\n"
+         << op_parallel_signatures.at(i)->Description() << ":\n";
+      if (failed_msg.has_signature_mismatch()) {
+        ss << "\t"
+           << "signature mismatch"
+           << "\n";
+      } else {
+        CHECK(failed_msg.has_conf_error());
+        if (failed_msg.conf_error().has_parallel_policy_error()) {
+          const auto& policy_error_msg = failed_msg.conf_error().parallel_policy_error();
+          ss << "\t"
+             << "parallel_policy conf error, configured: "
+             << ParallelPolicy_Name(policy_error_msg.configured())
+             << ", expected: " << ParallelPolicy_Name(policy_error_msg.expected()) << "\n";
+        }
+        if (failed_msg.conf_error().has_parallel_num_error()) {
+          const auto& parallel_num_error_msg = failed_msg.conf_error().parallel_num_error();
+          ss << "\t"
+             << "parallel_num conf error, configured: " << parallel_num_error_msg.configured()
+             << ", expected: " << parallel_num_error_msg.expected() << "\n";
+        }
+      }
+    }
+    LOG(FATAL) << ss.str();
+  } else {
+    UNIMPLEMENTED();
+  }
+}
+
+bool Operator::IsSoleInputBlobAllowedModelSplit() const {
+  return input_bns().size() == 1 && IsInputBlobAllowedModelSplit(SoleIbn());
+}
+
+void Operator::InferIsModelBlob4OutputBlobsIf(
+    std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const {
+  InferIsModelBlob4OutputBlobs(IsModelBlob4BnInOp);
+}
+
+void Operator::InferIsModelBlob4OutputBlobs(
+    std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const {
+  bool is_model_blob = (IsSoleInputBlobAllowedModelSplit() && *IsModelBlob4BnInOp(SoleIbn()));
+  for (const std::string& obn : output_bns()) { *IsModelBlob4BnInOp(obn) = is_model_blob; }
+}
+
 void Operator::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                                    const ParallelContext* parallel_ctx,
                                    const OpContext* op_ctx) const {
   InferBwBufBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);
 }
 
-void Operator::FixParallelDesc(ParallelDesc* pr_desc) const {
-  if (model_bns().empty() && const_model_bns().empty()) {
-    pr_desc->set_policy(ParallelPolicy::kDataParallel);
-  }
-  if (pr_desc->policy() == kModelParallel && MaxModelSplitNum() != -1) {
-    pr_desc->RemoveNeedlessDevice(op_name(), MaxModelSplitNum());
+void Operator::FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                  const ParallelContext* ctx) const {
+  VirtualFixInDiffBlobDescs(GetBlobDesc4BnInOp, ctx);
+  for (const std::string& input_diff_bn : input_diff_bns()) {
+    BlobDesc* blob_desc = GetBlobDesc4BnInOp(input_diff_bn);
+    if (!blob_desc) { continue; }
+    blob_desc->set_has_loss_instance_num_field(true);
   }
-  if (pr_desc->policy() == kDataParallel) {
-    pr_desc->RemoveNeedlessDevice(op_name(), Global<JobDesc>::Get()->PieceSize());
-  }
-  VirtualFixParallelDesc(pr_desc);
 }
 
+void Operator::FixParallelDesc(ParallelDesc* pr_desc) const { VirtualFixParallelDesc(pr_desc); }
+
 void Operator::FixLbiWhenShareModel(const std::string& shared_op_name) {
   for (const std::string& model_bn : model_bns()) {
     mut_bn_in_op2lbi()->at(model_bn).set_op_name(shared_op_name);
diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h
index 135752a438ff35b966cfaf349c5430111c963698..c7e4bb4840525318a1300d4f98476c85c29062eb 100644
--- a/oneflow/core/operator/operator.h
+++ b/oneflow/core/operator/operator.h
@@ -6,10 +6,11 @@
 #include "oneflow/core/common/auto_registration_factory.h"
 #include "oneflow/core/job/keyword.h"
 #include "oneflow/core/job/parallel_desc.h"
-#include "oneflow/core/job/placement.pb.h"
+#include "oneflow/core/job/sbp_parallel.h"
 #include "oneflow/core/kernel/kernel.pb.h"
 #include "oneflow/core/operator/op_conf.pb.h"
 #include "oneflow/core/register/blob_desc.h"
+#include "oneflow/core/operator/op_parallel_signature.h"
 
 namespace oneflow {
 
@@ -28,7 +29,8 @@ class Operator {
   //
   void InitFromOpConf(const OperatorConf& op_conf);
   virtual void InitFromOpConf() = 0;
-  virtual bool IsElemWiseOp() const { return false; }
+  bool IsSoleInputBlobAllowedModelSplit() const;
+  virtual bool IsInputBlobAllowedModelSplit(const std::string& ibn) const = 0;
 
   ActivationType GetActivationType() const;
 
@@ -37,6 +39,7 @@ class Operator {
   virtual bool IsLossOp() const { return false; }
   virtual bool IsRecurrentOp() const { return false; }
   virtual bool IsEmbeddingLookupOp() const { return false; }
+  virtual bool IsAllOutputConst() const { return false; }
 
   bool NeedOutBlobWhenBackwardIf() const {
     return NeedOutBlobWhenBackward() || (GetActivationType() != ActivationType::kNone);
@@ -44,6 +47,8 @@ class Operator {
   virtual bool NeedOutBlobWhenBackward() const { return true; }
   bool NeedInBlobWhenBackwardIf() const { return NeedInBlobWhenBackward(); }
   virtual bool NeedInBlobWhenBackward() const { return true; }
+  virtual bool IsForwardInplace() const { return false; }
+  virtual bool IsBackwardInplace() const { return false; }
 
   // bn_in_op <-> lbi
   const LogicalBlobId& BnInOp2Lbi(const std::string& bn_in_op) const;
@@ -114,10 +119,13 @@ class Operator {
   // Read: shape of input_blobs
   // Write: shape of output_blobs, model_blobs, data_tmp_blobs, const_model_blobs, const_buf_blobs
   void InferBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                        const ParallelContext*, std::function<void(OpContext*)> EnrollOpCtx) const;
+                        const ParallelContext*, int64_t record_piece_size,
+                        std::function<void(OpContext*)> EnrollOpCtx) const;
   virtual void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                              const ParallelContext*,
+                              const ParallelContext*, int64_t record_piece_size,
                               std::function<void(OpContext*)> EnrollOpCtx) const;
+  virtual void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                              const ParallelContext*, int64_t record_piece_size) const;
   virtual void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                               const ParallelContext*) const;
   void InferBwBufBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
@@ -131,17 +139,39 @@ class Operator {
       const ParallelContext*) const {
     UNIMPLEMENTED();
   }
+  // Infer out blob's time shape
+  void InferOutputBlobTimeShapeIf(
+      std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*,
+      Shape* time_shape) const;
+  virtual void InferOutputBlobTimeShape(
+      std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*,
+      Shape* time_shape) const;
+  // Infer blob's SbpParallel
+  void InferInputOutputSbpParallelIf(
+      std::function<SbpParallel*(const std::string&)> SbpParallel4BnInOp,
+      std::function<const SbpInferHint&(const std::string&)> SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const;
+  // Infer is_model_blob
+  void InferIsModelBlob4OutputBlobsIf(
+      std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const;
   virtual void FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                  const ParallelContext*) const {}
+                                  const ParallelContext*) const;
+  virtual void VirtualFixInDiffBlobDescs(
+      std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+      const ParallelContext*) const {}
 
   void FixParallelDesc(ParallelDesc* pr_desc) const;
   void FixLbiWhenShareModel(const std::string& shared_op_name);
-  virtual int32_t ModelSplitAxis() const { return -1; }
-  virtual int32_t MaxModelSplitNum() const { return -1; }
+  virtual int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const;
   void GenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                      bool is_forward, const ParallelContext*, KernelConf*, const OpContext*) const;
 
  protected:
+  virtual void InferIsModelBlob4OutputBlobs(
+      std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const;
+
   int64_t cudnn_buf_limit_byte() const;
 
   virtual PbMessage* MutableCustomizedKernelConf(KernelConf*) const {
@@ -220,6 +250,8 @@ class Operator {
   void StrFieldTolower(const std::string& field_name);
 
  private:
+  virtual void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const;
   LogicalBlobId dtbn2lbi(const std::string& data_tmp_bn) const;
   LogicalBlobId fbbn2lbi(const std::string& fw_buf_bn) const { return dtbn2lbi(fw_buf_bn); }
   LogicalBlobId bbbn2lbi(const std::string& bw_buf_bn) const { return dtbn2lbi(bw_buf_bn); }
@@ -265,6 +297,12 @@ struct OnlyCpuSupportPredicator {
 
 std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf);
 
+inline std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf, DeviceType device_type) {
+  OperatorConf dev_op_conf = op_conf;
+  dev_op_conf.set_device_type(device_type);
+  return ConstructOp(dev_op_conf);
+}
+
 void EraseEmptyBnInVec(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                        PbRpf<std::string>* bns);
 
@@ -283,6 +321,14 @@ inline LogicalBlobId GenLogicalBlobId(const std::string& lbn) {
   return lbi;
 }
 
+inline std::string GenLogicalBlobName(const LogicalBlobId& lbi) {
+  CHECK_EQ(lbi.has_op_name(), true);
+  CHECK_EQ(lbi.has_blob_name(), true);
+  CHECK_EQ(lbi.has_clone_id(), false);
+  CHECK_EQ(lbi.is_packed_id(), false);
+  return lbi.op_name() + "/" + lbi.blob_name();
+}
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_OPERATOR_OPERATOR_H_
diff --git a/oneflow/core/operator/pack_op.cpp b/oneflow/core/operator/pack_op.cpp
index a38c9a531b4df0d4cff51638fe6b02313795f130..6dec573945d1298316f21dcffe8fb9d96a256054 100644
--- a/oneflow/core/operator/pack_op.cpp
+++ b/oneflow/core/operator/pack_op.cpp
@@ -9,20 +9,20 @@ void PackOp::InitFromOpConf() {
   EnrollOutputBn("out", false);
 }
 
-int32_t PackOp::GetPackNum(int64_t parallel_num) const {
+void PackOp::InferOutputBlobTimeShape(
+    std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+    const ParallelContext* parallel_ctx, Shape* time_shape) const {
+  std::vector<int64_t> dim_vec(GetTimeShape4BnInOp("in")->dim_vec());
+  int32_t pack_num = GetPackNum();
+  CHECK_EQ(pack_num, dim_vec.back());
+  dim_vec.pop_back();
+  *time_shape = Shape(dim_vec);
+}
+
+int32_t PackOp::GetPackNum() const {
   CHECK(op_conf().has_pack_conf());
   const PackOpConf& conf = op_conf().pack_conf();
-  if (conf.has_pack_num()) {
-    return conf.pack_num();
-  } else if (conf.has_pack_num_per_record()) {
-    CHECK_EQ(Global<JobDesc>::Get()->PieceSize() % parallel_num, 0);
-    int64_t pack_num =
-        Global<JobDesc>::Get()->PieceSize() / parallel_num * conf.pack_num_per_record();
-    CHECK_LE(pack_num, static_cast<int64_t>(MaxVal<int32_t>()));
-    return static_cast<int32_t>(pack_num);
-  } else {
-    UNIMPLEMENTED();
-  }
+  return conf.pack_num();
 }
 
 REGISTER_OP(OperatorConf::kPackConf, PackOp);
diff --git a/oneflow/core/operator/pack_op.h b/oneflow/core/operator/pack_op.h
index 206d5a323a8a52b5c220759149c21d408e4d24be..835514527c5f7169698046cc3d1fb77a01923713 100644
--- a/oneflow/core/operator/pack_op.h
+++ b/oneflow/core/operator/pack_op.h
@@ -15,10 +15,16 @@ class PackOp final : public Operator {
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override { return op_conf().pack_conf(); }
   LogicalNode* NewProperLogicalNode() { return new PackForwardLogicalNode; }
+  void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+                                const ParallelContext* parallel_ctx,
+                                Shape* time_shape) const override;
 
   bool NeedInBlobWhenBackward() const override { return true; }
   bool NeedOutBlobWhenBackward() const override { return false; }
-  int32_t GetPackNum(int64_t parallel_num) const;
+  int32_t GetPackNum() const;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/parallel_cast_op.cpp b/oneflow/core/operator/parallel_cast_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d7d803aa0df41884cdc05a024abecee07a1fd484
--- /dev/null
+++ b/oneflow/core/operator/parallel_cast_op.cpp
@@ -0,0 +1,58 @@
+#include "oneflow/core/operator/parallel_cast_op.h"
+
+namespace oneflow {
+
+namespace {
+
+SbpParallel GetSbpParallel(const ParallelCastOpConf& conf) {
+  SbpParallel ret;
+  if (conf.has_split_parallel()) {
+    *ret.mutable_split_parallel() = conf.split_parallel();
+  } else if (conf.has_broadcast_parallel()) {
+    *ret.mutable_broadcast_parallel() = conf.broadcast_parallel();
+  } else {
+    UNIMPLEMENTED();
+  }
+  return ret;
+}
+
+class ParallelCastOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ParallelCastOpParallelSignature);
+  ~ParallelCastOpParallelSignature() override = default;
+
+  ParallelCastOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": A -> A"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp,
+      const ParallelContext* parallel_ctx) const override {
+    const auto& configured_sbp_parallel = GetSbpParallel(op().op_conf().parallel_cast_conf());
+    if (SbpInferHint4BnInOp("in").sbp_parallel() == configured_sbp_parallel
+        && parallel_ctx->parallel_num() != SbpInferHint4BnInOp("in").parallel_num()) {
+      return MakeOpParallelMatchParallelNumError(parallel_ctx->parallel_num(),
+                                                 SbpInferHint4BnInOp("in").parallel_num());
+    }
+    return MakeOpParallelMatchSuccess();
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    const auto& sbp_parallel = GetSbpParallel(op().op_conf().parallel_cast_conf());
+    (*bn2sbp)["in"] = sbp_parallel;
+    (*bn2sbp)["out"] = sbp_parallel;
+  }
+};
+
+}  // namespace
+
+void ParallelCastOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(new ParallelCastOpParallelSignature(this));
+}
+
+REGISTER_OP(OperatorConf::kParallelCastConf, ParallelCastOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/parallel_cast_op.h b/oneflow/core/operator/parallel_cast_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..b2805dd05438dc6f8204745ae5ac8ffb86e46120
--- /dev/null
+++ b/oneflow/core/operator/parallel_cast_op.h
@@ -0,0 +1,23 @@
+#ifndef ONEFLOW_CORE_OPERATOR_PARALLEL_CAST_OP_H_
+#define ONEFLOW_CORE_OPERATOR_PARALLEL_CAST_OP_H_
+
+#include "oneflow/core/operator/identity_op.h"
+
+namespace oneflow {
+
+class ParallelCastOp final : public IdentityOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ParallelCastOp);
+  ParallelCastOp() = default;
+  ~ParallelCastOp() override = default;
+
+  const PbMessage& GetCustomizedConf() const override { return op_conf().parallel_cast_conf(); }
+
+ private:
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_PARALLEL_CAST_OP_H_
diff --git a/oneflow/core/operator/pooling_op.h b/oneflow/core/operator/pooling_op.h
index db08d6a0c7fecd4fe1bfb54b3a2f51f88bbe7c80..a23153a21189622048747bddf1285e5175bc1252 100644
--- a/oneflow/core/operator/pooling_op.h
+++ b/oneflow/core/operator/pooling_op.h
@@ -24,6 +24,8 @@ class PoolingOp : public Operator {
                             KernelConf* kernel_conf) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void CheckPoolSizeAndStrides() const;
   Shape GetOutShape(int64_t in_n, int64_t in_c, const std::vector<int64_t>& out) const;
 };
diff --git a/oneflow/core/operator/print_op.h b/oneflow/core/operator/print_op.h
index 030d02d25ab1f6741b0b5cdf1cc11dcec9b3a69c..a02fac123b27d1a0e4f7c0653e6444af5cb0d2af 100644
--- a/oneflow/core/operator/print_op.h
+++ b/oneflow/core/operator/print_op.h
@@ -20,6 +20,8 @@ class PrintOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override {}
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override;
 };
 
diff --git a/oneflow/core/operator/record_load_op.cpp b/oneflow/core/operator/record_load_op.cpp
index 28f8ba5872eadcac5917bf2f28cf00a09488ef51..b119e979ccf05a0cc3724a03c8bae34f81538118 100644
--- a/oneflow/core/operator/record_load_op.cpp
+++ b/oneflow/core/operator/record_load_op.cpp
@@ -10,14 +10,21 @@ void RecordLoadOp::InitFromOpConf() {
 const PbMessage& RecordLoadOp::GetCustomizedConf() const { return op_conf().record_load_conf(); }
 
 void RecordLoadOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                  const ParallelContext* parallel_ctx) const {
+                                  const ParallelContext* parallel_ctx,
+                                  int64_t record_piece_size) const {
   BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
-  int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize();
-  CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0);
-  out_blob_desc->mut_shape() = Shape({global_piece_size / parallel_ctx->parallel_num()});
+  CHECK_EQ(record_piece_size % parallel_ctx->parallel_num(), 0);
+  out_blob_desc->mut_shape() = Shape({record_piece_size / parallel_ctx->parallel_num()});
   out_blob_desc->set_data_type(kOFRecord);
 }
 
+void RecordLoadOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
+  int64_t device_piece_size = GetBlobDesc4BnInOp("out")->shape().At(0);
+  kernel_conf->mutable_record_load_conf()->set_device_piece_size(device_piece_size);
+}
+
 REGISTER_CPU_OP(OperatorConf::kRecordLoadConf, RecordLoadOp);
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/record_load_op.h b/oneflow/core/operator/record_load_op.h
index cc069578176561ac1c650f9f7a122f3e0c002264..8ad1274de7ee4344bff3cdd86f498a66fc65a859 100644
--- a/oneflow/core/operator/record_load_op.h
+++ b/oneflow/core/operator/record_load_op.h
@@ -15,11 +15,15 @@ class RecordLoadOp final : public Operator {
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override;
+                      const ParallelContext* parallel_ctx,
+                      int64_t record_piece_size) const override;
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext*, KernelConf*) const override;
 
   LogicalNode* NewProperLogicalNode() override { return new RecordLoadLogicalNode; }
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/recurrent_op.cpp b/oneflow/core/operator/recurrent_op.cpp
index 5afbdbcc0d84a8e22e040c89c0f012733d4d02d1..84948bb42bfb30a7685cb2eb7adf51f275888ef4 100644
--- a/oneflow/core/operator/recurrent_op.cpp
+++ b/oneflow/core/operator/recurrent_op.cpp
@@ -19,10 +19,6 @@ void RecurrentOp::InitFromOpConf() {
   VirtualInitFromOpConf();
 }
 
-int32_t RecurrentOp::MaxModelSplitNum() const {
-  return GetValFromCustomizedConf<int32_t>("hidden_size");
-}
-
 void RecurrentOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                                  const ParallelContext* parallel_ctx) const {
   const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
diff --git a/oneflow/core/operator/recurrent_op.h b/oneflow/core/operator/recurrent_op.h
index cef95876e727ea184d9d9e5dbf13f63ad57dfc49..ae210368e45a69b984f17fd8bf66405c5bec153a 100644
--- a/oneflow/core/operator/recurrent_op.h
+++ b/oneflow/core/operator/recurrent_op.h
@@ -16,16 +16,21 @@ class RecurrentOp : public Operator {
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
-  int32_t ModelSplitAxis() const override { return 1; }
-  int32_t MaxModelSplitNum() const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override {
+    return 1;
+  }
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
   virtual void VirtualInitFromOpConf() { UNIMPLEMENTED(); }
   virtual void VirtualInferBlobDescs(
       std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
       const ParallelContext* parallel_ctx) const {
     UNIMPLEMENTED();
   }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override;
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/reduce_add_op.h b/oneflow/core/operator/reduce_add_op.h
index 276abf7fc78035952c2f7329d98d95c2c3eb8f6d..12d37e36475d682e9d46885154188d49e6a44863 100644
--- a/oneflow/core/operator/reduce_add_op.h
+++ b/oneflow/core/operator/reduce_add_op.h
@@ -18,6 +18,8 @@ class ReduceAddOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/reduce_concat_op.cpp b/oneflow/core/operator/reduce_concat_op.cpp
index 09861d11d8b37173cfb7edd624380b7e6821a350..65b4d171241431200a379efa95588a8d0c8ffb79 100644
--- a/oneflow/core/operator/reduce_concat_op.cpp
+++ b/oneflow/core/operator/reduce_concat_op.cpp
@@ -4,6 +4,11 @@
 
 namespace oneflow {
 
+struct ReduceConcatOpCtx : public OpContext {
+  ReduceConcatOpCtx(const int64_t elem_cnt) : out_blob_elem_cnt(elem_cnt) {}
+  int64_t out_blob_elem_cnt;
+};
+
 void ReduceConcatOp::InitFromOpConf() {
   CHECK(op_conf().has_reduce_concat_conf());
   for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
@@ -17,29 +22,49 @@ const PbMessage& ReduceConcatOp::GetCustomizedConf() const {
 }
 
 void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                                    const ParallelContext* parallel_ctx) const {
-  int32_t in_num = op_conf().reduce_concat_conf().in_num();
-  CHECK_GE(in_num, 2);
-  BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0));
+                                    const ParallelContext* parallel_ctx, int64_t record_piece_size,
+                                    std::function<void(OpContext*)> EnrollOpCtx) const {
+  const BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0));
+  const DataType data_type = first_in_blob->data_type();
+  for (int32_t i = 1; i < op_conf().reduce_concat_conf().in_num(); ++i) {
+    CHECK_EQ(data_type, GetBlobDesc4BnInOp(input_bns().Get(i))->data_type());
+  }
+
   BlobDesc* out_blob = GetBlobDesc4BnInOp(SoleObn());
   *out_blob = *first_in_blob;
-  int64_t out_blob_elem_cnt = first_in_blob->shape().elem_cnt();
-  for (int32_t i = 1; i < in_num; ++i) {
-    out_blob_elem_cnt += GetBlobDesc4BnInOp(input_bns().Get(i))->shape().elem_cnt();
+  int64_t in_blob_body_size_sum = 0;
+  for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
+    in_blob_body_size_sum +=
+        RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody();
   }
+  const int64_t data_type_byte_size =
+      static_cast<int64_t>(GetSizeOfDataType(first_in_blob->data_type()));
+  CHECK_EQ(in_blob_body_size_sum % data_type_byte_size, 0);
+  const int64_t out_blob_elem_cnt =
+      RoundUp(in_blob_body_size_sum / data_type_byte_size, parallel_ctx->parallel_num());
   out_blob->mut_shape() = Shape({out_blob_elem_cnt});
+
+  // construct reduce_concat_op_ctx for later CHECK in ReduceConcatOp::VirtualGenKernelConf
+  ReduceConcatOpCtx* reduce_concat_op_ctx = new ReduceConcatOpCtx(out_blob_elem_cnt);
+  EnrollOpCtx(reduce_concat_op_ctx);
 }
 
 void ReduceConcatOp::VirtualGenKernelConf(
-    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
-    KernelConf* kernel_conf) const {
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, KernelConf* kernel_conf, const OpContext* op_ctx) const {
   ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf();
   int64_t offset = 0;
   for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) {
     reduce_concat_conf->mutable_data_offset()->Add(offset);
     offset += RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody();
   }
-  CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody());
+  const int64_t data_type_byte_size =
+      static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(input_bns().Get(0))->data_type()));
+  CHECK_EQ(offset % data_type_byte_size, 0);
+  const int64_t out_blob_elem_cnt =
+      RoundUp(offset / data_type_byte_size, parallel_ctx->parallel_num());
+  const ReduceConcatOpCtx* reduce_concat_op_ctx = static_cast<const ReduceConcatOpCtx*>(op_ctx);
+  CHECK_EQ(reduce_concat_op_ctx->out_blob_elem_cnt, out_blob_elem_cnt);
 }
 
 LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const {
diff --git a/oneflow/core/operator/reduce_concat_op.h b/oneflow/core/operator/reduce_concat_op.h
index c2419c5178972343666f205b54df79355e9ec1d2..7e86ebd4f976d36e7116673f39905749e846211e 100644
--- a/oneflow/core/operator/reduce_concat_op.h
+++ b/oneflow/core/operator/reduce_concat_op.h
@@ -15,11 +15,15 @@ class ReduceConcatOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override;
+                      const ParallelContext* parallel_ctx, int64_t record_piece_size,
+                      std::function<void(OpContext*)> EnrollOpCtx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                            const ParallelContext*, KernelConf*) const override;
+                            const ParallelContext*, KernelConf*,
+                            const OpContext* op_ctx) const override;
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/reduce_gather_op.h b/oneflow/core/operator/reduce_gather_op.h
index 2753330de0403791d7d4ab646784301f5cb990be..0b97f9e8f1729f7d242a1d6cc99fda7bac40d44b 100644
--- a/oneflow/core/operator/reduce_gather_op.h
+++ b/oneflow/core/operator/reduce_gather_op.h
@@ -18,6 +18,8 @@ class ReduceGatherOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext*, KernelConf*) const override;
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
diff --git a/oneflow/core/operator/reduce_identity_op.cpp b/oneflow/core/operator/reduce_identity_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..baffcc20dc086f865d6cf081cb68edc3a7601b18
--- /dev/null
+++ b/oneflow/core/operator/reduce_identity_op.cpp
@@ -0,0 +1,26 @@
+#include "oneflow/core/operator/reduce_identity_op.h"
+
+namespace oneflow {
+
+void ReduceIdentityOp::InitFromOpConf() {
+  EnrollInputBn("in", false);
+  EnrollOutputBn("out", false);
+}
+
+void ReduceIdentityOp::InferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+  CHECK_EQ(GetBlobDesc4BnInOp("out")->shape().elem_cnt() % parallel_ctx->parallel_num(), 0);
+}
+
+LogicalBlobId ReduceIdentityOp::obn2lbi(const std::string& output_bn) const {
+  LogicalBlobId ret;
+  ret.set_op_name(op_name());
+  ret.set_blob_name(output_bn);
+  return ret;
+}
+
+REGISTER_OP(OperatorConf::kReduceIdentityConf, ReduceIdentityOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/reduce_identity_op.h b/oneflow/core/operator/reduce_identity_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..f5b2d2933bace02abad785a0148885daef1bb51c
--- /dev/null
+++ b/oneflow/core/operator/reduce_identity_op.h
@@ -0,0 +1,32 @@
+#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_IDENTITY_OP_H_
+#define ONEFLOW_CORE_OPERATOR_REDUCE_IDENTITY_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+#include "oneflow/core/graph/logical_node.h"
+
+namespace oneflow {
+
+class ReduceIdentityOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ReduceIdentityOp);
+  ReduceIdentityOp() = default;
+  ~ReduceIdentityOp() = default;
+
+  LogicalNode* NewProperLogicalNode() { return new ReduceIdentityLogicalNode; }
+  void InitFromOpConf() override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().reduce_identity_conf(); }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
+  LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
+  LogicalBlobId obn2lbi(const std::string& output_bn) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_REDUCE_IDENTITY_OP_H_
diff --git a/oneflow/core/operator/reduce_mean_op.cpp b/oneflow/core/operator/reduce_mean_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0e39682fcd1bf7876363ea11a4931da36960f592
--- /dev/null
+++ b/oneflow/core/operator/reduce_mean_op.cpp
@@ -0,0 +1,103 @@
+#include "oneflow/core/operator/reduce_mean_op.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+namespace {
+std::vector<int64_t> KeepDims(const std::vector<int64_t> dim_vec,
+                              const std::vector<int64_t> axis_vec) {
+  std::vector<int64_t> ret = dim_vec;
+  for (const auto& axis : axis_vec) { ret[axis] = 1; }
+  return ret;
+}
+
+std::vector<int64_t> DropDims(const std::vector<int64_t> dim_vec,
+                              const std::vector<int64_t> axis_vec) {
+  std::vector<int64_t> ret;
+  std::vector<int32_t> dim2is_reduced(dim_vec.size());
+  for (const auto& axis : axis_vec) { dim2is_reduced[axis] = 1; }
+  FOR_RANGE(int64_t, i, 0, dim_vec.size()) {
+    if (dim2is_reduced[i] != 1) { ret.push_back(dim_vec[i]); }
+  }
+  if (ret.empty()) { ret.push_back(1); }
+  return ret;
+}
+
+std::vector<int64_t> ShiftAxisIfNegative(std::vector<int64_t> axis_vec, const int64_t num_axes) {
+  FOR_RANGE(size_t, i, 0, axis_vec.size()) {
+    if (axis_vec[i] < 0) { axis_vec[i] += num_axes; }
+    CHECK_LT(axis_vec[i], num_axes);
+    CHECK_GE(axis_vec[i], 0);
+  }
+  return axis_vec;
+}
+
+}  // namespace
+
+void ReduceMeanOp::InitFromOpConf() {
+  CHECK(op_conf().has_reduce_mean_conf());
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+  EnrollFwBufBn("fw_tmp");
+  EnrollBwBufBn("bw_tmp");
+}
+
+const PbMessage& ReduceMeanOp::GetCustomizedConf() const { return op_conf().reduce_mean_conf(); }
+
+void ReduceMeanOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                  const ParallelContext*) const {
+  const ReduceMeanOpConf& conf = op_conf().reduce_mean_conf();
+  const BlobDesc* in_blob = GetBlobDesc4BnInOp("in");
+  *GetBlobDesc4BnInOp("fw_tmp") = *in_blob;
+  std::vector<int64_t> out_dim_vec;
+  if (conf.axis().empty()) {
+    if (conf.keep_dims() == true) {
+      out_dim_vec.resize(in_blob->shape().NumAxes());
+      std::fill(out_dim_vec.begin(), out_dim_vec.end(), 1);
+    } else {
+      out_dim_vec = {1};
+    }
+  } else {
+    const PbRf<int32_t>& axis_repeated = conf.axis();
+    std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()};
+    axis_vec = ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes());
+    std::sort(axis_vec.begin(), axis_vec.end());
+    CHECK(std::unique(axis_vec.begin(), axis_vec.end()) == axis_vec.end())
+        << "duplicate found in axis";
+    if (conf.keep_dims() == true) {
+      out_dim_vec = KeepDims(in_blob->shape().dim_vec(), axis_vec);
+    } else {
+      out_dim_vec = DropDims(in_blob->shape().dim_vec(), axis_vec);
+    }
+  }
+  CHECK(!out_dim_vec.empty());
+  BlobDesc* out_blob = GetBlobDesc4BnInOp("out");
+  out_blob->set_data_type(in_blob->data_type());
+  out_blob->mut_shape() = Shape(out_dim_vec);
+}
+
+void ReduceMeanOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
+    KernelConf* kernel_conf) const {
+  const ReduceMeanOpConf& conf = op_conf().reduce_mean_conf();
+  const BlobDesc* in_blob = GetBlobDesc4BnInOp("in");
+  std::vector<int64_t> kept_dims;
+  if (conf.axis().empty()) {
+    kept_dims.resize(in_blob->shape().NumAxes());
+    std::fill(kept_dims.begin(), kept_dims.end(), 1);
+  } else {
+    const PbRf<int32_t>& axis_repeated = op_conf().reduce_mean_conf().axis();
+    std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()};
+    kept_dims = KeepDims(in_blob->shape().dim_vec(),
+                         ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes()));
+  }
+  Shape(kept_dims).ToProto(kernel_conf->mutable_reduce_sum_conf()->mutable_kept_dims_shape());
+}
+
+void ReduceMeanOp::InferBwBufBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const {
+  *GetBlobDesc4BnInOp("bw_tmp") = *GetBlobDesc4BnInOp("out");
+}
+
+REGISTER_OP(OperatorConf::kReduceMeanConf, ReduceMeanOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/reduce_mean_op.h b/oneflow/core/operator/reduce_mean_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0477d2d25899146bdf4fea2af5eef235f240b739
--- /dev/null
+++ b/oneflow/core/operator/reduce_mean_op.h
@@ -0,0 +1,32 @@
+#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_MEAN_OP_H_
+#define ONEFLOW_CORE_OPERATOR_REDUCE_MEAN_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class ReduceMeanOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ReduceMeanOp);
+  ReduceMeanOp() = default;
+  ~ReduceMeanOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                           const ParallelContext*) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx,
+                            KernelConf* kernel_conf) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_REDUCE_MEAN_OP_H_
diff --git a/oneflow/core/operator/reduce_scatter_op.h b/oneflow/core/operator/reduce_scatter_op.h
index 745c4d1bca108b7cfcdb2ec65e031d67e4e8696b..be911903031770b6c49623493029a2e7874234ab 100644
--- a/oneflow/core/operator/reduce_scatter_op.h
+++ b/oneflow/core/operator/reduce_scatter_op.h
@@ -18,6 +18,8 @@ class ReduceScatterOp final : public Operator {
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override;
 };
diff --git a/oneflow/core/operator/reduce_split_op.cpp b/oneflow/core/operator/reduce_split_op.cpp
index 78995a57cc7de19e5cb0fd8ccdf2d2865355edc5..da809794d612f6515f4a1aea4f738b78c01d499d 100644
--- a/oneflow/core/operator/reduce_split_op.cpp
+++ b/oneflow/core/operator/reduce_split_op.cpp
@@ -15,15 +15,21 @@ void ReduceSplitOp::InitFromOpConf() {
 const PbMessage& ReduceSplitOp::GetCustomizedConf() const { return op_conf().reduce_split_conf(); }
 
 void ReduceSplitOp::VirtualGenKernelConf(
-    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
-    KernelConf* kernel_conf) const {
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
   ReduceSplitKernelConf* reduce_split_conf = kernel_conf->mutable_reduce_split_conf();
   int64_t offset = 0;
   for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) {
     reduce_split_conf->mutable_data_offset()->Add(offset);
     offset += RtBlobDesc(*(GetBlobDesc4BnInOp(output_bns().Get(i)))).ByteSizeOfBlobBody();
   }
-  CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody());
+  const int64_t data_type_byte_size =
+      static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(SoleIbn())->data_type()));
+  CHECK_EQ(offset % data_type_byte_size, 0);
+  const int64_t out_blob_elem_cnt_sum =
+      RoundUp(offset / data_type_byte_size, parallel_ctx->parallel_num());
+  const int64_t in_blob_elem_cnt = GetBlobDesc4BnInOp(SoleIbn())->shape().elem_cnt();
+  CHECK_EQ(out_blob_elem_cnt_sum, in_blob_elem_cnt);
 }
 
 REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp);
diff --git a/oneflow/core/operator/reduce_split_op.h b/oneflow/core/operator/reduce_split_op.h
index a6f632fe6d3a58842f7f9a67c627bd8e535f933e..d5f02679dc9ed4022973bf0eacabd48d5bb56f2f 100644
--- a/oneflow/core/operator/reduce_split_op.h
+++ b/oneflow/core/operator/reduce_split_op.h
@@ -13,11 +13,11 @@ class ReduceSplitOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override {}
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext*, KernelConf*) const override;
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
diff --git a/oneflow/core/operator/reduce_sum_op.cpp b/oneflow/core/operator/reduce_sum_op.cpp
index 45d470a6483bc9f2dc93e8ebbef9deca3425ab65..38367bc1cb8fb2314d4ed860e7caf3d84a00f508 100644
--- a/oneflow/core/operator/reduce_sum_op.cpp
+++ b/oneflow/core/operator/reduce_sum_op.cpp
@@ -2,55 +2,98 @@
 #include "oneflow/core/kernel/kernel_util.h"
 
 namespace oneflow {
+namespace {
+std::vector<int64_t> KeepDims(const std::vector<int64_t> dim_vec,
+                              const std::vector<int64_t> axis_vec) {
+  std::vector<int64_t> ret = dim_vec;
+  for (const auto& axis : axis_vec) { ret[axis] = 1; }
+  return ret;
+}
+
+std::vector<int64_t> DropDims(const std::vector<int64_t> dim_vec,
+                              const std::vector<int64_t> axis_vec) {
+  std::vector<int64_t> ret;
+  std::vector<int32_t> dim2is_reduced(dim_vec.size());
+  for (const auto& axis : axis_vec) { dim2is_reduced[axis] = 1; }
+  FOR_RANGE(int64_t, i, 0, dim_vec.size()) {
+    if (dim2is_reduced[i] != 1) { ret.push_back(dim_vec[i]); }
+  }
+  if (ret.empty()) { ret.push_back(1); }
+  return ret;
+}
+
+std::vector<int64_t> ShiftAxisIfNegative(std::vector<int64_t> axis_vec, const int64_t num_axes) {
+  FOR_RANGE(size_t, i, 0, axis_vec.size()) {
+    if (axis_vec[i] < 0) { axis_vec[i] += num_axes; }
+    CHECK_LT(axis_vec[i], num_axes);
+    CHECK_GE(axis_vec[i], 0);
+  }
+  return axis_vec;
+}
+
+}  // namespace
 
 void ReduceSumOp::InitFromOpConf() {
+  CHECK(op_conf().has_reduce_sum_conf());
   EnrollInputBn("in");
   EnrollOutputBn("out");
-  EnrollDataTmpBn("fw_tmp");
+  if (op_conf().reduce_sum_conf().has_in_sys()) {
+    EnrollDataTmpBn("fw_tmp");
+  } else {
+    EnrollFwBufBn("fw_tmp");
+  }
 }
 
 const PbMessage& ReduceSumOp::GetCustomizedConf() const { return op_conf().reduce_sum_conf(); }
 
 void ReduceSumOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                                  const ParallelContext*) const {
+  const ReduceSumOpConf& conf = op_conf().reduce_sum_conf();
   const BlobDesc* in_blob = GetBlobDesc4BnInOp("in");
-  std::vector<int64_t> out_dim_vec = {1};
-  if (op_conf().reduce_sum_conf().has_axis()) {
-    out_dim_vec = in_blob->shape().dim_vec();
-    int32_t axis = GetCorrectAxis(GetBlobDesc4BnInOp);
-    if (op_conf().reduce_sum_conf().keepdims() == true) {
-      out_dim_vec[axis] = 1;
+  *GetBlobDesc4BnInOp("fw_tmp") = *in_blob;
+  std::vector<int64_t> out_dim_vec;
+  if (conf.axis().empty()) {
+    if (conf.keep_dims() == true) {
+      out_dim_vec.resize(in_blob->shape().NumAxes());
+      std::fill(out_dim_vec.begin(), out_dim_vec.end(), 1);
     } else {
-      out_dim_vec.erase(out_dim_vec.begin() + axis);
+      out_dim_vec = {1};
     }
-    if (out_dim_vec.empty()) { out_dim_vec.push_back(1); }
   } else {
-    BlobDesc* fw_tmp_blob = GetBlobDesc4BnInOp("fw_tmp");
-    fw_tmp_blob->mut_shape() = Shape({static_cast<int64_t>(
-        GetTmpSizeForReduceSum(in_blob->data_type(), in_blob->shape().elem_cnt()))});
-    fw_tmp_blob->set_data_type(DataType::kChar);
+    const PbRf<int32_t>& axis_repeated = conf.axis();
+    std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()};
+    axis_vec = ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes());
+    std::sort(axis_vec.begin(), axis_vec.end());
+    CHECK(std::unique(axis_vec.begin(), axis_vec.end()) == axis_vec.end())
+        << "duplicate found in axis";
+    if (conf.keep_dims() == true) {
+      out_dim_vec = KeepDims(in_blob->shape().dim_vec(), axis_vec);
+    } else {
+      out_dim_vec = DropDims(in_blob->shape().dim_vec(), axis_vec);
+    }
   }
+  CHECK(!out_dim_vec.empty());
   BlobDesc* out_blob = GetBlobDesc4BnInOp("out");
-  if (op_conf().reduce_sum_conf().has_axis() && GetCorrectAxis(GetBlobDesc4BnInOp) > 0) {
-    *out_blob = *in_blob;
-  } else {
-    out_blob->set_data_type(in_blob->data_type());
-  }
+  out_blob->set_data_type(in_blob->data_type());
   out_blob->mut_shape() = Shape(out_dim_vec);
 }
 
 void ReduceSumOp::VirtualGenKernelConf(
     std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
     KernelConf* kernel_conf) const {
-  if (op_conf().reduce_sum_conf().has_axis() == false) { return; }
-  kernel_conf->mutable_reduce_sum_conf()->set_axis(GetCorrectAxis(GetBlobDesc4BnInOp));
-}
-
-int32_t ReduceSumOp::GetCorrectAxis(
-    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp) const {
-  int32_t axis = op_conf().reduce_sum_conf().axis();
-  if (axis < 0) { axis += GetBlobDesc4BnInOp("in")->shape().NumAxes(); }
-  return axis;
+  const ReduceSumOpConf& conf = op_conf().reduce_sum_conf();
+  const BlobDesc* in_blob = GetBlobDesc4BnInOp("in");
+  std::vector<int64_t> kept_dims;
+  if (conf.axis().empty()) {
+    kept_dims.resize(in_blob->shape().NumAxes());
+    std::fill(kept_dims.begin(), kept_dims.end(), 1);
+  } else {
+    const PbRf<int32_t>& axis_repeated = op_conf().reduce_sum_conf().axis();
+    std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()};
+    kept_dims = KeepDims(in_blob->shape().dim_vec(),
+                         ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes()));
+  }
+  Shape(kept_dims).ToProto(kernel_conf->mutable_reduce_sum_conf()->mutable_kept_dims_shape());
 }
 
 REGISTER_OP(OperatorConf::kReduceSumConf, ReduceSumOp);
diff --git a/oneflow/core/operator/reduce_sum_op.h b/oneflow/core/operator/reduce_sum_op.h
index 0a43b4fe52e7c523ccb8fa3efb516f1a2db508a9..b8f0276653d7d9cc97d0e724a74d11b00d12442f 100644
--- a/oneflow/core/operator/reduce_sum_op.h
+++ b/oneflow/core/operator/reduce_sum_op.h
@@ -13,23 +13,24 @@ class ReduceSumOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext* parallel_ctx,
                             KernelConf* kernel_conf) const override;
-  int32_t GetCorrectAxis(
-      std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp) const;
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override {
     const ReduceSumOpConf& conf = op_conf().reduce_sum_conf();
     if (conf.has_in_sys()) {
-      CHECK_EQ(conf.axis(), 0);
+      CHECK_EQ(conf.axis_size(), 1);
+      CHECK_EQ(conf.axis().Get(0), 0);
       return conf.in_sys();
     } else if (conf.has_in()) {
-      CHECK_GE(conf.axis(), 1);
       return GenLogicalBlobId(conf.in());
     } else {
       UNIMPLEMENTED();
diff --git a/oneflow/core/operator/relu_op.h b/oneflow/core/operator/relu_op.h
index 6a8f82025482b6d527dd7a73ecad119d6911a434..b3e64b543c55483fee6657c5f8abd665fe849c26 100644
--- a/oneflow/core/operator/relu_op.h
+++ b/oneflow/core/operator/relu_op.h
@@ -13,13 +13,13 @@ class ReluOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-  bool IsElemWiseOp() const override { return true; }
   bool NeedInBlobWhenBackward() const override { return false; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/repeat_op.cpp b/oneflow/core/operator/repeat_op.cpp
index 5a1e396a804d9b16364f7076ad179cbb804fab0a..5721a994e7eb5ed552d229ff68f42001065e936c 100644
--- a/oneflow/core/operator/repeat_op.cpp
+++ b/oneflow/core/operator/repeat_op.cpp
@@ -6,31 +6,24 @@ namespace oneflow {
 void oneflow::RepeatOp::InitFromOpConf() {
   CHECK(op_conf().has_repeat_conf());
   const RepeatOpConf& conf = op_conf().repeat_conf();
-  if (conf.has_repeat_num()) {
-    CHECK_GE(conf.repeat_num(), 1);
-  } else if (conf.has_repeat_num_per_record()) {
-    CHECK_GE(conf.repeat_num_per_record(), 1);
-  } else {
-    UNIMPLEMENTED();
-  }
+  CHECK_GE(conf.repeat_num(), 1);
   EnrollInputBn("in");
   EnrollOutputBn("out");
 }
 
-int32_t RepeatOp::GetRepeatNum(int64_t parallel_num) const {
+void RepeatOp::InferOutputBlobTimeShape(
+    std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+    const ParallelContext* parallel_ctx, Shape* time_shape) const {
+  std::vector<int64_t> dim_vec(GetTimeShape4BnInOp("in")->dim_vec());
+  int32_t repeat_num = GetRepeatNum();
+  dim_vec.push_back(repeat_num);
+  *time_shape = Shape(dim_vec);
+}
+
+int32_t RepeatOp::GetRepeatNum() const {
   CHECK(op_conf().has_repeat_conf());
   const RepeatOpConf& conf = op_conf().repeat_conf();
-  if (conf.has_repeat_num()) {
-    return conf.repeat_num();
-  } else if (conf.has_repeat_num_per_record()) {
-    CHECK_EQ(Global<JobDesc>::Get()->PieceSize() % parallel_num, 0);
-    int64_t repeat_num =
-        Global<JobDesc>::Get()->PieceSize() / parallel_num * conf.repeat_num_per_record();
-    CHECK_LE(repeat_num, static_cast<int64_t>(MaxVal<int32_t>()));
-    return static_cast<int32_t>(repeat_num);
-  } else {
-    UNIMPLEMENTED();
-  }
+  return conf.repeat_num();
 }
 
 const PbMessage& RepeatOp::GetCustomizedConf() const { return op_conf().repeat_conf(); }
diff --git a/oneflow/core/operator/repeat_op.h b/oneflow/core/operator/repeat_op.h
index 3635aaebef4fe8a05dd327f6aad8f96c50390eef..d7b0b27b48d28d551720aeea19d21e4fbfe28b7b 100644
--- a/oneflow/core/operator/repeat_op.h
+++ b/oneflow/core/operator/repeat_op.h
@@ -11,12 +11,17 @@ class RepeatOp final : public Operator {
   RepeatOp() = default;
   ~RepeatOp() override = default;
 
-  int32_t GetRepeatNum(int64_t parallel_num) const;
+  int32_t GetRepeatNum() const;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
   const PbMessage& GetCustomizedConf() const override;
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+  void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+                                const ParallelContext* parallel_ctx,
+                                Shape* time_shape) const override;
+
   void InferDiffBlobDescsWithoutFwBlob(
       std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
       const ParallelContext*) const override;
diff --git a/oneflow/core/operator/reshape_op.cpp b/oneflow/core/operator/reshape_op.cpp
index cae8727816de3b2955fdb2d8224292804ceca5a5..66395f865c0768347fe36d302ba59050d0bf1155 100644
--- a/oneflow/core/operator/reshape_op.cpp
+++ b/oneflow/core/operator/reshape_op.cpp
@@ -17,9 +17,25 @@ void ReshapeOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetB
   *out_blob_desc = *in_blob_desc;
 
   const ReshapeOpConf& conf = op_conf().reshape_conf();
-  std::vector<int64_t> dim_vec(1 + conf.shape().dim_size());
-  dim_vec[0] = in_blob_desc->shape().At(0);
-  FOR_RANGE(size_t, i, 1, dim_vec.size()) { dim_vec[i] = conf.shape().dim(i - 1); }
+
+  std::vector<int64_t> dim_vec;
+  if (!conf.has_dim0_in_shape()) { dim_vec.push_back(in_blob_desc->shape().At(0)); }
+  for (int32_t i = 0; i < conf.shape().dim_size(); ++i) { dim_vec.push_back(conf.shape().dim(i)); }
+  int32_t dim_cnt_need_infer = 0;
+  int32_t dim_index_need_infer = -1;
+  int64_t elem_cnt = 1;
+  for (int32_t i = 0; i < dim_vec.size(); ++i) {
+    if (dim_vec[i] == -1) {
+      ++dim_cnt_need_infer;
+      dim_index_need_infer = i;
+    } else {
+      elem_cnt *= dim_vec[i];
+    }
+  }
+  CHECK_LE(dim_cnt_need_infer, 1);
+  if (dim_cnt_need_infer == 1) {
+    dim_vec[dim_index_need_infer] = in_blob_desc->shape().elem_cnt() / elem_cnt;
+  }
   out_blob_desc->mut_shape() = Shape(dim_vec);
   CHECK_EQ(out_blob_desc->shape().elem_cnt(), in_blob_desc->shape().elem_cnt());
 }
diff --git a/oneflow/core/operator/reshape_op.h b/oneflow/core/operator/reshape_op.h
index 53bdb3ab279f15a496706ac510e0837f941ef29b..8c3b3dc111125382c1c1d56f61844340a7a0ff64 100644
--- a/oneflow/core/operator/reshape_op.h
+++ b/oneflow/core/operator/reshape_op.h
@@ -13,12 +13,16 @@ class ReshapeOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-  bool IsElemWiseOp() const override { return true; }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  bool IsForwardInplace() const override { return true; }
+  bool IsBackwardInplace() const override { return true; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/rmsprop_model_update_op.cpp b/oneflow/core/operator/rmsprop_model_update_op.cpp
index c9702acec5c72b365ffef6abd0b583753b791c11..80c7ebfddd10d24c8a69b1ddca2a167d8c0a9dd2 100644
--- a/oneflow/core/operator/rmsprop_model_update_op.cpp
+++ b/oneflow/core/operator/rmsprop_model_update_op.cpp
@@ -4,7 +4,7 @@ namespace oneflow {
 
 void RMSPropModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollDataTmpBn("mean_square"); }
 
-void RMSPropModelUpdateOp::InferBlobDescs(
+void RMSPropModelUpdateOp::MdUpdtVirtualInferBlobDescs(
     std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
     const ParallelContext* parallel_ctx) const {
   const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model");
diff --git a/oneflow/core/operator/rmsprop_model_update_op.h b/oneflow/core/operator/rmsprop_model_update_op.h
index 52c80395ae9ee607be1972688ebe078be3469e24..f2deced85ecc5f20ded3d73a355e62188f0cbb86 100644
--- a/oneflow/core/operator/rmsprop_model_update_op.h
+++ b/oneflow/core/operator/rmsprop_model_update_op.h
@@ -11,11 +11,10 @@ class RMSPropModelUpdateOp final : public NormalModelUpdtOp {
   RMSPropModelUpdateOp() = default;
   ~RMSPropModelUpdateOp() = default;
 
-  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext* parallel_ctx) const override;
-
  private:
   void MdUpdtVirtualInitFromOpConf() override;
+  void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                   const ParallelContext* parallel_ctx) const override;
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/rsqrt_op.cpp b/oneflow/core/operator/rsqrt_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8a135b7f51cee26e5914fd520d8e867520a7e702
--- /dev/null
+++ b/oneflow/core/operator/rsqrt_op.cpp
@@ -0,0 +1,20 @@
+#include "oneflow/core/operator/rsqrt_op.h"
+
+namespace oneflow {
+
+void RsqrtOp::InitFromOpConf() {
+  CHECK(op_conf().has_rsqrt_conf());
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+const PbMessage& RsqrtOp::GetCustomizedConf() const { return op_conf().rsqrt_conf(); }
+
+void RsqrtOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                             const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_OP(OperatorConf::kRsqrtConf, RsqrtOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/rsqrt_op.h b/oneflow/core/operator/rsqrt_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..c5d9a09ed70b3c3ad9f4f075d3a2569270d39706
--- /dev/null
+++ b/oneflow/core/operator/rsqrt_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_RSQRT_H_
+#define ONEFLOW_CORE_OPERATOR_RSQRT_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class RsqrtOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(RsqrtOp);
+  RsqrtOp() = default;
+  ~RsqrtOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_RSQRT_H_
diff --git a/oneflow/core/operator/scalar_add_op.cpp b/oneflow/core/operator/scalar_add_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8cac8b8203d5917a1eca90f88d6a86206343c106
--- /dev/null
+++ b/oneflow/core/operator/scalar_add_op.cpp
@@ -0,0 +1,17 @@
+#include "oneflow/core/operator/scalar_add_op.h"
+
+namespace oneflow {
+
+void ScalarAddOp::InitFromOpConf() {
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+void ScalarAddOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                 const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_OP(OperatorConf::kScalarAddConf, ScalarAddOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/scalar_add_op.h b/oneflow/core/operator/scalar_add_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..12f52b53f712583f0caadeffe8b0e138aa9a88c3
--- /dev/null
+++ b/oneflow/core/operator/scalar_add_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_ADD_OP_H_
+#define ONEFLOW_CORE_OPERATOR_SCALAR_ADD_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class ScalarAddOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ScalarAddOp);
+  ScalarAddOp() = default;
+  ~ScalarAddOp() = default;
+
+  void InitFromOpConf() override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().scalar_add_conf(); }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SCALAR_ADD_OP_H_
diff --git a/oneflow/core/operator/scalar_mul_op.cpp b/oneflow/core/operator/scalar_mul_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..175d49a5a173c5b9fa7525d95907c9ee373d79a7
--- /dev/null
+++ b/oneflow/core/operator/scalar_mul_op.cpp
@@ -0,0 +1,17 @@
+#include "oneflow/core/operator/scalar_mul_op.h"
+
+namespace oneflow {
+
+void ScalarMulOp::InitFromOpConf() {
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+void ScalarMulOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                 const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_OP(OperatorConf::kScalarMulConf, ScalarMulOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/scalar_mul_op.h b/oneflow/core/operator/scalar_mul_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..fc0126135aa1bdf7c052b7dae43a5c78802b5653
--- /dev/null
+++ b/oneflow/core/operator/scalar_mul_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
+#define ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class ScalarMulOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ScalarMulOp);
+  ScalarMulOp() = default;
+  ~ScalarMulOp() = default;
+
+  void InitFromOpConf() override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().scalar_mul_conf(); }
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
diff --git a/oneflow/core/operator/shared_model_diff_add_op.h b/oneflow/core/operator/shared_model_diff_add_op.h
index 61d6b36a49db76f07774f478a1bb4a86d84369b6..a95b61bb3018327cdc093696efcb88a9da287ab6 100644
--- a/oneflow/core/operator/shared_model_diff_add_op.h
+++ b/oneflow/core/operator/shared_model_diff_add_op.h
@@ -17,6 +17,10 @@ class SharedModelDiffAddOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override {
+    return UNIMPLEMENTED();
+  }
+
   LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
   LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
 };
diff --git a/oneflow/core/operator/sigmoid_cross_entropy_loss_op.cpp b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..adb5a828747208e5537e74f925ce54b7deb4b90b
--- /dev/null
+++ b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.cpp
@@ -0,0 +1,61 @@
+#include "oneflow/core/operator/sigmoid_cross_entropy_loss_op.h"
+#include "oneflow/core/kernel/kernel_util.h"
+
+namespace oneflow {
+
+void SigmoidCrossEntropyLossOp::VirtualInitFromOpConf() {
+  EnrollDataTmpBn("count");
+  EnrollDataTmpBn("label_num");
+  EnrollDataTmpBn("loss_buf");
+  EnrollDataTmpBn("sum_buf");
+}
+
+const PbMessage& SigmoidCrossEntropyLossOp::GetCustomizedConf() const {
+  return op_conf().sigmoid_cross_entropy_loss_conf();
+}
+
+LossKernelConf* SigmoidCrossEntropyLossOp::GetMutLossKernelConf(KernelConf* kernel_conf) const {
+  return kernel_conf->mutable_sigmoid_cross_entropy_loss_conf()->mutable_loss_conf();
+}
+
+void SigmoidCrossEntropyLossOp::VirtualInferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  // label
+  const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
+  // a label must be in {-1, 0, 1} while -1 indicates ignorance
+  CHECK_GE(label_blob_desc->shape().NumAxes(), 2);
+  // prediction
+  const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
+  CHECK_EQ(pred_blob_desc->shape(), label_blob_desc->shape());
+  CHECK_GE(pred_blob_desc->shape().NumAxes(), 2);
+
+  int64_t data_num = pred_blob_desc->shape().At(0);
+  int64_t data_dim = pred_blob_desc->shape().Count(1);
+
+  // loss
+  BlobDesc* loss_blob_desc = GetBlobDesc4BnInOp("loss");
+  loss_blob_desc->mut_shape() = Shape({data_num});
+  loss_blob_desc->set_data_type(pred_blob_desc->data_type());
+  // count
+  BlobDesc* count_blob_desc = GetBlobDesc4BnInOp("count");
+  count_blob_desc->mut_shape() = Shape({data_dim});
+  count_blob_desc->set_data_type(pred_blob_desc->data_type());
+  // loss_buf
+  BlobDesc* loss_buf_desc = GetBlobDesc4BnInOp("loss_buf");
+  loss_buf_desc->mut_shape() = Shape({data_dim});
+  loss_buf_desc->set_data_type(pred_blob_desc->data_type());
+  // label_num
+  BlobDesc* label_num_blob_desc = GetBlobDesc4BnInOp("label_num");
+  label_num_blob_desc->mut_shape() = Shape({1});
+  label_num_blob_desc->set_data_type(pred_blob_desc->data_type());
+  // sum buf
+  BlobDesc* sum_buf_blob_desc = GetBlobDesc4BnInOp("sum_buf");
+  const int64_t sum_buf_size = GetTmpSizeForReduceSum(pred_blob_desc->data_type(), data_dim);
+  sum_buf_blob_desc->mut_shape() = Shape({sum_buf_size});
+  sum_buf_blob_desc->set_data_type(DataType::kChar);
+}
+
+REGISTER_OP(OperatorConf::kSigmoidCrossEntropyLossConf, SigmoidCrossEntropyLossOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/sigmoid_cross_entropy_loss_op.h b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..33cffad76301f51fe9037f71045862f304fb1cb0
--- /dev/null
+++ b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
+#define ONEFLOW_CORE_OPERATOR_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
+
+#include "oneflow/core/operator/loss_op.h"
+
+namespace oneflow {
+
+class SigmoidCrossEntropyLossOp final : public LossOp {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SigmoidCrossEntropyLossOp);
+  SigmoidCrossEntropyLossOp() = default;
+  ~SigmoidCrossEntropyLossOp() = default;
+
+  const PbMessage& GetCustomizedConf() const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
+  void VirtualInitFromOpConf() override;
+  void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                             const ParallelContext* parallel_ctx) const override;
+  LossKernelConf* GetMutLossKernelConf(KernelConf*) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
diff --git a/oneflow/core/operator/sigmoid_op.h b/oneflow/core/operator/sigmoid_op.h
index e206b0876c44b2ae86d9393c37525d4af7256543..e98707e03a1ba097810c0b773f21b66f7d4c0a71 100644
--- a/oneflow/core/operator/sigmoid_op.h
+++ b/oneflow/core/operator/sigmoid_op.h
@@ -13,13 +13,13 @@ class SigmoidOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-  bool IsElemWiseOp() const override { return true; }
   bool NeedInBlobWhenBackward() const override { return false; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/slice_op.cpp b/oneflow/core/operator/slice_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6442ba3ed41fe09ecd99e712fc528f1cb0210205
--- /dev/null
+++ b/oneflow/core/operator/slice_op.cpp
@@ -0,0 +1,61 @@
+#include "oneflow/core/operator/slice_op.h"
+
+namespace oneflow {
+
+void SliceOp::InitFromOpConf() {
+  CHECK(op_conf().has_slice_conf());
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+  if (op_conf().device_type() == DeviceType::kGPU) { EnrollConstBufBn("out_to_in_offset"); }
+}
+
+const PbMessage& SliceOp::GetCustomizedConf() const { return op_conf().slice_conf(); }
+
+void SliceOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
+  const Shape& in_shape = GetBlobDesc4BnInOp("in")->shape();
+  in_shape.ToProto(kernel_conf->mutable_slice_conf()->mutable_in_shape());
+}
+
+bool SliceOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const {
+  CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end());
+  return ibn == "in";
+}
+
+void SliceOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                             const ParallelContext* parallel_ctx) const {
+  const SliceOpConf& conf = op_conf().slice_conf();
+  const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
+  CHECK_EQ(conf.dim_slice_conf_size(), in_blob_desc->shape().NumAxes() - 1);
+  std::vector<int64_t> shape_vec(in_blob_desc->shape().NumAxes());
+  shape_vec[0] = in_blob_desc->shape().At(0);
+  FOR_RANGE(size_t, i, 0, conf.dim_slice_conf_size()) {
+    int32_t dim_len = in_blob_desc->shape().At(i + 1);
+    const DimSliceConf& dim_slice_conf = conf.dim_slice_conf(i);
+    int32_t step = dim_slice_conf.stride();
+    CHECK_GT(step, 0);
+    int32_t start = dim_slice_conf.has_start() ? dim_slice_conf.start() : 0;
+    int32_t end = dim_slice_conf.has_end() ? dim_slice_conf.end() : dim_len;
+    if (start < 0) { start += dim_len; }
+    if (end < 0) { end += dim_len; }
+    CHECK_GE(start, 0);
+    CHECK_LT(start, end);
+    CHECK_LE(end, dim_len);
+    shape_vec[i + 1] = (end - start - 1) / std::abs(step) + 1;
+  }
+
+  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
+  *out_blob_desc = *in_blob_desc;
+  out_blob_desc->mut_shape() = Shape(shape_vec);
+
+  BlobDesc* offset_blob_desc = GetBlobDesc4BnInOp("out_to_in_offset");
+  if (offset_blob_desc) {
+    *offset_blob_desc = *out_blob_desc;
+    offset_blob_desc->set_data_type(DataType::kInt64);
+  }
+}
+
+REGISTER_OP(OperatorConf::kSliceConf, SliceOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/slice_op.h b/oneflow/core/operator/slice_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..bb15463cea499d17412cd425fec769a96f6c4d50
--- /dev/null
+++ b/oneflow/core/operator/slice_op.h
@@ -0,0 +1,28 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SLICE_OP_H_
+#define ONEFLOW_CORE_OPERATOR_SLICE_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class SliceOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SliceOp);
+  SliceOp() = default;
+  ~SliceOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx,
+                            KernelConf* kernel_conf) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SLICE_OP_H_
diff --git a/oneflow/core/operator/softmax_op.cpp b/oneflow/core/operator/softmax_op.cpp
index b2b3152c66835e519f5bd6923ed69471e3b20213..5a59e9e4bcf2a911c35c4ff258f6f51c8b72d9d9 100644
--- a/oneflow/core/operator/softmax_op.cpp
+++ b/oneflow/core/operator/softmax_op.cpp
@@ -20,7 +20,7 @@ void SoftmaxOp::InitFromOpConf() {
 const PbMessage& SoftmaxOp::GetCustomizedConf() const { return op_conf().softmax_conf(); }
 
 void SoftmaxOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                               const ParallelContext* parallel_ctx,
+                               const ParallelContext* parallel_ctx, int64_t record_piece_size,
                                std::function<void(OpContext*)> EnrollOpCtx) const {
   // in
   const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
diff --git a/oneflow/core/operator/softmax_op.h b/oneflow/core/operator/softmax_op.h
index a3303a58aeb9b47ffcf2e0743a0073fd4aabfabb..d4295a576f727596dfd4e90ebc76f703e8f550c7 100644
--- a/oneflow/core/operator/softmax_op.h
+++ b/oneflow/core/operator/softmax_op.h
@@ -23,12 +23,14 @@ class SoftmaxOp final : public Operator {
   const PbMessage& GetCustomizedConf() const override;
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
-                      const ParallelContext*,
+                      const ParallelContext* parallel_ctx, int64_t record_piece_size,
                       std::function<void(OpContext*)> EnrollOpCtx) const override;
   void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                            const ParallelContext*, const OpContext*) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext*, KernelConf*, const OpContext*) const override;
   SoftmaxOpCtx* NewSoftmaxOpCtx(const Shape& in_shape) const;
diff --git a/oneflow/core/operator/sparse_cross_entropy_loss_op.h b/oneflow/core/operator/sparse_cross_entropy_loss_op.h
index 716aac25f7047b484af04d20e4fd5f983f118af2..7fb10aa9661aaca5ae18c769b380d35a6af5f53f 100644
--- a/oneflow/core/operator/sparse_cross_entropy_loss_op.h
+++ b/oneflow/core/operator/sparse_cross_entropy_loss_op.h
@@ -14,6 +14,8 @@ class SparseCrossEntropyLossOp final : public LossOp {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   LossKernelConf* GetMutLossKernelConf(KernelConf*) const override;
 };
 
diff --git a/oneflow/core/operator/sparse_cross_entropy_op.cpp b/oneflow/core/operator/sparse_cross_entropy_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..833ce7863e79a7c481f435922922e2d4ff9c7684
--- /dev/null
+++ b/oneflow/core/operator/sparse_cross_entropy_op.cpp
@@ -0,0 +1,46 @@
+#include "oneflow/core/operator/sparse_cross_entropy_op.h"
+
+namespace oneflow {
+
+void SparseCrossEntropyOp::InitFromOpConf() {
+  CHECK(op_conf().has_sparse_cross_entropy_conf());
+  EnrollInputBn("prediction");
+  EnrollInputBn("label", false);
+  EnrollOutputBn("out");
+}
+
+const PbMessage& SparseCrossEntropyOp::GetCustomizedConf() const {
+  return op_conf().sparse_cross_entropy_conf();
+}
+
+void SparseCrossEntropyOp::InferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx, int64_t record_piece_size,
+    std::function<void(OpContext*)> EnrollOpCtx) const {
+  const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
+  const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label");
+  CHECK(IsIntegralDataType(label_blob_desc->data_type()));
+  CHECK(IsFloatingDataType(pred_blob_desc->data_type()));
+  CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field());
+  CHECK_EQ(pred_blob_desc->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field());
+  CHECK_EQ(pred_blob_desc->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape());
+  if (pred_blob_desc->has_dim0_inner_shape()) {
+    CHECK_EQ(pred_blob_desc->dim0_inner_shape().At(0), 1);
+    CHECK_EQ(pred_blob_desc->dim0_inner_shape(), label_blob_desc->dim0_inner_shape());
+  }
+  CHECK_GE(pred_blob_desc->shape().NumAxes(), 2);
+  const int64_t num_out_axes = pred_blob_desc->shape().NumAxes() - 1;
+  CHECK_GE(label_blob_desc->shape().NumAxes(), num_out_axes);
+  CHECK_EQ(label_blob_desc->shape().Count(num_out_axes), 1);
+  FOR_RANGE(int64_t, i, 0, num_out_axes) {
+    CHECK_EQ(pred_blob_desc->shape().At(i), label_blob_desc->shape().At(i));
+  }
+  BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
+  *out_blob_desc = *pred_blob_desc;
+  out_blob_desc->mut_shape() = Shape(std::vector<int64_t>(
+      pred_blob_desc->shape().dim_vec().cbegin(), pred_blob_desc->shape().dim_vec().cend() - 1));
+}
+
+REGISTER_OP(OperatorConf::kSparseCrossEntropyConf, SparseCrossEntropyOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/sparse_cross_entropy_op.h b/oneflow/core/operator/sparse_cross_entropy_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..6e188b22059d8a4ad866df5bf4a2cd9331efe5b4
--- /dev/null
+++ b/oneflow/core/operator/sparse_cross_entropy_op.h
@@ -0,0 +1,28 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SPARSE_CROSS_ENTROPY_OP_H_
+#define ONEFLOW_CORE_OPERATOR_SPARSE_CROSS_ENTROPY_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class SparseCrossEntropyOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SparseCrossEntropyOp);
+  SparseCrossEntropyOp() = default;
+  ~SparseCrossEntropyOp() override = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx, int64_t record_piece_size,
+                      std::function<void(OpContext*)> EnrollOpCtx) const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  bool NeedInBlobWhenBackward() const override { return true; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SPARSE_CROSS_ENTROPY_OP_H_
diff --git a/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h b/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h
index 2836a7f6182eaea4bbb872f88a5b9173d44d451b..b8f510746d700560ce46dfa8e664c787f9315b7c 100644
--- a/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h
+++ b/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h
@@ -14,6 +14,8 @@ class SparseSoftmaxCrossEntropyLossOp final : public LossOp {
   const PbMessage& GetCustomizedConf() const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+
   void VirtualInitFromOpConf() override;
   void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                              const ParallelContext* parallel_ctx) const override;
diff --git a/oneflow/core/operator/sqrt_op.cpp b/oneflow/core/operator/sqrt_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3f19e25a0c65ec6bb7394ecc87d4eaef5d4e98e0
--- /dev/null
+++ b/oneflow/core/operator/sqrt_op.cpp
@@ -0,0 +1,20 @@
+#include "oneflow/core/operator/sqrt_op.h"
+
+namespace oneflow {
+
+void SqrtOp::InitFromOpConf() {
+  CHECK(op_conf().has_sqrt_conf());
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+const PbMessage& SqrtOp::GetCustomizedConf() const { return op_conf().sqrt_conf(); }
+
+void SqrtOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_OP(OperatorConf::kSqrtConf, SqrtOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/sqrt_op.h b/oneflow/core/operator/sqrt_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..70904865234a138235192b4261c996db250985fb
--- /dev/null
+++ b/oneflow/core/operator/sqrt_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SQRT_H_
+#define ONEFLOW_CORE_OPERATOR_SQRT_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class SqrtOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SqrtOp);
+  SqrtOp() = default;
+  ~SqrtOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SQRT_H_
diff --git a/oneflow/core/operator/square_op.cpp b/oneflow/core/operator/square_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..77afab0f8763e7142d0999cbaae0297491bb6d22
--- /dev/null
+++ b/oneflow/core/operator/square_op.cpp
@@ -0,0 +1,20 @@
+#include "oneflow/core/operator/square_op.h"
+
+namespace oneflow {
+
+void SquareOp::InitFromOpConf() {
+  CHECK(op_conf().has_square_conf());
+  EnrollInputBn("in");
+  EnrollOutputBn("out");
+}
+
+const PbMessage& SquareOp::GetCustomizedConf() const { return op_conf().square_conf(); }
+
+void SquareOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                              const ParallelContext* parallel_ctx) const {
+  *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
+}
+
+REGISTER_OP(OperatorConf::kSquareConf, SquareOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/square_op.h b/oneflow/core/operator/square_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..7af18ce16184a02ce7552135a2b581c0ec3d3db6
--- /dev/null
+++ b/oneflow/core/operator/square_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_SQUARE_H_
+#define ONEFLOW_CORE_OPERATOR_SQUARE_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class SquareOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(SquareOp);
+  SquareOp() = default;
+  ~SquareOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_SQUARE_H_
diff --git a/oneflow/core/operator/tanh_op.h b/oneflow/core/operator/tanh_op.h
index 2d402b8452fda1dd772c0111f2a402f753552a48..b0b8deaab159905931ebec754268432b72e54144 100644
--- a/oneflow/core/operator/tanh_op.h
+++ b/oneflow/core/operator/tanh_op.h
@@ -13,13 +13,13 @@ class TanHOp final : public Operator {
 
   void InitFromOpConf() override;
   const PbMessage& GetCustomizedConf() const override;
-  bool IsElemWiseOp() const override { return true; }
   bool NeedInBlobWhenBackward() const override { return false; }
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/tick_op.cpp b/oneflow/core/operator/tick_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a79b3ed7e43dd6e0a80cb11f25fa2c0eb1ffc837
--- /dev/null
+++ b/oneflow/core/operator/tick_op.cpp
@@ -0,0 +1,20 @@
+#include "oneflow/core/operator/tick_op.h"
+
+namespace oneflow {
+
+void TickOp::InitFromOpConf() {
+  EnrollInputBn("in", false);
+  EnrollOutputBn("out", false);
+}
+
+void TickOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx) const {
+  const BlobDesc* in = GetBlobDesc4BnInOp("in");
+  BlobDesc* out = GetBlobDesc4BnInOp("out");
+  *out = *in;
+  out->mut_shape() = Shape({1});
+}
+
+REGISTER_CPU_OP(OperatorConf::kTickConf, TickOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/tick_op.h b/oneflow/core/operator/tick_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..8f37fe8818cab8ecdfa6f10c628b1d8481353edc
--- /dev/null
+++ b/oneflow/core/operator/tick_op.h
@@ -0,0 +1,25 @@
+#ifndef ONEFLOW_CORE_OPERATOR_TICK_OP_H_
+#define ONEFLOW_CORE_OPERATOR_TICK_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class TickOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TickOp);
+  TickOp() = default;
+  ~TickOp() = default;
+
+  void InitFromOpConf() override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().tick_conf(); }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_TICK_OP_H_
diff --git a/oneflow/core/operator/top_k_op.cpp b/oneflow/core/operator/top_k_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..78a5d33d4cecfa3a0c30e929a677a9e6d21deaf4
--- /dev/null
+++ b/oneflow/core/operator/top_k_op.cpp
@@ -0,0 +1,38 @@
+#include "oneflow/core/operator/top_k_op.h"
+
+namespace oneflow {
+
+void TopKOp::InitFromOpConf() {
+  CHECK(op_conf().has_top_k_conf());
+  EnrollInputBn("in", false);
+  EnrollFwBufBn("fw_buf");
+  EnrollOutputBn("out", false);
+}
+
+void TopKOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext* parallel_ctx) const {
+  const BlobDesc* in = GetBlobDesc4BnInOp("in");
+  CHECK_LE(in->shape().elem_cnt(), GetMaxVal<int32_t>());
+  const TopKOpConf& conf = op_conf().top_k_conf();
+  CHECK_GE(conf.k(), 1);
+  CHECK_LE(conf.k(), in->shape().dim_vec().back());
+  // fw_buf
+  BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
+  fw_buf->mut_shape() = Shape({in->shape().dim_vec().back()});
+  fw_buf->set_data_type(DataType::kInt32);
+  // out
+  BlobDesc* out = GetBlobDesc4BnInOp("out");
+  *out = *in;
+  out->mut_shape().Set(in->shape().NumAxes() - 1, conf.k());
+  out->set_data_type(DataType::kInt32);
+}
+
+void TopKOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
+    KernelConf* kernel_conf) const {
+  kernel_conf->set_data_type(GetBlobDesc4BnInOp("in")->data_type());
+}
+
+REGISTER_CPU_OP(OperatorConf::kTopKConf, TopKOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/top_k_op.h b/oneflow/core/operator/top_k_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..31d5eb5dc48eb0355ead7d0425922df741b0e5a4
--- /dev/null
+++ b/oneflow/core/operator/top_k_op.h
@@ -0,0 +1,27 @@
+#ifndef ONEFLOW_CORE_OPERATOR_TOP_K_OP_H_
+#define ONEFLOW_CORE_OPERATOR_TOP_K_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class TopKOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TopKOp);
+  TopKOp() = default;
+  ~TopKOp() override = default;
+
+  void InitFromOpConf() override;
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  const PbMessage& GetCustomizedConf() const override { return op_conf().top_k_conf(); }
+
+ private:
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext*, KernelConf*) const override;
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_TOP_K_OP_H_
diff --git a/oneflow/core/operator/transpose_op.cpp b/oneflow/core/operator/transpose_op.cpp
index f553d3451500a4a6d42484f6ea74868bdefb25f2..df40d757dbf8f3a252a053082df3ec7a02657ba6 100644
--- a/oneflow/core/operator/transpose_op.cpp
+++ b/oneflow/core/operator/transpose_op.cpp
@@ -5,12 +5,12 @@ namespace oneflow {
 namespace {
 
 void CheckIsPerm(const PbRf<int32_t>& perm) {
-  std::vector<bool> is_used(perm.size(), 0);
+  std::vector<bool> is_used(perm.size(), false);
   FOR_RANGE(size_t, i, 0, perm.size()) {
-    CHECK_GE(perm[i], 1);
+    CHECK_GE(perm[i], 0);
     CHECK_LE(perm[i], perm.size());
-    CHECK_EQ(is_used[perm[i] - 1], false);
-    is_used[perm[i] - 1] = true;
+    CHECK_EQ(is_used[perm[i]], false);
+    is_used[perm[i]] = true;
   }
 }
 
@@ -29,13 +29,44 @@ void TransposeOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> Ge
   const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
   const Shape& in_blob_shape = in_blob_desc->shape();
   const PbRf<int32_t>& perm = op_conf().transpose_conf().perm();
-  CHECK_EQ(perm.size(), in_blob_shape.NumAxes() - 1);
+  CHECK_EQ(perm.size(), in_blob_shape.NumAxes());
   CheckIsPerm(perm);
+  if (perm.Get(0) != 0) {
+    CHECK(!in_blob_desc->has_dim0_valid_num_field());
+  } else if (perm.size() >= 2 && perm.Get(1) != 1) {
+    CHECK(!in_blob_desc->has_dim1_valid_num_field());
+  } else if (perm.size() >= 3 && perm.Get(2) != 2) {
+    CHECK(!in_blob_desc->has_dim2_valid_num_field());
+  } else {
+    // do nothing
+  }
   BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
   *out_blob_desc = *in_blob_desc;
   FOR_RANGE(size_t, i, 0, perm.size()) {
-    out_blob_desc->mut_shape().Set(i + 1, in_blob_shape.At(perm[i]));
+    out_blob_desc->mut_shape().Set(i, in_blob_shape.At(perm[i]));
+  }
+}
+
+int32_t TransposeOp::OutputBlobModelSplitAxis(
+    const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+    const std::string& obn) const {
+  const auto& in_sbp_infer_hint = SbpInferHint4Ibn("in");
+  const PbRf<int32_t>& perm = op_conf().transpose_conf().perm();
+  CHECK_GT(perm.size(), 0);
+  CHECK_EQ(perm.size(), in_sbp_infer_hint.num_axes());
+  int32_t split_axis = -1;
+  FOR_RANGE(int32_t, i, 0, perm.size()) {
+    if (perm[i] == in_sbp_infer_hint.split_axis()) {
+      split_axis = i;
+      break;
+    }
+  }
+  CHECK_NE(split_axis, -1);
+  if (in_sbp_infer_hint.is_data_blob()) {
+    CHECK_GT(in_sbp_infer_hint.split_axis(), 0);
+    CHECK_GT(split_axis, 0);
   }
+  return split_axis;
 }
 
 void TransposeOp::VirtualGenKernelConf(
@@ -43,9 +74,8 @@ void TransposeOp::VirtualGenKernelConf(
     const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
   const PbRf<int32_t>& src_perm = op_conf().transpose_conf().perm();
   PbRf<int32_t>* perm = kernel_conf->mutable_transpose_conf()->mutable_perm();
-  perm->Add(0);
-  perm->MergeFrom(src_perm);
-  CHECK_EQ(perm->size(), src_perm.size() + 1);
+  *perm = src_perm;
+  CHECK_EQ(perm->size(), src_perm.size());
   PbRf<int32_t>* invert_perm = kernel_conf->mutable_transpose_conf()->mutable_invert_perm();
   invert_perm->Reserve(perm->size());
   invert_perm->CopyFrom(*perm);
diff --git a/oneflow/core/operator/transpose_op.h b/oneflow/core/operator/transpose_op.h
index 889062c114f3345b5d458ad7d5368f0d28d6d196..4f37ed3c4fbdf8ecffdfe2b1194ea0f911e436d7 100644
--- a/oneflow/core/operator/transpose_op.h
+++ b/oneflow/core/operator/transpose_op.h
@@ -18,8 +18,13 @@ class TransposeOp final : public Operator {
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override;
 
  private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; }
+
   void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                             const ParallelContext*, KernelConf*) const override;
 };
diff --git a/oneflow/core/operator/tuple_identity_op.cpp b/oneflow/core/operator/tuple_identity_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2ff4a0ec0150b86082704e9e2b27c4c6b401474c
--- /dev/null
+++ b/oneflow/core/operator/tuple_identity_op.cpp
@@ -0,0 +1,70 @@
+#include "oneflow/core/operator/tuple_identity_op.h"
+
+namespace oneflow {
+
+namespace {
+
+class TupleIdentityOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TupleIdentityOpParallelSignature);
+  ~TupleIdentityOpParallelSignature() override = default;
+
+  TupleIdentityOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": A -> A"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp,
+      const ParallelContext* parallel_ctx) const override {
+    const auto& ibn = op().input_bns().Get(0);
+    if (parallel_ctx->parallel_num() != SbpInferHint4BnInOp(ibn).parallel_num()) {
+      return MakeOpParallelMatchParallelNumError(parallel_ctx->parallel_num(),
+                                                 SbpInferHint4BnInOp(ibn).parallel_num());
+    }
+    return MakeOpParallelMatchSuccess();
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    FOR_RANGE(int32_t, i, 0, op().input_bns().size()) {
+      const auto& sbp_parallel = SbpInferHint4BnInOp(op().input_bns().Get(i)).sbp_parallel();
+      (*bn2sbp)[op().input_bns().Get(i)] = sbp_parallel;
+      (*bn2sbp)[op().output_bns().Get(i)] = sbp_parallel;
+    }
+  }
+};
+
+}  // namespace
+
+void TupleIdentityOp::InitFromOpConf() {
+  CHECK(op_conf().has_tuple_identity_conf());
+  int32_t in_size = op_conf().tuple_identity_conf().in_size();
+  int32_t out_size = op_conf().tuple_identity_conf().out_size();
+  CHECK_GT(in_size, 0);
+  CHECK_EQ(in_size, out_size);
+  EnrollRepeatedInputBn("in", in_size);
+  EnrollRepeatedOutputBn("out", out_size);
+}
+
+const PbMessage& TupleIdentityOp::GetCustomizedConf() const {
+  return op_conf().tuple_identity_conf();
+}
+
+void TupleIdentityOp::InferBlobDescs(
+    std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+    const ParallelContext* parallel_ctx) const {
+  size_t bn_size = op_conf().tuple_identity_conf().in_size();
+  FOR_RANGE(int, i, 0, bn_size) {
+    *GetBlobDesc4BnInOp(output_bns().Get(i)) = *GetBlobDesc4BnInOp(input_bns().Get(i));
+  }
+}
+
+void TupleIdentityOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(new TupleIdentityOpParallelSignature(this));
+}
+
+REGISTER_OP(OperatorConf::kTupleIdentityConf, TupleIdentityOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/tuple_identity_op.h b/oneflow/core/operator/tuple_identity_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..d4319efe1f16f4080fcbc35c1fb780c01ae26dce
--- /dev/null
+++ b/oneflow/core/operator/tuple_identity_op.h
@@ -0,0 +1,32 @@
+#ifndef ONEFLOW_CORE_OPERATOR_TUPLE_IDENTITY_OP_H_
+#define ONEFLOW_CORE_OPERATOR_TUPLE_IDENTITY_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class TupleIdentityOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(TupleIdentityOp);
+  TupleIdentityOp() = default;
+  ~TupleIdentityOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override {
+    return op_conf().tuple_identity_conf().in_size() == 1 && ibn == SoleIbn();
+  }
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_TUPLE_IDENTITY_OP_H_
diff --git a/oneflow/core/operator/unpack_op.cpp b/oneflow/core/operator/unpack_op.cpp
index bda0945ae765cba5cd831f95807140c52531785a..3dfe6b06ba15199ff1743ce8ef7a9cebf61c1d3c 100644
--- a/oneflow/core/operator/unpack_op.cpp
+++ b/oneflow/core/operator/unpack_op.cpp
@@ -14,7 +14,7 @@ void UnpackOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBl
   const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(SoleIbn());
   BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(SoleObn());
   *out_blob_desc = *in_blob_desc;
-  int32_t unpack_num = GetUnpackNum(parallel_ctx->parallel_num());
+  int32_t unpack_num = GetUnpackNum();
   if (in_blob_desc->has_dim0_inner_shape()) {
     CHECK_EQ(0, unpack_num % in_blob_desc->dim0_inner_shape().At(0));
     CHECK_EQ(0, in_blob_desc->dim0_inner_shape().Count(1)
@@ -26,22 +26,17 @@ void UnpackOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBl
   out_blob_desc->mut_dim0_inner_shape() = Shape({1, out_blob_desc->shape().At(0)});
 }
 
-int32_t UnpackOp::GetUnpackNum(int64_t parallel_num) const {
-  CHECK(op_conf().has_unpack_conf());
-  const UnpackOpConf& conf = op_conf().unpack_conf();
-  if (conf.has_unpack_num()) {
-    return conf.unpack_num();
-  } else if (conf.has_unpack_num_per_record()) {
-    CHECK_EQ(Global<JobDesc>::Get()->PieceSize() % parallel_num, 0);
-    int64_t unpack_num =
-        Global<JobDesc>::Get()->PieceSize() / parallel_num * conf.unpack_num_per_record();
-    CHECK_LE(unpack_num, static_cast<int64_t>(MaxVal<int32_t>()));
-    return static_cast<int32_t>(unpack_num);
-  } else {
-    UNIMPLEMENTED();
-  }
+void UnpackOp::InferOutputBlobTimeShape(
+    std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+    const ParallelContext* parallel_ctx, Shape* time_shape) const {
+  std::vector<int64_t> dim_vec(GetTimeShape4BnInOp("in")->dim_vec());
+  int32_t unpack_num = GetUnpackNum();
+  dim_vec.push_back(unpack_num);
+  *time_shape = Shape(dim_vec);
 }
 
+int32_t UnpackOp::GetUnpackNum() const { return op_conf().unpack_conf().unpack_num(); }
+
 REGISTER_OP(OperatorConf::kUnpackConf, UnpackOp);
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/unpack_op.h b/oneflow/core/operator/unpack_op.h
index 72b12b45575cfb120f1de322fb43ef097345bf02..f52869c4d89dceafcca420c2257445db30012124 100644
--- a/oneflow/core/operator/unpack_op.h
+++ b/oneflow/core/operator/unpack_op.h
@@ -18,9 +18,15 @@ class UnpackOp final : public Operator {
 
   void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
                       const ParallelContext* parallel_ctx) const override;
+  void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
+                                const ParallelContext* parallel_ctx,
+                                Shape* time_shape) const override;
   bool NeedInBlobWhenBackward() const override { return true; }
   bool NeedOutBlobWhenBackward() const override { return false; }
-  int32_t GetUnpackNum(int64_t parallel_num) const;
+  int32_t GetUnpackNum() const;
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
 };
 
 }  // namespace oneflow
diff --git a/oneflow/core/operator/variable_op.cpp b/oneflow/core/operator/variable_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6382046311a837a6ef5524998b61642c2612cc3a
--- /dev/null
+++ b/oneflow/core/operator/variable_op.cpp
@@ -0,0 +1,125 @@
+#include "oneflow/core/operator/variable_op.h"
+#include "oneflow/core/common/balanced_splitter.h"
+
+namespace oneflow {
+
+namespace {
+
+// S(0) -> C
+class VariableOpDataSplitOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(VariableOpDataSplitOpParallelSignature);
+  ~VariableOpDataSplitOpParallelSignature() override = default;
+
+  VariableOpDataSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": S(0) -> C"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    if (parallel_ctx->policy() == kDataParallel) {
+      return MakeOpParallelMatchSuccess();
+    } else {
+      return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kDataParallel);
+    }
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    CHECK(SbpInferHint4Ibn("tick").is_data_split());
+    (*bn2sbp)["tick"].mutable_split_parallel()->set_axis(0);
+    (*bn2sbp)["out"].mutable_broadcast_parallel();
+  }
+};
+
+// S(0) -> S
+class VariableOpModelSplitOpParallelSignature final : public OpParallelSignature {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(VariableOpModelSplitOpParallelSignature);
+  ~VariableOpModelSplitOpParallelSignature() override = default;
+
+  VariableOpModelSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {}
+
+  const std::string Description() const override { return op().op_name() + ": S(0) -> S"; }
+
+  const OpParallelMatchResult GetMatchResult(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const ParallelContext* parallel_ctx) const override {
+    if (parallel_ctx->policy() == kModelParallel) {
+      return MakeOpParallelMatchSuccess();
+    } else {
+      return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel);
+    }
+  }
+
+  void GenerateSignature(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      HashMap<std::string, SbpParallel>* bn2sbp) const override {
+    CHECK(SbpInferHint4Ibn("tick").is_data_split());
+    (*bn2sbp)["tick"].mutable_split_parallel()->set_axis(0);
+    (*bn2sbp)["out"].mutable_split_parallel()->set_axis(
+        (op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, "out")));
+  }
+};
+
+}  // namespace
+
+void VariableOp::InitFromOpConf() {
+  CHECK(op_conf().has_variable_conf());
+  EnrollInputBn("tick", false);
+  EnrollOutputBn("out", Global<JobDesc>::Get()->IsTrain() && op_conf().trainable());
+  EnrollModelBn(op_conf().variable_conf().model_name());
+}
+
+const PbMessage& VariableOp::GetCustomizedConf() const { return op_conf().variable_conf(); }
+
+int32_t VariableOp::OutputBlobModelSplitAxis(
+    const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+    const std::string& obn) const {
+  return op_conf().variable_conf().model_split_axis();
+}
+
+void VariableOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                                const ParallelContext* parallel_ctx) const {
+  const VariableOpConf& variable_conf = op_conf().variable_conf();
+  BlobDesc* model_blob_desc = GetBlobDesc4BnInOp(variable_conf.model_name());
+  model_blob_desc->mut_shape() = Shape(variable_conf.shape());
+  model_blob_desc->set_data_type(variable_conf.has_data_type()
+                                     ? variable_conf.data_type()
+                                     : Global<JobDesc>::Get()->DefaultDataType());
+  if (parallel_ctx->policy() == kModelParallel) {
+    int32_t model_split_axis = variable_conf.model_split_axis();
+    CHECK_GE(model_split_axis, 0);
+    CHECK_LT(model_split_axis, model_blob_desc->shape().NumAxes());
+    int64_t split_dim_num = model_blob_desc->shape().At(model_split_axis);
+    BalancedSplitter bs(split_dim_num, parallel_ctx->parallel_num());
+    model_blob_desc->mut_shape().Set(model_split_axis, bs.At(parallel_ctx->parallel_id()).size());
+  } else {
+    CHECK_EQ(parallel_ctx->policy(), kDataParallel);
+  }
+  *GetBlobDesc4BnInOp("out") = *model_blob_desc;
+}
+
+void VariableOp::GetOpParallelSignatures(
+    std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const {
+  op_parallel_signatures->emplace_back(new VariableOpDataSplitOpParallelSignature(this));
+  op_parallel_signatures->emplace_back(new VariableOpModelSplitOpParallelSignature(this));
+}
+
+void VariableOp::InferIsModelBlob4OutputBlobs(
+    std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const {
+  *IsModelBlob4BnInOp("out") = true;
+}
+
+void VariableOp::VirtualGenKernelConf(
+    std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
+    KernelConf* conf) const {
+  conf->mutable_variable_conf()->set_is_fw_inplace(*is_fw_inplace_);
+  conf->mutable_variable_conf()->set_is_bw_inplace(*is_bw_inplace_);
+}
+
+REGISTER_OP(OperatorConf::kVariableConf, VariableOp);
+
+}  // namespace oneflow
diff --git a/oneflow/core/operator/variable_op.h b/oneflow/core/operator/variable_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..79c8291c6417382fc7ee80b6b7137a46e65ac434
--- /dev/null
+++ b/oneflow/core/operator/variable_op.h
@@ -0,0 +1,45 @@
+#ifndef ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_
+#define ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_
+
+#include "oneflow/core/operator/operator.h"
+
+namespace oneflow {
+
+class VariableOp final : public Operator {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(VariableOp);
+  VariableOp()
+      : Operator(),
+        is_fw_inplace_(std::make_unique<bool>(false)),
+        is_bw_inplace_(std::make_unique<bool>(false)) {}
+  ~VariableOp() = default;
+
+  void InitFromOpConf() override;
+  const PbMessage& GetCustomizedConf() const override;
+  bool NeedInBlobWhenBackward() const override { return false; }
+  bool NeedOutBlobWhenBackward() const override { return false; }
+  void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                      const ParallelContext* parallel_ctx) const override;
+  int32_t OutputBlobModelSplitAxis(
+      const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn,
+      const std::string& obn) const override;
+
+  void set_is_fw_inplace(bool val) const { *is_fw_inplace_ = val; }
+  void set_is_bw_inplace(bool val) const { *is_bw_inplace_ = val; }
+
+ private:
+  bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
+  void GetOpParallelSignatures(
+      std::vector<std::unique_ptr<const OpParallelSignature>>*) const override;
+  void InferIsModelBlob4OutputBlobs(
+      std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const;
+  void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
+                            const ParallelContext*, KernelConf*) const override;
+
+  std::unique_ptr<bool> is_fw_inplace_;
+  std::unique_ptr<bool> is_bw_inplace_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_
diff --git a/oneflow/core/persistence/tee_persistent_log_stream.h b/oneflow/core/persistence/tee_persistent_log_stream.h
index efe21e30ddb659810c413115e1416e17e594fdd6..5b29239cf44344faf7bde0363f12342b8374e072 100644
--- a/oneflow/core/persistence/tee_persistent_log_stream.h
+++ b/oneflow/core/persistence/tee_persistent_log_stream.h
@@ -29,10 +29,10 @@ class TeePersistentLogStream final {
   void Write(const PbMessage& proto);
 
   static std::unique_ptr<TeePersistentLogStream> Create(const std::string& path);
+  void Flush();
 
  private:
   explicit TeePersistentLogStream(const std::string& path);
-  void Flush();
   std::vector<LogStreamDestination> destinations_;
   std::vector<std::unique_ptr<PersistentOutStream>> branches_;
 };
diff --git a/oneflow/core/persistence/windows/windows_file_system.cpp b/oneflow/core/persistence/windows/windows_file_system.cpp
index 9b5335ed3dbd947c72f8567008b975c7385027a9..eac14309e9538026c7ddff32e1fb9a8d2570f237 100644
--- a/oneflow/core/persistence/windows/windows_file_system.cpp
+++ b/oneflow/core/persistence/windows/windows_file_system.cpp
@@ -17,7 +17,7 @@ typedef std::unique_ptr<void, decltype(CloseHandleFunc)> UniqueCloseHandlePtr;
 // PLEASE NOTE: hfile is expected to be an async handle
 // (i.e. opened with FILE_FLAG_OVERLAPPED)
 SSIZE_T pread(HANDLE hfile, char* src, size_t num_bytes, uint64_t offset) {
-  assert(num_bytes <= MaxVal<DWORD>());
+  assert(num_bytes <= GetMaxVal<DWORD>());
   OVERLAPPED overlapped = {0};
   ULARGE_INTEGER offset_union;
   offset_union.QuadPart = offset;
diff --git a/oneflow/core/record/ofrecord_raw_decoder.cpp b/oneflow/core/record/ofrecord_raw_decoder.cpp
index 5db8113b3807381ba0180beb429708643a9b9c0b..8b112fe0fca431dbc93348f9086dfb41eedd154f 100644
--- a/oneflow/core/record/ofrecord_raw_decoder.cpp
+++ b/oneflow/core/record/ofrecord_raw_decoder.cpp
@@ -85,7 +85,12 @@ void OFRecordDecoderImpl<EncodeCase::kRaw, T>::ReadOneCol(
   else if (feature.has_##PbT##_list()) {                                              \
     const auto& list = feature.PbT##_list();                                          \
     const CppT* in_dptr = list.value().data();                                        \
-    one_col_elem_num = std::max<int64_t>(one_col_elem_num, list.value_size());        \
+    if (blob_conf.encode_case().raw().dim1_varying_length()) {                        \
+      CHECK_LE(list.value_size(), one_col_elem_num);                                  \
+      one_col_elem_num = list.value_size();                                           \
+    } else {                                                                          \
+      CHECK_EQ(one_col_elem_num, list.value_size());                                  \
+    }                                                                                 \
     FixInDptrThenCopyElem<CppT, T>(ctx, in_dptr, col_id, one_col_elem_num, out_dptr); \
   }
   DEFINE_ONE_ELIF(float, float)
diff --git a/oneflow/core/record/ofrecord_raw_decoder.h b/oneflow/core/record/ofrecord_raw_decoder.h
index dc72bddcb63d521688633ac88ef14791f0d2e9a6..29e4d83e1a372f96a4dd10a131c809f9974078d1 100644
--- a/oneflow/core/record/ofrecord_raw_decoder.h
+++ b/oneflow/core/record/ofrecord_raw_decoder.h
@@ -11,11 +11,12 @@ class OFRecordDecoderImpl<EncodeCase::kRaw, T> final : public OFRecordDecoder<En
   bool HasDim1ValidNumField(const EncodeConf& encode_conf) const override;
   bool HasDim2ValidNumField(const EncodeConf& encode_conf) const override { return false; }
 
- private:
-  int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override;
   void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id,
                   T* out_dptr, int64_t one_col_elem_num,
                   std::function<int32_t(void)> NextRandomInt) const override;
+
+ private:
+  int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override;
   void SetDim1ValidNum(const Feature& feature, Blob* out_blob, int64_t dim0_idx) const override;
 };
 
diff --git a/oneflow/core/record/ofrecord_raw_encoder.cpp b/oneflow/core/record/ofrecord_raw_encoder.cpp
index 53afc259d0b65e49719efdff7bd891110c390871..db85c1ef155befeb1cd15ee8e9a609bc4e16993a 100644
--- a/oneflow/core/record/ofrecord_raw_encoder.cpp
+++ b/oneflow/core/record/ofrecord_raw_encoder.cpp
@@ -2,6 +2,31 @@
 
 namespace oneflow {
 
+namespace {
+
+template<typename T>
+void EncodeDataToFeature(DeviceCtx* ctx, Feature* feature, const T* in_dptr, size_t elem_num) {
+  DataType data_type = GetDataType<T>();
+  if (data_type == DataType::kInt8) {
+    feature->mutable_bytes_list()->add_value(reinterpret_cast<const char*>(in_dptr), elem_num);
+  }
+#define DEFINE_ONE_ELIF(CppT, ListT)                                                     \
+  else if (data_type == GetDataType<CppT>()) {                                           \
+    feature->mutable_##ListT##_list()->mutable_value()->Resize(elem_num, 0);             \
+    CppT* out_dptr = feature->mutable_##ListT##_list()->mutable_value()->mutable_data(); \
+    Memcpy<DeviceType::kCPU>(nullptr, out_dptr, in_dptr, elem_num * sizeof(T));          \
+  }
+  DEFINE_ONE_ELIF(float, float)
+  DEFINE_ONE_ELIF(double, double)
+  DEFINE_ONE_ELIF(int32_t, int32)
+#undef DEFINE_ONE_ELIF
+  else {
+    UNIMPLEMENTED();
+  }
+}
+
+}  // namespace
+
 template<typename T>
 void OFRecordEncoderImpl<EncodeCase::kRaw, T>::EncodeOneCol(DeviceCtx* ctx, const Blob* in_blob,
                                                             int64_t in_offset, Feature& feature,
@@ -15,23 +40,13 @@ void OFRecordEncoderImpl<EncodeCase::kRaw, T>::EncodeOneCol(DeviceCtx* ctx, cons
   int64_t elem_num = shape.NumAxes() == 1 ? 1 : in_blob->dim1_valid_num(dim0_idx) * shape.Count(2);
 
   const T* in_dptr = in_blob->dptr<T>() + in_offset;
-  DataType data_type = GetDataType<T>();
-  if (data_type == DataType::kInt8) {
-    feature.mutable_bytes_list()->add_value(reinterpret_cast<const char*>(in_dptr), elem_num);
-  }
-#define DEFINE_ONE_ELIF(CppT, ListT)                                                    \
-  else if (data_type == GetDataType<CppT>()) {                                          \
-    feature.mutable_##ListT##_list()->mutable_value()->Resize(elem_num, 0);             \
-    CppT* out_dptr = feature.mutable_##ListT##_list()->mutable_value()->mutable_data(); \
-    Memcpy<DeviceType::kCPU>(nullptr, out_dptr, in_dptr, elem_num * sizeof(T));         \
-  }
-  DEFINE_ONE_ELIF(float, float)
-  DEFINE_ONE_ELIF(double, double)
-  DEFINE_ONE_ELIF(int32_t, int32)
-#undef DEFINE_ONE_ELIF
-  else {
-    UNIMPLEMENTED();
-  }
+  EncodeDataToFeature(ctx, &feature, in_dptr, elem_num);
+}
+
+template<typename T>
+void OFRecordEncoderImpl<EncodeCase::kRaw, T>::EncodeBlob(DeviceCtx* ctx, const Blob* in_blob,
+                                                          Feature* feature) const {
+  EncodeDataToFeature(ctx, feature, in_blob->dptr<T>(), in_blob->shape().elem_cnt());
 }
 
 #define INSTANTIATE_OFRECORD_RAW_ENCODER(type_cpp, type_proto) \
diff --git a/oneflow/core/record/ofrecord_raw_encoder.h b/oneflow/core/record/ofrecord_raw_encoder.h
index b8c985bebfe7b35b2b4db6c2d8bcd51ebf5d6806..fcd5ffbd01bd9149c5f870a78db81cd4e9824dde 100644
--- a/oneflow/core/record/ofrecord_raw_encoder.h
+++ b/oneflow/core/record/ofrecord_raw_encoder.h
@@ -7,6 +7,9 @@ namespace oneflow {
 
 template<typename T>
 class OFRecordEncoderImpl<EncodeCase::kRaw, T> final : public OFRecordEncoderIf {
+ public:
+  void EncodeBlob(DeviceCtx* ctx, const Blob* in_blob, Feature* feature) const;
+
  private:
   void EncodeOneCol(DeviceCtx*, const Blob* in_blob, int64_t in_offset, Feature&,
                     const std::string& field_name, int64_t one_col_elem_num) const override;
diff --git a/oneflow/core/record/ofrecord_reader.cpp b/oneflow/core/record/ofrecord_reader.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..93d36d2fa5ce79262b36df28d9aa248552a028db
--- /dev/null
+++ b/oneflow/core/record/ofrecord_reader.cpp
@@ -0,0 +1,81 @@
+#include "oneflow/core/record/ofrecord_reader.h"
+
+namespace oneflow {
+
+constexpr int64_t MAX_CHUNK_SIZE = 64 * 1024 * 1024;  // 64M
+
+namespace {
+
+bool ReadChunk(PersistentInStream* is, OFRecordChunk* chunk) {
+  if (is->Read(reinterpret_cast<char*>(&chunk->size), sizeof(int64_t)) == 0) {
+    CHECK_GE(chunk->size, 0);
+    CHECK_LE(chunk->size, MAX_CHUNK_SIZE);
+    chunk->data.reset(new char[chunk->size]);
+    CHECK_EQ(is->Read(chunk->data.get(), chunk->size), 0);
+    return true;
+  }
+  return false;
+}
+
+}  // namespace
+
+NaiveOFRecordReader::NaiveOFRecordReader(PersistentInStream* in, size_t num_max_read)
+    : in_stream_(in), num_read_(0), num_max_read_(num_max_read) {}
+
+size_t NaiveOFRecordReader::Read(size_t n, OFRecord* allocated_records) {
+  OFRecordChunk chunk;
+  const size_t can_read = std::min(n, num_max_read_ - num_read_);
+  FOR_RANGE(size_t, i, 0, can_read) {
+    if (ReadChunk(in_stream_, &chunk)) {
+      CHECK(allocated_records[i].ParseFromArray(chunk.data.get(), chunk.size));
+      ++num_read_;
+    } else {
+      return i;
+    }
+  }
+  return can_read;
+}
+
+RandomShuffleOFRecordReader::RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size,
+                                                         size_t num_max_read, int32_t random_seed)
+    : in_stream_(in),
+      buffer_size_(buffer_size),
+      num_max_read_(num_max_read),
+      random_gen_(random_seed),
+      is_eof_(false) {
+  CHECK_GT(buffer_size, 0);
+  buffered_chunks_.reserve(buffer_size);
+}
+
+void RandomShuffleOFRecordReader::FillBuffer() {
+  for (; num_read_ < num_max_read_ && buffered_chunks_.size() < buffer_size_; ++num_read_) {
+    OFRecordChunk chunk;
+    if (ReadChunk(in_stream_, &chunk)) {
+      buffered_chunks_.emplace_back(std::move(chunk));
+    } else {
+      is_eof_ = true;
+      break;
+    }
+  }
+  if (num_read_ == num_max_read_) { is_eof_ = true; }
+}
+
+size_t RandomShuffleOFRecordReader::Read(size_t n, OFRecord* allocated_records) {
+  size_t cur_read = 0;
+  while (cur_read < n) {
+    if (!is_eof_) { FillBuffer(); }
+    if (buffered_chunks_.empty()) { break; }
+    const size_t pos =
+        std::uniform_int_distribution<size_t>(0, buffered_chunks_.size() - 1)(random_gen_);
+    if (pos != buffered_chunks_.size() - 1) {
+      std::swap(buffered_chunks_[pos], buffered_chunks_.back());
+    }
+    CHECK(allocated_records[cur_read].ParseFromArray(buffered_chunks_.back().data.get(),
+                                                     buffered_chunks_.back().size));
+    buffered_chunks_.pop_back();
+    ++cur_read;
+  }
+  return cur_read;
+}
+
+}  // namespace oneflow
diff --git a/oneflow/core/record/ofrecord_reader.h b/oneflow/core/record/ofrecord_reader.h
new file mode 100644
index 0000000000000000000000000000000000000000..559647d64b3a23e4cd6b0f34d1599bbe6da2f85b
--- /dev/null
+++ b/oneflow/core/record/ofrecord_reader.h
@@ -0,0 +1,66 @@
+#ifndef ONEFLOW_CORE_RECORD_OFRECORD_READER_H_
+#define ONEFLOW_CORE_RECORD_OFRECORD_READER_H_
+
+#include "oneflow/core/record/record.pb.h"
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/persistence/persistent_in_stream.h"
+
+namespace oneflow {
+
+struct OFRecordChunk {
+  int64_t size = 0;
+  std::unique_ptr<char[]> data;
+};
+
+class OFRecordReader {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(OFRecordReader);
+  OFRecordReader() = default;
+  virtual ~OFRecordReader() = default;
+
+  virtual size_t Read(size_t n, OFRecord* allocated_records) = 0;
+};
+
+class NaiveOFRecordReader final : public OFRecordReader {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(NaiveOFRecordReader);
+  explicit NaiveOFRecordReader(PersistentInStream* in)
+      : NaiveOFRecordReader(in, MaxVal<size_t>::value) {}
+  NaiveOFRecordReader(PersistentInStream* in, size_t num_max_read);
+  ~NaiveOFRecordReader() override = default;
+
+ private:
+  size_t Read(size_t n, OFRecord* allocated_records) override;
+
+  PersistentInStream* in_stream_;
+  size_t num_read_;
+  const size_t num_max_read_;
+};
+
+class RandomShuffleOFRecordReader final : public OFRecordReader {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(RandomShuffleOFRecordReader);
+  RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size, size_t num_max_read,
+                              int32_t random_seed);
+  RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size, size_t num_max_read)
+      : RandomShuffleOFRecordReader(in, buffer_size, num_max_read, std::random_device()()) {}
+  RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size)
+      : RandomShuffleOFRecordReader(in, buffer_size, MaxVal<size_t>::value) {}
+  ~RandomShuffleOFRecordReader() override = default;
+
+ private:
+  size_t Read(size_t n, OFRecord* allocated_records) override;
+  void FillBuffer();
+
+  PersistentInStream* in_stream_;
+  const size_t buffer_size_;
+  const size_t num_max_read_;
+  std::mt19937 random_gen_;
+  size_t num_read_;
+  std::vector<OFRecordChunk> buffered_chunks_;
+  bool is_eof_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_RECORD_OFRECORD_READER_H_
diff --git a/oneflow/core/register/blob.cpp b/oneflow/core/register/blob.cpp
index 38cc3a9d8a4e87ccac52185a8dfee034d6a4d3aa..49dfd2cc33d84ac777513de87f6cdcad6e0cdb90 100644
--- a/oneflow/core/register/blob.cpp
+++ b/oneflow/core/register/blob.cpp
@@ -28,6 +28,7 @@ void Blob::Init(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, cha
   dim2_valid_num_ptr_ = header_pod_ptr_.MutTensorPtr<int64_t>(FieldKey::kDim2ValidNum, nullptr);
   record_id_in_device_piece_ptr_ =
       header_pod_ptr_.MutTensorPtr<int64_t>(FieldKey::kRecordIdInDevicePiece, nullptr);
+  loss_instance_num_ptr_ = header_pod_ptr_.MutTensorPtr<float>(FieldKey::kLossInstanceNum, nullptr);
   dptr_ = body_ptr;
   dynamic_shape_ = blob_desc->shape();
 }
@@ -131,6 +132,16 @@ void Blob::set_record_id_in_device_piece(int64_t no, int64_t val) {
   record_id_in_device_piece_ptr_[no] = val;
 }
 
+float Blob::loss_instance_num() const {
+  CHECK_NOTNULL(loss_instance_num_ptr_);
+  return *loss_instance_num_ptr_;
+}
+
+void Blob::set_loss_instance_num(float val) {
+  CHECK_NOTNULL(loss_instance_num_ptr_);
+  *loss_instance_num_ptr_ = val;
+}
+
 void Blob::set_dim2_valid_num(int64_t dim0_idx, int64_t dim1_idx, int64_t val) {
   CHECK_NOTNULL(dim2_valid_num_ptr_);
   CHECK_GE(dim0_idx, 0);
diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h
index af2edcb031fb838bd2eb2b7892da08bc6f2bd862..a9c45e113da1593e9fb67204cf5b698f265a4759 100644
--- a/oneflow/core/register/blob.h
+++ b/oneflow/core/register/blob.h
@@ -54,6 +54,10 @@ class Blob final {
   void set_record_id_in_device_piece(int64_t no, int64_t val);
   const int64_t* record_id_in_device_piece_ptr() const { return record_id_in_device_piece_ptr_; }
   int64_t* mut_record_id_in_device_piece_ptr() { return record_id_in_device_piece_ptr_; }
+  float loss_instance_num() const;
+  void set_loss_instance_num(float val);
+  const float* loss_instance_num_ptr() const { return loss_instance_num_ptr_; }
+  float* mut_loss_instance_num_ptr() { return loss_instance_num_ptr_; }
 
   const void* header_ptr() const { return header_ptr_; }
   void* mut_header_ptr() { return header_ptr_; }
@@ -99,6 +103,7 @@ class Blob final {
   bool has_record_id_in_device_piece_field() const {
     return blob_desc_->has_record_id_in_device_piece_field();
   }
+  bool has_loss_instance_num_field() const { return blob_desc_->has_loss_instance_num_field(); }
   int32_t max_col_num() const { return blob_desc_->max_col_num(); }
   size_t ByteSizeOfBlobHeader() const { return blob_desc_->ByteSizeOfBlobHeader(); }
   size_t ByteSizeOfDataIdField() const { return blob_desc_->ByteSizeOfDataIdField(); }
@@ -161,6 +166,7 @@ class Blob final {
   int64_t* dim1_valid_num_ptr_;
   int64_t* dim2_valid_num_ptr_;
   int64_t* record_id_in_device_piece_ptr_;
+  float* loss_instance_num_ptr_;
   void* dptr_;
   const RtBlobDesc* blob_desc_;
   Regst* regst_;
diff --git a/oneflow/core/register/blob_desc.cpp b/oneflow/core/register/blob_desc.cpp
index 4714eeee09aac5abb030e6c9ca8808368341b1ce..e35e5a6ec5c62c52ea729e8ca9f386a69f4c1039 100644
--- a/oneflow/core/register/blob_desc.cpp
+++ b/oneflow/core/register/blob_desc.cpp
@@ -16,6 +16,7 @@ BlobDesc::BlobDesc(const Shape& shape, DataType data_type, bool has_data_id, boo
       has_dim1_valid_num_(false),
       has_dim2_valid_num_(false),
       has_record_id_in_device_piece_(false),
+      has_loss_instance_num_(false),
       max_col_num_(max_col_num),
       blob_mem_id_(-1),
       body_field_(shape, data_type) {}
@@ -35,6 +36,7 @@ void BlobDesc::InitFromProto(const BlobDescProto& proto) {
     has_dim1_valid_num_ = false;
     has_dim2_valid_num_ = false;
     has_record_id_in_device_piece_ = false;
+    has_loss_instance_num_ = false;
     opaque_header_ = FieldDesc(proto.header().opaque_header());
   } else {
     CHECK(proto.header().has_field_header());
@@ -45,6 +47,7 @@ void BlobDesc::InitFromProto(const BlobDescProto& proto) {
     has_dim1_valid_num_ = header_pod_desc_.HasField(FieldKey::kDim1ValidNum);
     has_dim2_valid_num_ = header_pod_desc_.HasField(FieldKey::kDim2ValidNum);
     has_record_id_in_device_piece_ = header_pod_desc_.HasField(FieldKey::kRecordIdInDevicePiece);
+    has_loss_instance_num_ = header_pod_desc_.HasField(FieldKey::kLossInstanceNum);
   }
   if (proto.has_dim0_inner_shape()) {
     dim0_inner_shape_.reset(new Shape(proto.dim0_inner_shape()));
@@ -59,6 +62,7 @@ BlobDesc::BlobDesc(const StructPodDesc& header_pod_desc, int64_t header_byte_siz
       has_dim1_valid_num_(false),
       has_dim2_valid_num_(false),
       has_record_id_in_device_piece_(false),
+      has_loss_instance_num_(false),
       max_col_num_(max_col_num),
       blob_mem_id_(-1),
       body_field_(shape, data_type) {
@@ -102,6 +106,11 @@ void BlobDesc::set_has_record_id_in_device_piece_field(bool val) {
   has_record_id_in_device_piece_ = val;
 }
 
+void BlobDesc::set_has_loss_instance_num_field(bool val) {
+  CHECK(!header_is_opaque_);
+  has_loss_instance_num_ = val;
+}
+
 Shape& BlobDesc::mut_dim0_inner_shape() {
   CHECK(!header_is_opaque_);
   if (!dim0_inner_shape_) { dim0_inner_shape_.reset(new Shape()); }
@@ -145,6 +154,10 @@ void BlobDesc::RecordIdInDevicePieceToProto(StructPodDesc* header_pod_desc) cons
                             TensorPodDesc(shape, DataType::kInt64));
 }
 
+void BlobDesc::LossInstanceNumToProto(StructPodDesc* header_pod_desc) const {
+  header_pod_desc->AddField(FieldKey::kLossInstanceNum, TensorPodDesc({1}, DataType::kFloat));
+}
+
 void BlobDesc::HeaderToProto(BlobDescProto* proto) const {
   proto->mutable_header()->set_max_col_num(max_col_num_);
   proto->mutable_header()->set_blob_mem_id(blob_mem_id_);
@@ -157,6 +170,7 @@ void BlobDesc::HeaderToProto(BlobDescProto* proto) const {
     if (has_dim1_valid_num_field()) { Dim1ValidNumToProto(&header_pod_desc); }
     if (has_dim2_valid_num_field()) { Dim2ValidNumToProto(&header_pod_desc); }
     if (has_record_id_in_device_piece_field()) { RecordIdInDevicePieceToProto(&header_pod_desc); }
+    if (has_loss_instance_num_field()) { LossInstanceNumToProto(&header_pod_desc); }
     header_pod_desc.ToProto(proto->mutable_header()->mutable_header_pod_desc());
   } else {
     opaque_header_.ToProto(proto->mutable_header()->mutable_opaque_header());
@@ -177,8 +191,8 @@ bool BlobDesc::operator==(const BlobDesc& rhs) const {
          && has_dim1_valid_num_ == rhs.has_dim1_valid_num_
          && has_dim2_valid_num_ == rhs.has_dim2_valid_num_
          && has_record_id_in_device_piece_ == rhs.has_record_id_in_device_piece_
-         && max_col_num_ == rhs.max_col_num_ && blob_mem_id_ == rhs.blob_mem_id_
-         && body_field_ == rhs.body_field_;
+         && has_loss_instance_num_ == rhs.has_loss_instance_num_ && max_col_num_ == rhs.max_col_num_
+         && blob_mem_id_ == rhs.blob_mem_id_ && body_field_ == rhs.body_field_;
 }
 
 BlobDesc& BlobDesc::operator=(const BlobDesc& blob_desc) {
@@ -248,4 +262,9 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
   return ret;
 }
 
+bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) {
+  return (lhs.blob_desc().header().blob_mem_id() < rhs.blob_desc().header().blob_mem_id())
+         || (lhs.lbi() < rhs.lbi());
+}
+
 }  // namespace oneflow
diff --git a/oneflow/core/register/blob_desc.h b/oneflow/core/register/blob_desc.h
index 39467de729fae56509dd6806ee15e6982b88b9e0..80bbe18d135cff32afe1f3b0a54f762e48e889e3 100644
--- a/oneflow/core/register/blob_desc.h
+++ b/oneflow/core/register/blob_desc.h
@@ -7,6 +7,7 @@
 #include "oneflow/core/register/blob_desc.pb.h"
 #include "oneflow/core/register/pod_desc.h"
 #include "oneflow/core/job/job_desc.h"
+#include "oneflow/core/register/register_desc.pb.h"
 
 namespace oneflow {
 
@@ -50,6 +51,9 @@ class BlobDesc {
   bool has_record_id_in_device_piece_field() const { return has_record_id_in_device_piece_; }
   void set_has_record_id_in_device_piece_field(bool val);
 
+  bool has_loss_instance_num_field() const { return has_loss_instance_num_; }
+  void set_has_loss_instance_num_field(bool val);
+
   bool has_col_num_field() const { return has_col_num_; }
   void set_has_col_num_field(bool val);
 
@@ -72,6 +76,7 @@ class BlobDesc {
   void Dim1ValidNumToProto(StructPodDesc* header_pod_desc) const;
   void Dim2ValidNumToProto(StructPodDesc* header_pod_desc) const;
   void RecordIdInDevicePieceToProto(StructPodDesc* header_pod_desc) const;
+  void LossInstanceNumToProto(StructPodDesc* header_pod_desc) const;
 
   bool header_is_opaque_;
   FieldDesc opaque_header_;
@@ -83,6 +88,7 @@ class BlobDesc {
   bool has_dim1_valid_num_;
   bool has_dim2_valid_num_;
   bool has_record_id_in_device_piece_;
+  bool has_loss_instance_num_;
   int64_t max_col_num_;
   int32_t blob_mem_id_;
 
@@ -93,6 +99,8 @@ class BlobDesc {
 std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
     const HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>>& lbi2blob_desc);
 
+bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs);
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
diff --git a/oneflow/core/register/pod.proto b/oneflow/core/register/pod.proto
index 3ca40a886206c61ddf939258b9be66ca50db94b6..a2a1c83f6f66117fc0d83249a3a3f92b5836e37a 100644
--- a/oneflow/core/register/pod.proto
+++ b/oneflow/core/register/pod.proto
@@ -22,6 +22,7 @@ enum FieldKey {
   kDim1ValidNum = 4;
   kDim2ValidNum = 5;
   kRecordIdInDevicePiece = 6;
+  kLossInstanceNum = 7;
 }
 
 message FieldId {
diff --git a/oneflow/core/register/pod_desc.h b/oneflow/core/register/pod_desc.h
index 19444c2077b9ab27286d8c17fb894cfe0ad5df27..145dcf890f86afa6139eacc9f8ad1386147d7296 100644
--- a/oneflow/core/register/pod_desc.h
+++ b/oneflow/core/register/pod_desc.h
@@ -81,23 +81,25 @@ class StructPodDesc final : public PodDesc {
   ~StructPodDesc() = default;
 
   StructPodDesc* MutStructField(const FieldId& field_id);
+  StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment);
   const PodDesc& Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); }
   const PodDesc& Field(const FieldId& field_id) const;
   void AddField(FieldKey field_key, const PodDesc& pod_desc);
   void AddField(const FieldId& field_id, const PodDesc& pod_desc);
-  size_t ByteSize() const override;
-  void InitFromProto(const StructPodProto& struct_pod);
-
+  void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment);
   bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); }
   bool HasField(const FieldId& field_id) const;
-  StructPodDesc& operator=(const StructPodDesc&);
+
   std::unique_ptr<PodDesc> Clone() const override { return std::make_unique<StructPodDesc>(*this); }
+  void InitFromProto(const StructPodProto& struct_pod);
   void ToProto(PodProto* pod_proto) const override { ToProto(pod_proto->mutable_struct_pod()); }
   void ToProto(StructPodProto* pod_proto) const;
-  StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment);
-  void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment);
-  bool operator==(const PodDesc& rhs) const override;
+
   size_t ByteOffset4Field(const FieldId& field_name) const;
+  size_t ByteSize() const override;
+
+  StructPodDesc& operator=(const StructPodDesc&);
+  bool operator==(const PodDesc& rhs) const override;
 
  private:
   void Clear();
diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp
index 04b736b61b3c5df5fbaa991de036f66dbfbd58a7..7730c3f348ed97b16601376a64cb71331a9f0e25 100644
--- a/oneflow/core/register/register_desc.cpp
+++ b/oneflow/core/register/register_desc.cpp
@@ -3,6 +3,7 @@
 #include "oneflow/core/graph/copy_task_node.h"
 #include "oneflow/core/job/id_manager.h"
 #include "oneflow/core/register/runtime_blob_desc.h"
+#include "oneflow/core/register/runtime_register_desc.h"
 
 namespace oneflow {
 
@@ -25,6 +26,12 @@ RegstDesc::RegstDesc() {
   enable_mem_sharing_ = false;
   mem_shared_id_ = -1;
   mem_shared_offset_ = -1;
+  mem_shared_inplace_block_id_ = -1;
+}
+
+int64_t RegstDesc::mem_shared_offset() const {
+  CHECK_GE(mem_shared_offset_, 0);
+  return mem_shared_offset_;
 }
 
 void RegstDesc::AddConsumer(const TaskNode* new_consumer) {
@@ -60,6 +67,12 @@ void RegstDesc::CopyBlobDescFrom(const RegstDesc* rhs) {
   CopyBlobDescWithoutAddLbi(rhs);
 }
 
+void RegstDesc::CopyMemSharedInfoFrom(const RegstDesc* rhs) {
+  enable_mem_sharing_ = rhs->enable_mem_sharing_;
+  mem_shared_id_ = rhs->mem_shared_id_;
+  mem_shared_offset_ = rhs->mem_shared_offset_;
+}
+
 void RegstDesc::CopyBlobDescWithoutAddLbi(const RegstDesc* rhs) {
   CHECK_EQ(is_locked_, false);
   for (const auto& pair : lbi2blob_desc_) {
@@ -141,6 +154,15 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
   ret->set_enable_mem_sharing(enable_mem_sharing_);
   ret->set_mem_shared_id(mem_shared_id_);
   ret->set_mem_shared_offset(mem_shared_offset_);
+  ret->add_mem_block_hierarchy()->set_mem_block_id(Global<IDMgr>::Get()->NewMemBlockId());
+  if (mem_shared_inplace_block_id_ != -1) {
+    ret->add_mem_block_hierarchy()->set_mem_block_id(mem_shared_inplace_block_id_);
+  }
+}
+
+bool RegstDesc::HasSameMemSize(const RegstDesc* rhs) {
+  return RtBlobDesc(*(packed_blob_desc_.get())).TotalByteSize()
+         == RtBlobDesc(*(rhs->packed_blob_desc_.get())).TotalByteSize();
 }
 
 bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) {
@@ -153,6 +175,34 @@ bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) {
   return true;
 }
 
+int64_t RegstDesc::ByteOffsetInPackedBlobDescBody(const LogicalBlobId& lbi) const {
+  RegstDescProto regst_desc_proto;
+  ToProto(&regst_desc_proto);
+  RtRegstDesc rt_regst_desc(regst_desc_proto);
+  std::vector<LbiBlobDescPair> lbi_blob_desc_pairs;
+  for (const auto& pair : lbi2blob_desc_) {
+    LbiBlobDescPair lbi_blob_desc_pair;
+    *lbi_blob_desc_pair.mutable_lbi() = pair.first;
+    pair.second->ToProto(lbi_blob_desc_pair.mutable_blob_desc());
+    lbi_blob_desc_pairs.push_back(lbi_blob_desc_pair);
+  }
+  std::sort(lbi_blob_desc_pairs.begin(), lbi_blob_desc_pairs.end(), CompareLbiBlobDescPair);
+
+  bool found = false;
+  int64_t offset = 0;
+  rt_regst_desc.ForEachBlobDescOffsetInOnRegst(
+      lbi_blob_desc_pairs,
+      [&](const LbiBlobDescPair& lbi_blob_desc_pair, int64_t body_offset, int64_t header_offset) {
+        if (found) { return; }
+        if (lbi_blob_desc_pair.lbi() == lbi) {
+          offset = body_offset;
+          found = true;
+        }
+      });
+  CHECK(found);
+  return offset;
+}
+
 void InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto) {
   CHECK_NOTNULL(ctrl_regst_proto);
   ctrl_regst_proto->set_regst_desc_id(Global<IDMgr>::Get()->NewRegstDescId());
diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h
index ed09ec71fb0ccb6ff7cf2e5895194a94c74ab26e..ec473e2b3e42f005b2943825334054376b90d8c9 100644
--- a/oneflow/core/register/register_desc.h
+++ b/oneflow/core/register/register_desc.h
@@ -50,14 +50,20 @@ class RegstDesc final {
   // mem
   const MemoryCase& mem_case() const { return mem_case_; }
   MemoryCase* mut_mem_case() { return &mem_case_; }
+  bool enable_mem_sharing() { return enable_mem_sharing_; }
   void set_enable_mem_sharing(bool enable_mem_sharing) { enable_mem_sharing_ = enable_mem_sharing; }
-  int64_t mem_shared_offset() const { return mem_shared_offset_; }
+  int64_t mem_shared_offset() const;
   void set_mem_shared_offset(int64_t val) { mem_shared_offset_ = val; }
+  int64_t mem_shared_inplace_block_id() const { return mem_shared_inplace_block_id_; }
+  void set_mem_shared_inplace_block_id(int64_t val) { mem_shared_inplace_block_id_ = val; }
   int32_t mem_shared_id() const { return mem_shared_id_; }
   void set_mem_shared_id(int32_t val) { mem_shared_id_ = val; }
+  bool HasSetMemSharedId() { return mem_shared_id_ != -1; }
+  void CopyMemSharedInfoFrom(const RegstDesc*);
 
   const std::shared_ptr<Shape>& data_regst_time_shape() const {
     CHECK(regst_desc_type_.has_data_regst_desc());
+    CHECK(data_regst_time_shape_);
     return data_regst_time_shape_;
   }
   std::shared_ptr<Shape>* mut_data_regst_time_shape() {
@@ -66,12 +72,14 @@ class RegstDesc final {
   }
   RegstDescTypeProto* mut_regst_desc_type() { return &regst_desc_type_; }
   const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; }
+  bool HasSameMemSize(const RegstDesc*);
 
   // util
   int32_t MaxColNum() const { return packed_blob_desc_->max_col_num(); }
   void EraseZeroSizeBlob();
   void ToProto(RegstDescProto*) const;
   bool HasSameBlobDescs(const RegstDesc*);
+  int64_t ByteOffsetInPackedBlobDescBody(const LogicalBlobId& lbi) const;
 
  private:
   int64_t regst_desc_id_;
@@ -89,10 +97,28 @@ class RegstDesc final {
   bool enable_mem_sharing_;
   int32_t mem_shared_id_;
   int64_t mem_shared_offset_;
+  int64_t mem_shared_inplace_block_id_;
 
   std::shared_ptr<Shape> data_regst_time_shape_;
 };
 
+inline bool operator==(const MemBlock& lhs, const MemBlock& rhs) {
+  bool ret = (lhs.mem_block_id() == rhs.mem_block_id());
+  if (ret) { CHECK_EQ(lhs.mem_reduce_method(), rhs.mem_reduce_method()); }
+  return ret;
+}
+
 }  // namespace oneflow
 
+namespace std {
+
+template<>
+struct hash<oneflow::MemBlock> final {
+  size_t operator()(const oneflow::MemBlock& mem_block) const {
+    return hash<int64_t>()(mem_block.mem_block_id());
+  }
+};
+
+}  // namespace std
+
 #endif  // ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_
diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto
index e933c283c8eb445b9c2bb877b71f6fdf2ba7e916..37e6111faf00bad2658ed1492f83a44c6ad39fc6 100644
--- a/oneflow/core/register/register_desc.proto
+++ b/oneflow/core/register/register_desc.proto
@@ -29,6 +29,17 @@ message RegstDescTypeProto {
   }
 }
 
+enum MemReduceMethod {
+  kMemInvalidSharedMethod = 0;
+  kMemSum = 1;
+  kMemMax = 2;
+}
+
+message MemBlock {
+  required int64 mem_block_id = 1;
+  optional MemReduceMethod mem_reduce_method = 2 [default = kMemMax];
+}
+
 message RegstDescProto {
   required int64 regst_desc_id = 1;
   required int64 producer_task_id = 2;
@@ -41,4 +52,6 @@ message RegstDescProto {
   required bool enable_mem_sharing = 9;
   required int32 mem_shared_id = 10;
   required int64 mem_shared_offset = 11;
+  // from bottom to top
+  repeated MemBlock mem_block_hierarchy = 13;
 }
diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp
index 6d79acda4ce17b49e1b14fff8589d7fb91fce069..2ef56a5b6f970d25b04b346faa319a503ae14acb 100644
--- a/oneflow/core/register/register_manager.cpp
+++ b/oneflow/core/register/register_manager.cpp
@@ -75,12 +75,7 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
     for (const LbiBlobDescPair& pair : regst_desc_type.data_regst_desc().lbi2blob_desc()) {
       lbi_pairs.push_back(pair);
     }
-    std::sort(lbi_pairs.begin(), lbi_pairs.end(),
-              [&](const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) {
-                return lhs.blob_desc().header().blob_mem_id()
-                           < rhs.blob_desc().header().blob_mem_id()
-                       || lhs.lbi() < rhs.lbi();
-              });
+    std::sort(lbi_pairs.begin(), lbi_pairs.end(), &CompareLbiBlobDescPair);
     CHECK(!lbi_pairs.empty());
     CHECK(main_mem_ptr != nullptr);
   }
@@ -126,23 +121,14 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs
     cur_header_pointer = main_mem_ptr;
     cur_body_pointer = main_mem_ptr + packed_blob_desc->ByteSizeOfBlobHeader();
   }
-  int32_t last_blob_mem_id = -1;
-  size_t last_size = 0;
-  for (const LbiBlobDescPair& lbi : lbis) {
-    const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi());
-    int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
-    if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
-      cur_body_pointer += last_size;
-    }
-    std::unique_ptr<Blob> blob_ptr(
-        new Blob(regst, blob_desc, cur_header_pointer, cur_body_pointer));
-    InitOFRecordBlobIfNeed(blob_ptr.get());
-    CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second);
-    cur_header_pointer += blob_desc->ByteSizeOfBlobHeader();
-
-    last_blob_mem_id = cur_blob_mem_id;
-    last_size = blob_desc->ByteSizeOfBlobBody();
-  }
+  rt_regst_desc->ForEachBlobDescOffsetInOnRegst(
+      lbis, [&](const LbiBlobDescPair& lbi, int64_t body_offset, int64_t header_offset) {
+        const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi());
+        std::unique_ptr<Blob> blob_ptr(new Blob(
+            regst, blob_desc, cur_header_pointer + header_offset, cur_body_pointer + body_offset));
+        InitOFRecordBlobIfNeed(blob_ptr.get());
+        CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second);
+      });
 }
 
 void RegstMgr::InitOFRecordBlobIfNeed(Blob* blob_ptr) {
diff --git a/oneflow/core/register/runtime_blob_desc.cpp b/oneflow/core/register/runtime_blob_desc.cpp
index 6251d6b736ca0110e3de3ac68d7cb4966d9d1a31..59c1a9f47be64637500b80533bfb50aa518101c9 100644
--- a/oneflow/core/register/runtime_blob_desc.cpp
+++ b/oneflow/core/register/runtime_blob_desc.cpp
@@ -43,6 +43,10 @@ bool RtBlobDesc::has_record_id_in_device_piece_field() const {
   return header_pod_desc_.HasField(FieldKey::kRecordIdInDevicePiece);
 }
 
+bool RtBlobDesc::has_loss_instance_num_field() const {
+  return header_pod_desc_.HasField(FieldKey::kLossInstanceNum);
+}
+
 size_t RtBlobDesc::ByteSizeOfBlobHeader() const { return header_pod_desc_.ByteSize(); }
 
 size_t RtBlobDesc::ByteSizeOfBlobBody() const { return body_desc_.AlignedByteSize(); }
diff --git a/oneflow/core/register/runtime_blob_desc.h b/oneflow/core/register/runtime_blob_desc.h
index 7cb6a8d3f4faa89744b9c026c516f8b8e84b9a55..d84591039e48157d5cb391e061c03f1943b97c89 100644
--- a/oneflow/core/register/runtime_blob_desc.h
+++ b/oneflow/core/register/runtime_blob_desc.h
@@ -29,6 +29,7 @@ class RtBlobDesc {
   bool has_dim1_valid_num_field() const;
   bool has_dim2_valid_num_field() const;
   bool has_record_id_in_device_piece_field() const;
+  bool has_loss_instance_num_field() const;
   const StructPodDesc& header_pod_desc() const { return header_pod_desc_; }
 
   int32_t max_col_num() const { return blob_desc_proto_.header().max_col_num(); }
diff --git a/oneflow/core/register/runtime_register_desc.cpp b/oneflow/core/register/runtime_register_desc.cpp
index 83826b0b355de45d90c5318f83a787a0e0f89018..cbd94e251f49787ab6d443949bd42d8bf989c52d 100644
--- a/oneflow/core/register/runtime_register_desc.cpp
+++ b/oneflow/core/register/runtime_register_desc.cpp
@@ -70,4 +70,25 @@ const Shape& RtRegstDesc::data_regst_time_shape() const {
   return *data_regst_time_shape_;
 }
 
+void RtRegstDesc::ForEachBlobDescOffsetInOnRegst(
+    const std::vector<LbiBlobDescPair>& lbis,
+    const std::function<void(const LbiBlobDescPair&, int64_t body_offset, int64_t header_offset)>&
+        Handler) const {
+  int32_t last_blob_mem_id = -1;
+  size_t last_size = 0;
+  int64_t cur_body_offset = 0;
+  int64_t cur_header_offset = 0;
+  for (const LbiBlobDescPair& lbi : lbis) {
+    const RtBlobDesc* blob_desc = GetRtBlobDescFromLbi(lbi.lbi());
+    int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
+    if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
+      cur_body_offset += last_size;
+    }
+    Handler(lbi, cur_body_offset, cur_header_offset);
+    cur_header_offset += blob_desc->ByteSizeOfBlobHeader();
+    last_blob_mem_id = cur_blob_mem_id;
+    last_size = blob_desc->ByteSizeOfBlobBody();
+  }
+}
+
 }  // namespace oneflow
diff --git a/oneflow/core/register/runtime_register_desc.h b/oneflow/core/register/runtime_register_desc.h
index 6747b4619a685ea177d28ce22764b2467ecbb550..939e2e5426961698ec5b26468db8b4a51afd647c 100644
--- a/oneflow/core/register/runtime_register_desc.h
+++ b/oneflow/core/register/runtime_register_desc.h
@@ -31,6 +31,11 @@ class RtRegstDesc {
   size_t MainByteSize4OneRegst() const;
   const Shape& data_regst_time_shape() const;
 
+  void ForEachBlobDescOffsetInOnRegst(
+      const std::vector<LbiBlobDescPair>& lbis,
+      const std::function<void(const LbiBlobDescPair&, int64_t body_offset, int64_t header_offset)>&
+          Handler) const;
+
  private:
   int64_t regst_desc_id_;
   int64_t producer_actor_id_;
diff --git a/oneflow/core/thread/gpu_thread.cpp b/oneflow/core/thread/gpu_thread.cpp
index 8e36bffdaaeca2ee6d948171c83fb91d1a024eee..ab13574a6a5b7f9a35a751cd6fde2b9486df1588 100644
--- a/oneflow/core/thread/gpu_thread.cpp
+++ b/oneflow/core/thread/gpu_thread.cpp
@@ -14,7 +14,8 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) {
     ctx.cb_event_chan = &cb_event_chan_;
     PollMsgChannel(ctx);
   });
-  cb_event_poller_ = std::thread([this]() {
+  cb_event_poller_ = std::thread([this, dev_id]() {
+    CudaCheck(cudaSetDevice(dev_id));
     CudaCBEvent cb_event;
     while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) {
       CudaCheck(cudaEventSynchronize(cb_event.event));