diff --git a/CMakeLists.txt b/CMakeLists.txt index 48557fc2b4a56438ded4944434197a661bc7cf77..0cf82571ebd846aca7ed5aae36adf0d66be17f4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,11 @@ if (WIN32) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0") else() list(APPEND CUDA_NVCC_FLAGS -std=c++11 -w -Wno-deprecated-gpu-targets) + list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_30,code=\"sm_30,compute_30\") + list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_52,code=\"sm_52,compute_52\") + list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_60,code=\"sm_60,compute_60\") + list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_61,code=\"sm_61,compute_61\") + list(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_70,code=\"sm_70,compute_70\") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare -Wno-unused-function") if (RELEASE_VERSION) list(APPEND CUDA_NVCC_FLAGS -O3) diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index f43ab9b9a84577da39bf935c388725e4deb59822..f368c81215365a4f82c147bf2b8b586abdd65b81 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -1,5 +1,6 @@ # main cpp list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/job/oneflow.cpp) +list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/ndarray/ndarray_reduce_test.cpp) list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_resnet.cpp) list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_alexnet.cpp) list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_googlenet.cpp) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index e33988991f6994c795b0f2ff241cf176da669982..7e5a054376dea24ed1a320b2dbca0e7ac778adca 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -12,6 +12,7 @@ include(libjpeg-turbo) include(opencv) include(eigen) include(cocoapi) +include(half) if (BUILD_CUDA) set(CUDA_SEPARABLE_COMPILATION ON) @@ -90,6 +91,7 @@ set(oneflow_third_party_dependencies eigen cocoapi_copy_headers_to_destination cocoapi_copy_libs_to_destination + half_copy_headers_to_destination ) include_directories( @@ -104,6 +106,7 @@ include_directories( ${OPENCV_INCLUDE_DIR} ${EIGEN_INCLUDE_DIR} ${COCOAPI_INCLUDE_DIR} + ${HALF_INCLUDE_DIR} ) if (BUILD_CUDA) @@ -124,3 +127,5 @@ if (BUILD_CUDA) ${NCCL_INCLUDE_DIR} ) endif() + +add_definitions(-DHALF_ENABLE_CPP11_USER_LITERALS=0) diff --git a/cmake/third_party/half.cmake b/cmake/third_party/half.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d068bf612d8dfac47d2b770f3bfb6fce61eaafe7 --- /dev/null +++ b/cmake/third_party/half.cmake @@ -0,0 +1,35 @@ +include (ExternalProject) + +set(HALF_INCLUDE_DIR ${THIRD_PARTY_DIR}/half/include) + +set(HALF_URL https://cfhcable.dl.sourceforge.net/project/half/half/1.12.0/half-1.12.0.zip) +set(HALF_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/half/src/half) + +set(HALF_HEADERS + "${HALF_BASE_DIR}/include/half.hpp" +) + +if(BUILD_THIRD_PARTY) + +ExternalProject_Add(half + PREFIX half + URL ${HALF_URL} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + BUILD_IN_SOURCE 1 + INSTALL_COMMAND "" +) + +add_custom_target(half_create_header_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${HALF_INCLUDE_DIR} + DEPENDS half) + +add_custom_target(half_copy_headers_to_destination + DEPENDS half_create_header_dir) + +foreach(header_file ${HALF_HEADERS}) + add_custom_command(TARGET half_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${HALF_INCLUDE_DIR}) +endforeach() +endif(BUILD_THIRD_PARTY) diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index 8134fda32299958c6f613969826f793540568c89..d3ed86f26320c46418f022307641ca60eb6cd8d8 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -193,8 +193,11 @@ int Actor::HandlerNormal(const ActorMsg& msg) { Regst* regst = msg.regst(); if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) { CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst)); - NormalProcessNaiveReadableRegstMsg( - naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id())); + const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id()); + CHECK(rdeq.empty() == false); + if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) { + NormalProcessNaiveReadableDataRegstMsg(rdeq); + } } else if (TryUpdtStateAsProducedRegst(regst) == 0) { // do nothing } else { @@ -325,8 +328,9 @@ void Actor::AsyncSendConsumedCtrlRegstMsgToProducer() { CHECK_GE(reg_deq.size(), returned_regst_num); for (size_t i = 0; i < returned_regst_num; ++i) { Regst* regst = reg_deq.at(i); - AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst)); + // must access regst before sending it to producer regst_desc_ids.push_back(regst->regst_desc_id()); + AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst)); } }); naive_consumed_rs_.PopFrontRegsts(regst_desc_ids); @@ -452,8 +456,10 @@ void Actor::AsyncSendRegstMsgToProducer(Regst* regst) { } void Actor::AsyncSendRegstMsgToProducer(Regst* regst, int64_t producer) { + // must access regst before sending it to producer + int64_t regst_desc_id = regst->regst_desc_id(); AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, producer, regst)); - naive_consumed_rs_.TryPopFrontRegst(regst->regst_desc_id()); + naive_consumed_rs_.TryPopFrontRegst(regst_desc_id); } Regst* Actor::GetSoleProducedRegst4RegstDescId(int64_t regst_desc_id) { diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index a3dc4d3de0bac88503c7f1975dc991b83f14b36b..5e23a330e37384cfda5ccd9a8ebac9a648c86021 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -128,7 +128,7 @@ class Actor { // Process Msg virtual void NormalProcessCustomizedEordMsg(const ActorMsg&) {} - virtual void NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>&) {} + virtual void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) {} virtual void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) { UNIMPLEMENTED(); } virtual bool NormalTryProcessReadableMsgFromOtherMachine(const ActorMsg&) { return false; } int TryUpdtStateAsProducedRegst(Regst* regst); diff --git a/oneflow/core/actor/boxing_actor.cpp b/oneflow/core/actor/boxing_actor.cpp index 06394b0c3e21d8fbe60541049b5c9da0957f39a8..88325dec3bf74bfa0973d394403244a30d287c1c 100644 --- a/oneflow/core/actor/boxing_actor.cpp +++ b/oneflow/core/actor/boxing_actor.cpp @@ -9,7 +9,7 @@ void BoxingActor::VirtualActorInit(const TaskProto& task_proto) { OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal); } -void BoxingActor::NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>& rq) { +void BoxingActor::NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>& rq) { if (rq.back()->packed_blob()->max_col_num() > 1 && col_id_order_ == ColIdOrder::kUnCertain) { TrySetColIdOrder(rq.back()); } diff --git a/oneflow/core/actor/boxing_actor.h b/oneflow/core/actor/boxing_actor.h index 88189f89dd9cd02387bd24751869718f88013d5b..32609ecb842cc60dd7247057a693de06bd5a8e47 100644 --- a/oneflow/core/actor/boxing_actor.h +++ b/oneflow/core/actor/boxing_actor.h @@ -14,7 +14,7 @@ class BoxingActor final : public Actor { void VirtualActorInit(const TaskProto&) override; private: - void NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>&) override; + void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) override; void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer(); void VirtualAsyncSendNaiveConsumedRegstMsgToProducer(); diff --git a/oneflow/core/actor/loss_print_compute_actor.h b/oneflow/core/actor/loss_print_compute_actor.h index 551a3d391561fcc2374766f55212c6a07f1f5834..f269d45c42a4d33515f00ea494e5f6691067a3ef 100644 --- a/oneflow/core/actor/loss_print_compute_actor.h +++ b/oneflow/core/actor/loss_print_compute_actor.h @@ -9,9 +9,13 @@ class LossPrintCompActor final : public SinkCompActor { public: OF_DISALLOW_COPY_AND_MOVE(LossPrintCompActor); LossPrintCompActor() = default; - ~LossPrintCompActor() = default; + ~LossPrintCompActor() override = default; private: + void VirtualSinkCompActorInit(const TaskProto&) override { timestamp_ = 0; } + void* NewOther() override { return ×tamp_; } + + double timestamp_ = 0; }; } // namespace oneflow diff --git a/oneflow/core/actor/naive_actor.cpp b/oneflow/core/actor/naive_actor.cpp index 642b5ecc1132b833f7b07d3ba0454b1b6df29ee9..f32db4bb1087f2398cb0068d81c69bd6db909173 100644 --- a/oneflow/core/actor/naive_actor.cpp +++ b/oneflow/core/actor/naive_actor.cpp @@ -12,4 +12,6 @@ void NaiveActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { }); } +REGISTER_ACTOR(TaskType::kReduceIdentity, NaiveActor); + } // namespace oneflow diff --git a/oneflow/core/actor/normal_backward_compute_actor.cpp b/oneflow/core/actor/normal_backward_compute_actor.cpp index 61313aee6362908874ad2ee157a4238664be44ec..577136c906c8592ecc6d9a51eb44dd067cfaa7c7 100644 --- a/oneflow/core/actor/normal_backward_compute_actor.cpp +++ b/oneflow/core/actor/normal_backward_compute_actor.cpp @@ -33,7 +33,7 @@ void NormalBackwardCompActor::ForEachCurCustomizedReadableRegst( } } -void NormalBackwardCompActor::NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>& rq) { +void NormalBackwardCompActor::NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>& rq) { if (rq.size() == 1 && rq.front()->regst_desc_id() == any_out_diff_regst_desc_id_) { AsyncReturnModelRegstUntilModelVersionIdEqual( GetModelVersionIdFromPieceId(rq.front()->piece_id(), actual_num_of_piece_in_batch_)); diff --git a/oneflow/core/actor/normal_backward_compute_actor.h b/oneflow/core/actor/normal_backward_compute_actor.h index fd0bc969520f71dfefcfc36e3b56f50fcd722026..4cf1ea961351ba6c90b2948819c0aa2bf9a425d7 100644 --- a/oneflow/core/actor/normal_backward_compute_actor.h +++ b/oneflow/core/actor/normal_backward_compute_actor.h @@ -15,7 +15,7 @@ class NormalBackwardCompActor final : public CompActor { private: void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override; - void NormalProcessNaiveReadableRegstMsg(const std::deque<Regst*>&) override; + void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void Act() override; bool IsCustomizedReadReady() override; diff --git a/oneflow/core/actor/reduce_concat_compute_actor.cpp b/oneflow/core/actor/reduce_concat_compute_actor.cpp index 8cedf4dcf37355dc37a78425694c6e7c153c1910..dc7b7bf87b64bc1efc013fe941d8cf34ff539e05 100644 --- a/oneflow/core/actor/reduce_concat_compute_actor.cpp +++ b/oneflow/core/actor/reduce_concat_compute_actor.cpp @@ -2,16 +2,6 @@ namespace oneflow { -void ReduceConcatCompActor::VirtualCompActorInit(const TaskProto& proto) { - InputWiseCompActor::Init(proto); -} - -void ReduceConcatCompActor::SetKernelCtxOther(void** other) { - int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id()); - other_val_ = std::make_pair(in_bn_id, EnableInplace()); - *other = static_cast<void*>(&other_val_); -} - REGISTER_ACTOR(TaskType::kReduceConcat, ReduceConcatCompActor); } // namespace oneflow diff --git a/oneflow/core/actor/reduce_concat_compute_actor.h b/oneflow/core/actor/reduce_concat_compute_actor.h index e482fce7d0df4003be929f0227731fbbfd591528..740c6f28088d92fffb64554a9550f2838797bb37 100644 --- a/oneflow/core/actor/reduce_concat_compute_actor.h +++ b/oneflow/core/actor/reduce_concat_compute_actor.h @@ -1,21 +1,15 @@ #ifndef ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_ #define ONEFLOW_CORE_ACTOR_REDUCE_CONCAT_COMPUTE_ACTOR_H_ -#include "oneflow/core/actor/input_wise_compute_actor.h" +#include "oneflow/core/actor/naive_actor.h" namespace oneflow { -class ReduceConcatCompActor final : public InputWiseCompActor { +class ReduceConcatCompActor final : public NaiveActor { public: OF_DISALLOW_COPY_AND_MOVE(ReduceConcatCompActor); ReduceConcatCompActor() = default; ~ReduceConcatCompActor() = default; - - private: - void VirtualCompActorInit(const TaskProto& proto) override; - void SetKernelCtxOther(void** other) override; - - std::pair<int64_t, bool> other_val_; }; } // namespace oneflow diff --git a/oneflow/core/actor/reduce_split_compute_actor.cpp b/oneflow/core/actor/reduce_split_compute_actor.cpp index d5e58b32fae21025d2ee5743860d13cf44529cbb..e107fa957239b0d9ed3f5460e87c51cb4a29b77a 100644 --- a/oneflow/core/actor/reduce_split_compute_actor.cpp +++ b/oneflow/core/actor/reduce_split_compute_actor.cpp @@ -2,15 +2,6 @@ namespace oneflow { -void ReduceSplitCompActor::VirtualCompActorInit(const TaskProto& proto) { - InputWiseCompActor::Init(proto); -} - -void ReduceSplitCompActor::SetKernelCtxOther(void** other) { - other_val_ = EnableInplace(); - *other = static_cast<void*>(&other_val_); -} - REGISTER_ACTOR(TaskType::kReduceSplit, ReduceSplitCompActor); } // namespace oneflow diff --git a/oneflow/core/actor/reduce_split_compute_actor.h b/oneflow/core/actor/reduce_split_compute_actor.h index 7cbdc0c582b4781b8466aa29d6572ebec5466356..2a3f7752ac66dba825c3c2f1e6eb60153c8d5116 100644 --- a/oneflow/core/actor/reduce_split_compute_actor.h +++ b/oneflow/core/actor/reduce_split_compute_actor.h @@ -1,21 +1,15 @@ #ifndef ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_ #define ONEFLOW_CORE_ACTOR_REDUCE_SPLIT_COMPUTE_ACTOR_H_ -#include "oneflow/core/actor/input_wise_compute_actor.h" +#include "oneflow/core/actor/naive_actor.h" namespace oneflow { -class ReduceSplitCompActor final : public InputWiseCompActor { +class ReduceSplitCompActor final : public NaiveActor { public: OF_DISALLOW_COPY_AND_MOVE(ReduceSplitCompActor); ReduceSplitCompActor() = default; ~ReduceSplitCompActor() = default; - - private: - void VirtualCompActorInit(const TaskProto& proto) override; - void SetKernelCtxOther(void** other) override; - - bool other_val_; }; } // namespace oneflow diff --git a/oneflow/core/comm_network/epoll/epoll_comm_network.cpp b/oneflow/core/comm_network/epoll/epoll_comm_network.cpp index ae47cb890022b26616d554644781596c8758cef3..de85b83a808014d8dc4e699675865771db352cc3 100644 --- a/oneflow/core/comm_network/epoll/epoll_comm_network.cpp +++ b/oneflow/core/comm_network/epoll/epoll_comm_network.cpp @@ -114,13 +114,13 @@ void EpollCommNet::InitSockets() { ((this_machine.data_port_agent() != -1) ? (this_machine.data_port_agent()) : (this_listen_port))); } else { - for (this_listen_port = 1024; this_listen_port < MaxVal<uint16_t>(); ++this_listen_port) { + for (this_listen_port = 1024; this_listen_port < GetMaxVal<uint16_t>(); ++this_listen_port) { if (SockListen(listen_sockfd, this_listen_port, total_machine_num) == 0) { PushPort(this_machine_id, this_listen_port); break; } } - CHECK_LT(this_listen_port, MaxVal<uint16_t>()); + CHECK_LT(this_listen_port, GetMaxVal<uint16_t>()); } int32_t src_machine_count = 0; diff --git a/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp b/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp index 4ec79c7b2adb50c0407b250499f9ab15551cf606..1a747879b37cc02d0bf8123637bfcc53f96b90f3 100644 --- a/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp +++ b/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp @@ -65,7 +65,7 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) { qp_attr.ah_attr.grh.dgid.global.interface_id = peer_info.interface_id(); qp_attr.ah_attr.grh.flow_label = 0; qp_attr.ah_attr.grh.sgid_index = 0; - qp_attr.ah_attr.grh.hop_limit = MaxVal<uint8_t>(); + qp_attr.ah_attr.grh.hop_limit = GetMaxVal<uint8_t>(); qp_attr.ah_attr.dlid = peer_info.lid(); qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; diff --git a/oneflow/core/common/blas.h b/oneflow/core/common/blas.h index 4bbe52b2d9d0db1e5fa4f0754953d883d540a56f..9c8efe25f396bfc7827a15d8b71abf5d7126c686 100644 --- a/oneflow/core/common/blas.h +++ b/oneflow/core/common/blas.h @@ -8,14 +8,16 @@ namespace oneflow { -#define BLAS_NAME_SEQ \ - OF_PP_MAKE_TUPLE_SEQ(dot) \ - OF_PP_MAKE_TUPLE_SEQ(swap) \ - OF_PP_MAKE_TUPLE_SEQ(copy) \ - OF_PP_MAKE_TUPLE_SEQ(axpy) \ - OF_PP_MAKE_TUPLE_SEQ(scal) \ - OF_PP_MAKE_TUPLE_SEQ(gemv) \ - OF_PP_MAKE_TUPLE_SEQ(gemm) +#define BLAS_NAME_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dot) \ + OF_PP_MAKE_TUPLE_SEQ(swap) \ + OF_PP_MAKE_TUPLE_SEQ(copy) \ + OF_PP_MAKE_TUPLE_SEQ(axpy) \ + OF_PP_MAKE_TUPLE_SEQ(scal) \ + OF_PP_MAKE_TUPLE_SEQ(gemv) \ + OF_PP_MAKE_TUPLE_SEQ(gemm) \ + OF_PP_MAKE_TUPLE_SEQ(gemmBatched) \ + OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched) #define CBLAS_TEMPLATE(name) \ template<typename T, typename... Args> \ diff --git a/oneflow/core/common/data_type.h b/oneflow/core/common/data_type.h index 688afb9e19cc55858bf0cc32c58df7afc06862ec..4c763a606184ef4609f0d6e678f0fa5dbbf93b22 100644 --- a/oneflow/core/common/data_type.h +++ b/oneflow/core/common/data_type.h @@ -95,6 +95,37 @@ TRAIT_CONST_VAR(One, 1); #undef TRAIT_CONST_VAR +template<typename T> +struct MaxVal; +template<typename T> +struct MinVal; + +#define TRAIT_LIMIT_VAL(max_or_min, T, limit_value) \ + template<> \ + struct max_or_min##Val<T> final { \ + static_assert(alignof(int) == alignof(int32_t), "int32_t should be exactly int"); \ + static_assert(alignof(long long) == alignof(int64_t), "int64_t should be exactly long long"); \ + constexpr static T value = limit_value; \ + } + +TRAIT_LIMIT_VAL(Max, int8_t, CHAR_MAX); +TRAIT_LIMIT_VAL(Max, int32_t, INT_MAX); +TRAIT_LIMIT_VAL(Max, uint32_t, UINT_MAX); +TRAIT_LIMIT_VAL(Max, int64_t, LLONG_MAX); +TRAIT_LIMIT_VAL(Max, uint64_t, ULLONG_MAX); +TRAIT_LIMIT_VAL(Max, float, FLT_MAX); +TRAIT_LIMIT_VAL(Max, double, DBL_MAX); + +TRAIT_LIMIT_VAL(Min, int8_t, CHAR_MIN); +TRAIT_LIMIT_VAL(Min, int32_t, INT_MIN); +TRAIT_LIMIT_VAL(Min, uint32_t, 0); +TRAIT_LIMIT_VAL(Min, int64_t, LLONG_MIN); +TRAIT_LIMIT_VAL(Min, uint64_t, 0); +TRAIT_LIMIT_VAL(Min, float, -FLT_MAX); +TRAIT_LIMIT_VAL(Min, double, -DBL_MAX); + +#undef TRAIT_LIMIT_VAL + // Func bool IsIntegralDataType(DataType data_type); diff --git a/oneflow/core/common/protobuf.cpp b/oneflow/core/common/protobuf.cpp index 252f3c5e71f9e235e436f756714415c7eed5f3d7..72f03736aa34a703a9803ba62f87d3f2b025c23f 100644 --- a/oneflow/core/common/protobuf.cpp +++ b/oneflow/core/common/protobuf.cpp @@ -60,6 +60,14 @@ PbMessage* MutableMessageInPbMessage(PbMessage* msg, const std::string& field_na return r->MutableMessage(msg, fd); } +PbMessage* MutableMessageInPbMessage(PbMessage* msg, int field_index) { + const auto* d = const_cast<google::protobuf::Descriptor*>(msg->GetDescriptor()); + const auto* fd = const_cast<PbFd*>(d->FindFieldByNumber(field_index)); + CHECK_NOTNULL(fd); + const auto* r = const_cast<google::protobuf::Reflection*>(msg->GetReflection()); + return r->MutableMessage(msg, fd); +} + #define DEFINE_ADD_VAL_IN_PBRF(cpp_type, pb_type_name) \ template<> \ void AddValInPbRf(PbMessage* msg, const std::string& field_name, const cpp_type& val) { \ diff --git a/oneflow/core/common/protobuf.h b/oneflow/core/common/protobuf.h index 7663e98543fd1a4312578ad6037a37de048d2943..3ff546352f289da9b968188365c5ae255215c5f7 100644 --- a/oneflow/core/common/protobuf.h +++ b/oneflow/core/common/protobuf.h @@ -75,12 +75,19 @@ const PbRpf<T>& GetPbRpfFromPbMessage(const PbMessage& msg, const std::string& f return r->GetRepeatedPtrField<T>(msg, fd); } +template<typename T> +PbRpf<T>* MutPbRpfFromPbMessage(PbMessage* msg, const std::string& field_name) { + PROTOBUF_REFLECTION((*msg), field_name); + return r->MutableRepeatedPtrField<T>(msg, fd); +} + // Set In PbMessage template<typename T> void SetValInPbMessage(PbMessage* msg, const std::string& field_name, const T& val); PbMessage* MutableMessageInPbMessage(PbMessage*, const std::string& field_name); +PbMessage* MutableMessageInPbMessage(PbMessage*, int field_index); // Add In PbMessage RepeatedField diff --git a/oneflow/core/common/shape.cpp b/oneflow/core/common/shape.cpp index b7697b494151c2ae7989efc4df8fb0d7caed0c60..7be79b858ac0ec3528dbf4999ea9d0b312c65c7f 100644 --- a/oneflow/core/common/shape.cpp +++ b/oneflow/core/common/shape.cpp @@ -19,14 +19,20 @@ Shape& Shape::operator=(const Shape& shape) { bool Shape::operator==(const Shape& rhs) const { return dim_vec_ == rhs.dim_vec_; } -std::string Shape::DebugStr() const { +std::string Shape::ToString() const { std::stringstream ss; - ss << "{"; - for (int64_t dim : dim_vec_) { ss << dim << ","; } - ss << "(" << elem_cnt_ << ")}"; + int32_t idx = 0; + ss << "("; + for (int64_t dim : dim_vec_) { + ss << dim; + if (++idx != dim_vec_.size() || dim_vec_.size() == 1) { ss << ","; } + } + ss << ")"; return ss.str(); } +std::string Shape::DebugStr() const { return ToString(); } + void Shape::ToProto(ShapeProto* ret) const { *(ret->mutable_dim()) = PbRf<int64_t>(dim_vec_.begin(), dim_vec_.end()); } @@ -57,4 +63,11 @@ std::ostream& operator<<(std::ostream& out, const Shape& shape) { return out; } +Shape Shape::CreateLeftExtendedShape(int num_axes) const { + CHECK_GE(num_axes, NumAxes()); + std::vector<int64_t> dim_vec = this->dim_vec(); + for (int i = 0; i < num_axes - NumAxes(); ++i) { dim_vec.insert(dim_vec.begin(), 1LL); } + return Shape(dim_vec); +} + } // namespace oneflow diff --git a/oneflow/core/common/shape.h b/oneflow/core/common/shape.h index e2ad1277349de0c7ba74db85ec2ec2398173cd0f..161089d45eb89361328c16dd8fabe1c19d21339b 100644 --- a/oneflow/core/common/shape.h +++ b/oneflow/core/common/shape.h @@ -19,6 +19,7 @@ class Shape final { bool operator==(const Shape& rhs) const; bool operator!=(const Shape& rhs) const { return !(*this == rhs); } std::string DebugStr() const; + std::string ToString() const; void ToProto(ShapeProto*) const; @@ -34,6 +35,8 @@ class Shape final { int64_t Count(int64_t begin_axis, int64_t end_axis) const; int64_t Count(int64_t begin_axis) const; + Shape CreateLeftExtendedShape(int num_axes) const; + private: void UpdateElemCnt(); diff --git a/oneflow/core/common/switch_func.h b/oneflow/core/common/switch_func.h index 4373b0160b90b918681757f512cb18c2e6e700b3..a36f232adf799ef25a21f76e2c715f2c4451b73a 100644 --- a/oneflow/core/common/switch_func.h +++ b/oneflow/core/common/switch_func.h @@ -29,27 +29,27 @@ auto SwitchCase(Args&&... args) -> decltype(std::make_tuple(std::forward<Args>(a // CTRV example: (float, DataType::kFloat) // TYPED_CTRV_SEQ example: (DataType, ((float, DataType::kFloat))) -#define MAKE_TYPED_CTRV_SEQ(runtime_value_type, ctrv_pair_seq) (runtime_value_type, ctrv_pair_seq) - #define MAKE_DATA_TYPE_CTRV_SEQ(data_type_seq) MAKE_TYPED_CTRV_SEQ(DataType, data_type_seq) #define MAKE_DEVICE_TYPE_CTRV_SEQ(device_type_seq) \ MAKE_TYPED_CTRV_SEQ(DeviceType, \ - OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPE_SEQ, device_type_seq)) + OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, device_type_seq)) #define MAKE_NDIM_CTRV_SEQ(ndim_seq) \ - MAKE_TYPED_CTRV_SEQ(int32_t, OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPE_SEQ, ndim_seq)) + MAKE_TYPED_CTRV_SEQ(int32_t, OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, ndim_seq)) #define MAKE_STRINGIZED_DATA_TYPE_CTRV(data_type_pair) \ (OF_PP_PAIR_FIRST(data_type_pair), OF_PP_STRINGIZE(OF_PP_PAIR_FIRST(data_type_pair))) #define MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(data_type_seq) \ (std::string, OF_PP_SEQ_MAP(MAKE_STRINGIZED_DATA_TYPE_CTRV, data_type_seq)) +#define MAKE_TYPED_CTRV_SEQ(runtime_value_type, ctrv_pair_seq) (runtime_value_type, ctrv_pair_seq) + // internal preprocessor macros #define OF_PP_I_MAKE_SWITCH_ENTRY_MAP_PAIR(switch_case, func_args_type, func) \ {switch_case, \ [](func_args_type&&... args) { return func(std::forward<func_args_type>(args)...); }}, -#define OF_PP_I_MAKE_REPLICATE_TUPE_SEQ(x) OF_PP_MAKE_TUPLE_SEQ(x, x) +#define OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ(x) OF_PP_MAKE_TUPLE_SEQ(x, x) #define OF_PP_I_MAKE_SWITCH_FUNC_ENTRY_1(make_template_func, func_name, func_args_type, \ switch_case_pair0) \ diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h index e9db5351b7868b4cdc7ceccd63938937a8f6357a..4caa56a94396e1a90bdfcd13ff5bd4173f82c4c1 100644 --- a/oneflow/core/common/util.h +++ b/oneflow/core/common/util.h @@ -158,7 +158,9 @@ inline uint32_t NewRandomSeed() { #define DEVICE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) #endif -#define DIM_SEQ (1)(2)(3)(4)(5)(6)(7)(8) +#define DIM_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(1) \ + OF_PP_MAKE_TUPLE_SEQ(2) OF_PP_MAKE_TUPLE_SEQ(3) OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) #define BOOL_SEQ (true)(false) #define PARALLEL_POLICY_SEQ (ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel) @@ -203,12 +205,12 @@ void Erase(T& container, const std::function<bool(const typename T::value_type&) } template<typename T> -inline T MinVal() { +inline T GetMinVal() { return std::numeric_limits<T>::lowest(); } template<typename T> -inline T MaxVal() { +inline T GetMaxVal() { return std::numeric_limits<T>::max(); } @@ -220,6 +222,8 @@ inline T MaxVal() { #if defined(__GNUC__) #define ALWAYS_INLINE __attribute__((always_inline)) +#elif defined(__CUDACC__) +#define ALWAYS_INLINE __forceinline__ #else #define ALWAYS_INLINE inline #endif diff --git a/oneflow/core/device/cuda_util.cpp b/oneflow/core/device/cuda_util.cpp index 7b080712f3b3836fd2462b288dc5077e68959c95..6fddfbb680f5e7aa3be2606642b9e5ef4940d1cb 100644 --- a/oneflow/core/device/cuda_util.cpp +++ b/oneflow/core/device/cuda_util.cpp @@ -45,8 +45,17 @@ const char* CurandGetErrorString(curandStatus_t error) { return "Unknown curand status"; } +cudaDeviceProp global_device_prop; + } // namespace +void InitGlobalCudaDeviceProp() { cudaGetDeviceProperties(&global_device_prop, 0); } + +int32_t GetSMCudaMaxBlocksNum() { + return global_device_prop.multiProcessorCount * global_device_prop.maxThreadsPerMultiProcessor + / kCudaThreadsNumPerBlock; +} + template<> void CudaCheck(cudaError_t error) { CHECK_EQ(error, cudaSuccess) << cudaGetErrorString(error); @@ -73,6 +82,14 @@ size_t GetAvailableGpuMemSize(int dev_id) { return prop.totalGlobalMem; } +cudaDataType_t GetCudaDataType(DataType val) { +#define MAKE_ENTRY(type_cpp, type_cuda) \ + if (val == GetDataType<type_cpp>::value) { return type_cuda; } + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CUDA_DATA_TYPE_SEQ); +#undef MAKE_ENTRY + UNIMPLEMENTED(); +} + #endif // WITH_CUDA } // namespace oneflow diff --git a/oneflow/core/device/cuda_util.h b/oneflow/core/device/cuda_util.h index 8c830f4201333720dd149d486d125ba1156f2d1b..d1abc1252ebef5854c5995eccad4fb5e65ee3224 100644 --- a/oneflow/core/device/cuda_util.h +++ b/oneflow/core/device/cuda_util.h @@ -19,16 +19,29 @@ template<typename T> void CudaCheck(T error); // CUDA: grid stride looping -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < (n); \ + i += step) -const int32_t kCudaThreadsNumPerBlock = 512; +const int32_t kCudaThreadsNumPerBlock = 1024; const int32_t kCudaMaxBlocksNum = 4096; +int32_t GetSMCudaMaxBlocksNum(); +void InitGlobalCudaDeviceProp(); + inline int32_t BlocksNum4ThreadsNum(const int32_t n) { return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, kCudaMaxBlocksNum); } +inline int32_t SMBlocksNum4ThreadsNum(const int32_t n) { + return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, + GetSMCudaMaxBlocksNum()); +} + +#define RUN_CUDA_KERNEL(func, device_ctx_ptr, thread_num, ...) \ + func<<<SMBlocksNum4ThreadsNum(thread_num), kCudaThreadsNumPerBlock, 0, \ + (device_ctx_ptr)->cuda_stream()>>>(__VA_ARGS__) + size_t GetAvailableGpuMemSize(int dev_id); #define CUDA_WORK_TYPE_SEQ \ @@ -38,14 +51,31 @@ size_t GetAvailableGpuMemSize(int dev_id); OF_PP_MAKE_TUPLE_SEQ(kNcclScatter) \ OF_PP_MAKE_TUPLE_SEQ(kNcclGather) \ OF_PP_MAKE_TUPLE_SEQ(kMix) \ + OF_PP_MAKE_TUPLE_SEQ(kReduceCtrl) \ OF_PP_MAKE_TUPLE_SEQ(kMdUpdt) enum class CudaWorkType { #define DECLARE_CUDA_WORK_TYPE(type) type, OF_PP_FOR_EACH_TUPLE(DECLARE_CUDA_WORK_TYPE, CUDA_WORK_TYPE_SEQ) }; + inline size_t GetCudaWorkTypeSize() { return OF_PP_SEQ_SIZE(CUDA_WORK_TYPE_SEQ); } +#define CUDA_DATA_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(float, CUDA_R_32F) \ + OF_PP_MAKE_TUPLE_SEQ(double, CUDA_R_64F) + +cudaDataType_t GetCudaDataType(DataType); + +template<typename T> +struct CudaDataType; + +#define SPECIALIZE_CUDA_DATA_TYPE(type_cpp, type_cuda) \ + template<> \ + struct CudaDataType<type_cpp> : std::integral_constant<cudaDataType_t, type_cuda> {}; +OF_PP_FOR_EACH_TUPLE(SPECIALIZE_CUDA_DATA_TYPE, CUDA_DATA_TYPE_SEQ); +#undef SPECIALIZE_CUDA_DATA_TYPE + } // namespace oneflow #endif // WITH_CUDA diff --git a/oneflow/core/graph/accuracy_compute_task_node.cpp b/oneflow/core/graph/accuracy_compute_task_node.cpp index d7b94e3723a29ace9c17897bbb777f78c0337eaa..1185d93a955305498a3605d97b58403e90fc6e4a 100644 --- a/oneflow/core/graph/accuracy_compute_task_node.cpp +++ b/oneflow/core/graph/accuracy_compute_task_node.cpp @@ -6,6 +6,7 @@ namespace oneflow { void AccuracyCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("accuracy", false); + ProduceRegst("data_tmp", true); for (TaskEdge* edge : out_edges()) { BindEdgeWithProducedRegst(edge, "accuracy"); } } @@ -27,6 +28,7 @@ void AccuracyCompTaskNode::BuildExecGphAndRegst() { accuracy_regst->AddLbi(accuracy_op->BnInOp2Lbi(obn)); accuracy_node->BindBnWithRegst(obn, accuracy_regst); } + accuracy_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp")); accuracy_node->InferBlobDescs(parallel_ctx()); } diff --git a/oneflow/core/graph/boxing_task_node.cpp b/oneflow/core/graph/boxing_task_node.cpp index 8b70a7054b3ecb02598ff474c189d4226ba31764..0c33835056420994410c4a3f66e551913d907fb8 100644 --- a/oneflow/core/graph/boxing_task_node.cpp +++ b/oneflow/core/graph/boxing_task_node.cpp @@ -2,6 +2,7 @@ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/graph/logical_node.h" #include "oneflow/core/graph/logical_graph.h" +#include "oneflow/core/graph/op_graph.h" #include "oneflow/core/operator/operator.h" namespace oneflow { @@ -55,20 +56,20 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(InBoxingTaskNode, DataConcatAndDataSplit) { conf->mutable_concat_box()->set_axis(0); BoxSplitConf* split_conf = conf->mutable_split_box(); split_conf->set_axis(0); - BalancedSplitter bs(Global<JobDesc>::Get()->PieceSize(), - out_logical->parallel_desc()->parallel_num()); - SetBoxSplitPart(sorted_out_edges, bs, split_conf); + const std::string& out_op_name = out_logical->op_vec().at(0)->op_name(); + SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi), + split_conf); } DEFINE_BLD_BOXING_OP_CONF_METHOD(OutBoxingTaskNode, DataConcatAndDataSplit) { conf->mutable_concat_box()->set_axis(0); BoxSplitConf* split_conf = conf->mutable_split_box(); split_conf->set_axis(0); - BalancedSplitter in_bs(Global<JobDesc>::Get()->PieceSize(), - in_logical->parallel_desc()->parallel_num()); + const std::string& in_op_name = in_logical->op_vec().at(0)->op_name(); + BalancedSplitter in_bs = Global<OpGraph>::Get()->GetBalancedSplitter(in_op_name, lbi); Range in_range = in_bs.At(sorted_in_edges.front().parallel_id_min, sorted_in_edges.back().parallel_id_max); - BalancedSplitter out_bs(Global<JobDesc>::Get()->PieceSize(), - out_logical->parallel_desc()->parallel_num()); + const std::string& out_op_name = out_logical->op_vec().at(0)->op_name(); + BalancedSplitter out_bs = Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi); for (const EdgeInfo& out_edge : sorted_out_edges) { Range out_range = out_bs.At(out_edge.parallel_id_min, out_edge.parallel_id_max); Range intersectant_range = FindIntersectant(in_range, out_range); @@ -82,44 +83,89 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndClone) { DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndModelSplit) { conf->mutable_concat_box()->set_axis(0); BoxSplitConf* split_conf = conf->mutable_split_box(); - split_conf->set_axis(out_logical->GetModelSplitAxis()); - BalancedSplitter bs(out_logical->GetMaxModelSplitNum(), - out_logical->parallel_desc()->parallel_num()); - SetBoxSplitPart(sorted_out_edges, bs, split_conf); + const std::string& out_op_name = out_logical->op_vec().at(0)->op_name(); + split_conf->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(out_op_name, lbi)); + SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi), + split_conf); } DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndDataSplit) { - conf->mutable_concat_box()->set_axis(in_logical->GetModelSplitAxis()); + const std::string& in_op_name = in_logical->op_vec().at(0)->op_name(); + conf->mutable_concat_box()->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(in_op_name, lbi)); BoxSplitConf* split_conf = conf->mutable_split_box(); split_conf->set_axis(0); - BalancedSplitter bs(Global<JobDesc>::Get()->PieceSize(), - out_logical->parallel_desc()->parallel_num()); - SetBoxSplitPart(sorted_out_edges, bs, split_conf); + + const std::string& out_op_name = out_logical->op_vec().at(0)->op_name(); + SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi), + split_conf); } DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndClone) { - conf->mutable_concat_box()->set_axis(in_logical->GetModelSplitAxis()); + const std::string& in_op_name = in_logical->op_vec().at(0)->op_name(); + conf->mutable_concat_box()->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(in_op_name, lbi)); conf->mutable_clone_box(); } DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndDataSplit) { conf->mutable_add_box(); BoxSplitConf* split_conf = conf->mutable_split_box(); split_conf->set_axis(0); - BalancedSplitter bs(Global<JobDesc>::Get()->PieceSize(), - out_logical->parallel_desc()->parallel_num()); - SetBoxSplitPart(sorted_out_edges, bs, split_conf); + + const std::string& out_op_name = out_logical->op_vec().at(0)->op_name(); + SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi), + split_conf); } DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndModelSplit) { conf->mutable_add_box(); BoxSplitConf* split_conf = conf->mutable_split_box(); - split_conf->set_axis(out_logical->GetModelSplitAxis()); - BalancedSplitter bs(out_logical->GetMaxModelSplitNum(), - out_logical->parallel_desc()->parallel_num()); - SetBoxSplitPart(sorted_out_edges, bs, split_conf); + const std::string& out_op_name = out_logical->op_vec().at(0)->op_name(); + split_conf->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(out_op_name, lbi)); + SetBoxSplitPart(sorted_out_edges, Global<OpGraph>::Get()->GetBalancedSplitter(out_op_name, lbi), + split_conf); } DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndClone) { conf->mutable_add_box(); conf->mutable_clone_box(); } +void SetBoxingOpConfBySbpParallel( + BoxingOpConf* conf, const LogicalBlobId& lbi, const Operator& in_op, const Operator& out_op, + const std::vector<BoxingTaskNode::EdgeInfo>& sorted_edges, + const std::function<SbpParallel(const std::string&, const LogicalBlobId&)>& GetSbpParallel) { + SbpParallel in_sbp = GetSbpParallel(in_op.op_name(), lbi); + if (in_sbp.has_split_parallel()) { + conf->mutable_concat_box()->set_axis(in_sbp.split_parallel().axis()); + } else if (in_sbp.has_partial_sum_parallel()) { + conf->mutable_add_box(); + } else { + UNIMPLEMENTED(); + } + SbpParallel out_sbp = GetSbpParallel(out_op.op_name(), lbi); + if (out_sbp.has_split_parallel()) { + BoxSplitConf* split_conf = conf->mutable_split_box(); + split_conf->set_axis(out_sbp.split_parallel().axis()); + const auto& bs = Global<OpGraph>::Get()->GetBalancedSplitter(out_op.op_name(), lbi); + SetBoxSplitPart(sorted_edges, bs, split_conf); + } else if (out_sbp.has_broadcast_parallel()) { + conf->mutable_clone_box(); + } else { + UNIMPLEMENTED(); + } +} + +DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, FwSbpParallel) { + SetBoxingOpConfBySbpParallel(conf, lbi, *in_logical->SoleOp(), *out_logical->SoleOp(), + sorted_out_edges, + [&](const std::string& op_name, const LogicalBlobId& lbi) { + return Global<OpGraph>::Get()->GetSbpParallel(op_name, lbi); + }); +} + +DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, BwSbpParallel) { + SetBoxingOpConfBySbpParallel( + conf, lbi, *in_logical->SoleOp(), *out_logical->SoleOp(), sorted_out_edges, + [&](const std::string& op_name, const LogicalBlobId& lbi) { + return GetDualSbpParallel(Global<OpGraph>::Get()->GetSbpParallel(op_name, lbi)); + }); +} + void BoxingTaskNode::InitLogical2SortedEdgeInfo( const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() const, TaskEdge* (TaskNode::*SoleEdge)() const, TaskNode* (TaskEdge::*SoleNode)() const, @@ -128,8 +174,8 @@ void BoxingTaskNode::InitLogical2SortedEdgeInfo( for (const TaskEdge* edge : (this->*GetEdges)()) { EdgeInfo edge_info; edge_info.edge = edge; - edge_info.parallel_id_min = MaxVal<int64_t>(); - edge_info.parallel_id_max = MinVal<int64_t>(); + edge_info.parallel_id_min = GetMaxVal<int64_t>(); + edge_info.parallel_id_max = GetMinVal<int64_t>(); std::queue<const TaskNode*> node_queue; node_queue.push((edge->*SoleNode)()); const LogicalNode* logical = nullptr; diff --git a/oneflow/core/graph/boxing_task_node.h b/oneflow/core/graph/boxing_task_node.h index 69dfbc678a8399d65c6a9fd73259358f10f63317..447f56d78313af8aa494eb1cdeb1c246903806b4 100644 --- a/oneflow/core/graph/boxing_task_node.h +++ b/oneflow/core/graph/boxing_task_node.h @@ -42,6 +42,8 @@ class BoxingTaskNode : public TaskNode { DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndDataSplit); DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndModelSplit); DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndClone); + DECLARE_BLD_BOXING_OP_CONF_METHOD(FwSbpParallel); + DECLARE_BLD_BOXING_OP_CONF_METHOD(BwSbpParallel); private: void InitLogical2SortedEdgeInfo(const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() diff --git a/oneflow/core/graph/chain_graph.cpp b/oneflow/core/graph/chain_graph.cpp index 5805a3bd014cb1936902e9130de3684cf26fb5dc..8767cd79b3fecbb439611f174571214b776cdc00 100644 --- a/oneflow/core/graph/chain_graph.cpp +++ b/oneflow/core/graph/chain_graph.cpp @@ -1,6 +1,7 @@ #include "oneflow/core/graph/chain_graph.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/graph/task_node.h" +#include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/thread/thread_pool.h" #include "oneflow/core/common/blocking_counter.h" @@ -66,7 +67,7 @@ void ChainMerger::InitChains() { Chain& cur_chain = chain_list_.back(); cur_chain.nodes = {task_node}; cur_chain.stream_area_id = - std::make_pair(task_node->area_id(), task_node->GlobalWorkStreamId()); + std::make_pair(task_node->AreaId4ChainMerge(), task_node->GlobalWorkStreamId()); cur_chain.ancestors.resize(bitset_num); cur_chain.ancestors_and_this.resize(bitset_num); CarefullySetBitset(&(cur_chain.ancestors_and_this), GetTaskUid(task_node)); @@ -81,7 +82,7 @@ void ChainMerger::InitChains() { bool ChainMerger::DoMerge(std::list<ChainIt>& chains, ChainIt rhs) { CHECK_EQ(rhs->nodes.size(), 1); // rm kMdUpdtArea chain merge - if (rhs->nodes.front()->area_id() == kMdUpdtArea) { return false; } + if (rhs->nodes.front()->AreaId4ChainMerge() == kMdUpdtArea) { return false; } for (auto chains_it = chains.rbegin(); chains_it != chains.rend(); ++chains_it) { ChainIt lhs = *chains_it; if (IsSubset(lhs, rhs)) { @@ -128,6 +129,23 @@ bool ChainMerger::IsSubset(const ChainIt& lhs, const ChainIt& rhs) const { return true; } +bool IsForwardOnlyTaskNode(TaskNode* node) { + auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node); + if (fw_node == nullptr) { return true; } + return fw_node->HasBackwardCompTaskNode() == false; +}; + +bool NoOutRegstConsumedByBwNode(TaskNode* node) { + auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node); + if (fw_node == nullptr) { return false; } + for (TaskEdge* edge : fw_node->out_edges()) { + auto* fw_consumer = dynamic_cast<NormalForwardCompTaskNode*>(edge->dst_node()); + if (fw_consumer == nullptr) { return false; } + if (fw_consumer->HasBackwardCompTaskNode()) { return false; } + } + return true; +}; + } // namespace std::string ChainNode::VisualStr() const { @@ -145,6 +163,7 @@ ChainGraph::ChainGraph(const TaskGraph& task_gph) : task_gph_(task_gph) { std::vector<std::vector<TaskNode*>> chains; GroupTaskNodesByMachineAndCollectAncestors(task_gph, &machine2tasks, &node2ancestors); MergeTaskNodes(machine2tasks, node2ancestors, &chains); + for (auto& task_nodes : chains) { PrioritizeUntrainableTaskNode(&task_nodes); } InitChainNode(chains); InitChainEdge(chains); SetChainId4ChainNode(); @@ -158,7 +177,7 @@ void ChainGraph::GroupTaskNodesByMachineAndCollectAncestors( (*machine2tasks)[node->machine_id()].emplace_back(node); CHECK(node2ancestors->emplace(node, HashSet<TaskNode*>()).second); // to reduce memory consumption - if (node->area_id() == kMdUpdtArea) { return; } + if (node->AreaId4ChainMerge() == kMdUpdtArea) { return; } node->ForEachNodeOnInEdge([&](TaskNode* in_node) { if (IsBackEdge(in_node, node)) { return; } (*node2ancestors)[node].insert(in_node); @@ -191,6 +210,85 @@ void ChainGraph::MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>& counter.WaitUntilCntEqualZero(); } +void ChainGraph::PrioritizeUntrainableTaskNode(std::vector<TaskNode*>* task_nodes) const { + HashSet<TaskNode*> task_nodes_set(task_nodes->begin(), task_nodes->end()); + auto IsInSubset = [&](TaskNode* node) { + return task_nodes_set.find(node) != task_nodes_set.end(); + }; + auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) { + node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { + if (IsInSubset(node_on_in_edge)) { Handler(node_on_in_edge); } + }); + }; + auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) { + node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) { + if (IsInSubset(node_on_out_edge)) { Handler(node_on_out_edge); } + }); + }; + auto IsSourceNode = [&](TaskNode* node) { + int32_t in_node_num = 0; + ForEachInNode(node, [&](TaskNode* in_node) { ++in_node_num; }); + return in_node_num == 0; + }; + std::list<TaskNode*> starts; + for (TaskNode* node : task_nodes_set) { + if (IsSourceNode(node)) { starts.push_back(node); } + } + task_nodes->clear(); + auto IsPrior = [&](TaskNode* node) { + return IsForwardOnlyTaskNode(node) && NoOutRegstConsumedByBwNode(node); + }; + PartialPriorTopoForEachNode(starts, ForEachInNode, ForEachOutNode, IsPrior, + [&](TaskNode* node) { task_nodes->push_back(node); }); + HashSet<TaskNode*> task_nodes_set_check(task_nodes->begin(), task_nodes->end()); + CHECK(task_nodes_set == task_nodes_set_check); +} + +void ChainGraph::PartialPriorTopoForEachNode( + const std::list<TaskNode*> starts, + const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachInNode, + const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachOutNode, + const std::function<bool(TaskNode*)>& IsPrior, + const std::function<void(TaskNode*)>& Handler) const { + // collect prior nodes + HashSet<TaskNode*> prior_nodes; + auto IsTaskNodePrior = [&](TaskNode* node) { + if (!IsPrior(node)) { return false; } + bool is_prior = true; + ForEachInNode(node, [&](TaskNode* in_node) { + is_prior = is_prior && (prior_nodes.find(in_node) != prior_nodes.end()); + }); + return is_prior; + }; + std::list<TaskNode*> nodes; + task_gph_.TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](TaskNode* node) { + if (IsTaskNodePrior(node)) { CHECK(prior_nodes.emplace(node).second); } + nodes.push_back(node); + }); + // travel prior nodes; + auto ForEachPriorInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) { + ForEachInNode(node, [&](TaskNode* in_node) { + if (prior_nodes.find(in_node) != prior_nodes.end()) { Handler(in_node); } + }); + }; + auto ForEachPriorOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) { + ForEachOutNode(node, [&](TaskNode* out_node) { + if (prior_nodes.find(out_node) != prior_nodes.end()) { Handler(out_node); } + }); + }; + std::list<TaskNode*> prior_starts; + for (TaskNode* start : starts) { + if (IsTaskNodePrior(start)) { prior_starts.push_back(start); } + } + task_gph_.DfsTopoForEachNodeSortByDistanceToSink(prior_starts, ForEachPriorInNode, + ForEachPriorOutNode, Handler); + // travel other nodes ; + task_gph_.DfsTopoForEachNodeSortByDistanceToSink( + starts, ForEachInNode, ForEachOutNode, [&](TaskNode* node) { + if (prior_nodes.find(node) == prior_nodes.end()) { Handler(node); } + }); +} + void ChainGraph::InitChainNode(const std::vector<std::vector<TaskNode*>>& chains) { for (auto& chain : chains) { ChainNode* chain_node = new ChainNode(chain); diff --git a/oneflow/core/graph/chain_graph.h b/oneflow/core/graph/chain_graph.h index 8e097f19d890d559fb135a75455ddabc2cc06b04..59a61ba9dfe00bf9c7800a6dfbe5db62e96e134e 100644 --- a/oneflow/core/graph/chain_graph.h +++ b/oneflow/core/graph/chain_graph.h @@ -78,6 +78,13 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> { void MergeTaskNodes(const HashMap<int64_t, std::vector<TaskNode*>>& machine2tasks, const HashMap<TaskNode*, HashSet<TaskNode*>>& node2ancestors, std::vector<std::vector<TaskNode*>>* chains) const; + void PrioritizeUntrainableTaskNode(std::vector<TaskNode*>* task_nodes) const; + void PartialPriorTopoForEachNode( + const std::list<TaskNode*> starts, + const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachInNode, + const std::function<void(TaskNode*, const std::function<void(TaskNode*)>&)>& ForEachOutNode, + const std::function<bool(TaskNode*)>& IsPrior, + const std::function<void(TaskNode*)>& Handler) const; void InitChainNode(const std::vector<std::vector<TaskNode*>>& chains); void InitChainEdge(const std::vector<std::vector<TaskNode*>>& chains); void SetChainId4ChainNode(); diff --git a/oneflow/core/graph/chain_logical_graph.cpp b/oneflow/core/graph/chain_logical_graph.cpp deleted file mode 100644 index 16fcfeab6852dbec45ccef6748c7a17f9dce1173..0000000000000000000000000000000000000000 --- a/oneflow/core/graph/chain_logical_graph.cpp +++ /dev/null @@ -1,186 +0,0 @@ -#include "oneflow/core/operator/fully_connected_op.h" -#include "oneflow/core/graph/chain_logical_graph.h" -#include "oneflow/core/graph/logical_graph.h" - -namespace oneflow { - -struct ChainLogicalGraph::Chain { - std::vector<const LogicalNode*> nodes; - HashSet<const LogicalNode*> ancestors; - HashSet<const LogicalNode*> ancestors_and_this; - HashSet<const LogicalNode*> descendants; - HashSet<const LogicalNode*> descendants_and_this; - bool is_mergeable; - - bool IsParallelDescEqual(const Chain& rhs) const { - CHECK_GT(nodes.size(), 0); - CHECK_GT(rhs.nodes.size(), 0); - return nodes.front()->parallel_desc()->Equal(rhs.nodes.front()->parallel_desc().get()); - } -}; - -ChainLogicalGraph::ChainLogicalGraph(const LogicalGraph& logical_graph) { - std::list<Chain> chain_list; - HashMap<const LogicalNode*, std::list<Chain>::iterator> logical2chain_it; - HashMap<const LogicalNode*, size_t> logical2order_in_topo; - - InitChains(logical_graph, &chain_list, &logical2chain_it, &logical2order_in_topo); - MergeChains(&chain_list, &logical2chain_it); - SortNodesInChains(&chain_list, logical2order_in_topo); - BuildGraph(logical_graph, &chain_list); - ToDotWithAutoFilePath(); -} - -void ChainLogicalGraph::InitChains( - const LogicalGraph& logical_graph, std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it, - HashMap<const LogicalNode*, size_t>* logical2order_in_topo) { - logical_graph.ForEachNode([&](const LogicalNode* node) { - chain_list->emplace_back(); - logical2chain_it->insert({node, --chain_list->end()}); - Chain& chain = chain_list->back(); - chain.nodes = {node}; - chain.is_mergeable = IsLogicalNodeMergeable(node); - size_t order_in_topo = logical2order_in_topo->size(); - logical2order_in_topo->emplace(node, order_in_topo); - }); - - logical_graph.TopoForEachNode([&](const LogicalNode* node) { - auto cur_chain = logical2chain_it->at(node); - for (const LogicalEdge* edge : node->in_edges()) { - LogicalNode* pred_node = edge->src_node(); - auto pred_chain = logical2chain_it->at(pred_node); - cur_chain->ancestors.insert(pred_chain->ancestors.begin(), pred_chain->ancestors.end()); - cur_chain->ancestors.insert(pred_node); - } - cur_chain->ancestors_and_this.insert(cur_chain->ancestors.begin(), cur_chain->ancestors.end()); - cur_chain->ancestors_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end()); - }); - - logical_graph.ReverseTopoForEachNode([&](const LogicalNode* node) { - auto cur_chain = logical2chain_it->at(node); - for (const LogicalEdge* edge : node->out_edges()) { - LogicalNode* succ_node = edge->dst_node(); - auto succ_chain = logical2chain_it->at(succ_node); - cur_chain->descendants.insert(succ_chain->descendants.begin(), succ_chain->descendants.end()); - cur_chain->descendants.insert(succ_node); - } - cur_chain->descendants_and_this.insert(cur_chain->descendants.begin(), - cur_chain->descendants.end()); - cur_chain->descendants_and_this.insert(cur_chain->nodes.begin(), cur_chain->nodes.end()); - }); -} - -void ChainLogicalGraph::MergeChains( - std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) { - while (chain_list->size() > 1 && TryMergeTwoChains(chain_list, logical2chain_it)) {}; -} - -bool ChainLogicalGraph::TryMergeTwoChains( - std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) { - return TryMergeTwoParallelChains(chain_list, logical2chain_it) - || TryMergeTwoConnectedChains(chain_list, logical2chain_it); -} - -bool ChainLogicalGraph::TryMergeTwoParallelChains( - std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) { - for (auto lhs = chain_list->begin(); lhs != chain_list->end(); ++lhs) { - if (!lhs->is_mergeable) { continue; } - for (auto rhs = lhs; rhs != chain_list->end(); ++rhs) { - if (lhs == rhs) { continue; } - if (!rhs->is_mergeable) { continue; } - if (!lhs->IsParallelDescEqual(*rhs)) { continue; } - if (lhs->ancestors != rhs->ancestors || lhs->descendants != rhs->descendants) { continue; } - for (const LogicalNode* node : rhs->nodes) { - lhs->nodes.push_back(node); - lhs->ancestors_and_this.insert(node); - lhs->descendants_and_this.insert(node); - logical2chain_it->at(node) = lhs; - } - chain_list->erase(rhs); - return true; - } - } - return false; -} - -bool ChainLogicalGraph::TryMergeTwoConnectedChains( - std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it) { - for (auto succ_chain_it = chain_list->begin(); succ_chain_it != chain_list->end(); - ++succ_chain_it) { - if (!succ_chain_it->is_mergeable) { continue; } - for (const LogicalNode* node_in_succ : succ_chain_it->nodes) { - for (const LogicalEdge* in_edge : node_in_succ->in_edges()) { - auto pred_chain_it = logical2chain_it->at(in_edge->src_node()); - if (pred_chain_it == succ_chain_it) { continue; } - if (!pred_chain_it->is_mergeable) { continue; } - if (!pred_chain_it->IsParallelDescEqual(*succ_chain_it)) { continue; } - if (pred_chain_it->ancestors_and_this != succ_chain_it->ancestors - || pred_chain_it->descendants != succ_chain_it->descendants_and_this) { - continue; - } - for (const LogicalNode* node : succ_chain_it->nodes) { - pred_chain_it->nodes.push_back(node); - pred_chain_it->ancestors_and_this.insert(node); - pred_chain_it->descendants.erase(node); - logical2chain_it->at(node) = pred_chain_it; - } - chain_list->erase(succ_chain_it); - return true; - } - } - } - return false; -} - -void ChainLogicalGraph::SortNodesInChains( - std::list<Chain>* chain_list, - const HashMap<const LogicalNode*, size_t>& logical2order_in_topo) { - for (Chain& chain : *chain_list) { - std::sort(chain.nodes.begin(), chain.nodes.end(), - [&](const LogicalNode* a, const LogicalNode* b) { - return logical2order_in_topo.at(a) < logical2order_in_topo.at(b); - }); - } -} - -void ChainLogicalGraph::BuildGraph(const LogicalGraph& logical_graph, - std::list<Chain>* chain_list) { - HashMap<const LogicalNode*, ChainLogicalNode*> logical_node2chain_logical_node; - - for (const Chain& chain : *chain_list) { - ChainLogicalNode* chain_logical_node = NewNode(); - chain_logical_node->mut_logical_nodes() = chain.nodes; - for (const LogicalNode* node : chain.nodes) { - CHECK(logical_node2chain_logical_node.emplace(node, chain_logical_node).second); - } - } - - std::unordered_set<std::pair<ChainLogicalNode*, ChainLogicalNode*>> pred_succ_pairs; - logical_graph.ForEachEdge([&](const LogicalEdge* edge) { - pred_succ_pairs.emplace(logical_node2chain_logical_node.at(edge->src_node()), - logical_node2chain_logical_node.at(edge->dst_node())); - }); - - for (auto& pair : pred_succ_pairs) { - if (pair.first == pair.second) { continue; } - ChainLogicalEdge* edge = NewEdge(); - Connect(pair.first, edge, pair.second); - } -} - -bool ChainLogicalGraph::IsLogicalNodeMergeable(const LogicalNode* logical_node) const { - if (logical_node->parallel_desc()->policy() != kDataParallel) { return false; } - if (!dynamic_cast<const NormalForwardLogicalNode*>(logical_node)) { return false; } - for (const std::shared_ptr<Operator>& op : logical_node->op_vec()) { - if (dynamic_cast<FullyConnectedOp*>(op.get())) { return false; } - if (op->IsRecurrentOp()) { return false; } - } - return true; -} - -} // namespace oneflow diff --git a/oneflow/core/graph/chain_logical_graph.h b/oneflow/core/graph/chain_logical_graph.h deleted file mode 100644 index bb95bcc43368e859ed89c57927ed3e72e6977a2e..0000000000000000000000000000000000000000 --- a/oneflow/core/graph/chain_logical_graph.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_ -#define ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_ - -#include "oneflow/core/graph/graph.h" -#include "oneflow/core/graph/logical_graph.h" - -namespace oneflow { - -class ChainLogicalEdge; - -class ChainLogicalNode final : public Node<ChainLogicalNode, ChainLogicalEdge> { - public: - OF_DISALLOW_COPY_AND_MOVE(ChainLogicalNode); - ChainLogicalNode() = default; - ~ChainLogicalNode() override = default; - - const std::vector<const LogicalNode*>& logical_nodes() const { return logical_nodes_; } - std::vector<const LogicalNode*>& mut_logical_nodes() { return logical_nodes_; } - - private: - std::vector<const LogicalNode*> logical_nodes_; -}; - -class ChainLogicalEdge final : public Edge<ChainLogicalNode, ChainLogicalEdge> { - public: - OF_DISALLOW_COPY_AND_MOVE(ChainLogicalEdge); - ChainLogicalEdge() = default; - ~ChainLogicalEdge() override = default; -}; - -class ChainLogicalGraph final : public Graph<ChainLogicalNode, ChainLogicalEdge> { - public: - OF_DISALLOW_COPY_AND_MOVE(ChainLogicalGraph); - explicit ChainLogicalGraph(const LogicalGraph& logical_graph); - ~ChainLogicalGraph() override = default; - - private: - struct Chain; - void InitChains(const LogicalGraph& logical_graph, std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it, - HashMap<const LogicalNode*, size_t>* logical2order_in_topo); - void MergeChains(std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it); - bool TryMergeTwoChains(std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it); - bool TryMergeTwoParallelChains( - std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it); - bool TryMergeTwoConnectedChains( - std::list<Chain>* chain_list, - HashMap<const LogicalNode*, std::list<Chain>::iterator>* logical2chain_it); - void SortNodesInChains(std::list<Chain>* chain_list, - const HashMap<const LogicalNode*, size_t>& logical2order_in_topo); - void BuildGraph(const LogicalGraph& logical_graph, std::list<Chain>* chain_list); - bool IsLogicalNodeMergeable(const LogicalNode* logical_node) const; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_GRAPH_CHAIN_LOGICAL_GRAPH_H_ diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index d3ed963f0af535c3c90841dd882cccbd33552d97..4c94992570626e3237a73be324259a07aa31f67a 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -1,4 +1,5 @@ #include "oneflow/core/graph/exec_graph.h" +#include "oneflow/core/graph/op_graph.h" namespace oneflow { @@ -51,8 +52,10 @@ void ExecNode::ToProto(bool is_forward, const ParallelContext* parallel_ctx, } void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) { - op_->InferBlobDescsIf(GetBlobDesc4BnInOpFunc(), parallel_ctx, + auto GetBlobDesc4BnInOp = GetBlobDesc4BnInOpFunc(); + op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, Global<JobDesc>::Get()->RecordPieceSize(), [this](OpContext* op_ctx) { op_ctx_.reset(op_ctx); }); + Global<OpGraph>::Get()->CheckBlobDescs(op_->op_name(), GetBlobDesc4BnInOp, parallel_ctx); } void ExecNode::InferBwBufBlobDescs(const ParallelContext* parallel_ctx) { diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index 7b8de5479a85748d8bf5f06c873ea9aabbe53983..7a00416ba74e0314f99d69a103529a1c9f1e91dc 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -45,8 +45,8 @@ class Graph { const std::function<void(NodeType*)>& Handler) const; // Getters - const std::unordered_set<NodeType*>& source_nodes() const; - const std::unordered_set<NodeType*>& sink_nodes() const; + std::list<NodeType*> source_nodes() const; + std::list<NodeType*> sink_nodes() const; NodeType* SoleSourceNode() const; NodeType* SoleSinkNode() const; NodeType* SoleNode() const; @@ -57,7 +57,8 @@ class Graph { // Setters template<typename DerivedNodeType = NodeType> DerivedNodeType* NewNode(); - EdgeType* NewEdge(); + template<class... Args> + EdgeType* NewEdge(Args&&... args); void AddAllocatedNode(NodeType*); void AddAllocatedEdge(EdgeType*); void DeleteNode(NodeType*); @@ -79,23 +80,47 @@ void Graph<NodeType, EdgeType>::ForEachNode(std::function<void(NodeType*)> NodeH } template<typename NodeType, typename EdgeType> -void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const { - std::list<NodeType*> starts; +std::list<NodeType*> Graph<NodeType, EdgeType>::source_nodes() const { + std::list<NodeType*> ret; + ForEachNode([&](NodeType* node) { + if (node->in_edges().empty()) { ret.push_back(node); } + }); + return ret; +} + +template<typename NodeType, typename EdgeType> +std::list<NodeType*> Graph<NodeType, EdgeType>::sink_nodes() const { + std::list<NodeType*> ret; ForEachNode([&](NodeType* node) { - if (node->in_edges().empty()) { starts.push_back(node); } + if (node->out_edges().empty()) { ret.push_back(node); } }); - TopoForEachNode(starts, &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, + return ret; +} + +template<typename NodeType, typename EdgeType> +NodeType* Graph<NodeType, EdgeType>::SoleSourceNode() const { + std::list<NodeType*> source_nodes_list = source_nodes(); + CHECK_EQ(source_nodes_list.size(), 1); + return source_nodes_list.front(); +} + +template<typename NodeType, typename EdgeType> +NodeType* Graph<NodeType, EdgeType>::SoleSinkNode() const { + std::list<NodeType*> sink_nodes_list = sink_nodes(); + CHECK_EQ(sink_nodes_list.size(), 1); + return sink_nodes_list.front(); +} + +template<typename NodeType, typename EdgeType> +void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const { + TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template<typename NodeType, typename EdgeType> void Graph<NodeType, EdgeType>::ReverseTopoForEachNode( std::function<void(NodeType*)> NodeHandler) const { - std::list<NodeType*> starts; - ForEachNode([&](NodeType* node) { - if (node->out_edges().empty()) { starts.push_back(node); } - }); - TopoForEachNode(starts, &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, + TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler); } @@ -122,8 +147,9 @@ DerivedNodeType* Graph<NodeType, EdgeType>::NewNode() { } template<typename NodeType, typename EdgeType> -EdgeType* Graph<NodeType, EdgeType>::NewEdge() { - EdgeType* ret = new EdgeType; +template<class... Args> +EdgeType* Graph<NodeType, EdgeType>::NewEdge(Args&&... args) { + EdgeType* ret = new EdgeType(std::forward<Args>(args)...); AddAllocatedEdge(ret); return ret; } @@ -163,6 +189,7 @@ template<typename NodeType, typename EdgeType> void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path) { auto log_stream = TeePersistentLogStream::Create(file_path); ToDotWithStream(log_stream); + log_stream->Flush(); } template<typename NodeType, typename EdgeType> @@ -242,17 +269,18 @@ void Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink( } HashMap<NodeType*, int64_t> node2distance_to_sink; TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) { - int64_t distince_to_sink = -1; + int64_t distance_to_sink = -1; ForEachOutNode(node, [&](NodeType* out_node) { - distince_to_sink = std::max(distince_to_sink, node2distance_to_sink[out_node]); + distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]); }); - node2distance_to_sink[node] = distince_to_sink + 1; + node2distance_to_sink[node] = distance_to_sink + 1; }); auto ForEachOutNodeSortedByDistanceToSink = [&](NodeType* node, const std::function<void(NodeType*)>& Handler) { std::vector<NodeType*> out_nodes; ForEachOutNode(node, [&](NodeType* out_node) { out_nodes.push_back(out_node); }); std::sort(out_nodes.begin(), out_nodes.end(), [&](NodeType* lhs, NodeType* rhs) { + // DfsTopoForEachNode use stack, so sort desc return node2distance_to_sink.at(lhs) > node2distance_to_sink.at(rhs); }); for (NodeType* out_node : out_nodes) { Handler(out_node); } diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index 8ffccbef733b6ed5c519924ddbecd3fed6b48651..8b6aff144b7f379f28b36acfff35cac351a970f1 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -1,8 +1,8 @@ #include "oneflow/core/graph/logical_graph.h" #include "oneflow/core/graph/task_graph.h" +#include "oneflow/core/graph/op_graph.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/op_conf.pb.h" -#include "oneflow/core/graph/chain_logical_graph.h" #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { @@ -45,7 +45,6 @@ std::function<bool(const LogicalNode*)> MakePredicatorHasActualOutDiff(const Log LogicalGraph::LogicalGraph(bool is_train) { BuildFwStruct(); if (is_train) { GroupNodesForReduceStruct(); } - SetMainModelParallel(); if (is_train) { BuildBwStruct(); } MergeEdge(); SetNodeDataLbi(); @@ -59,24 +58,41 @@ LogicalGraph::LogicalGraph(bool is_train) { } void LogicalGraph::GroupNodesForReduceStruct() { - ChainLogicalGraph chain_logical_graph(*this); - std::vector<std::vector<const LogicalNode*>> fw_node_groups; - chain_logical_graph.ForEachNode( - [&](ChainLogicalNode* node) { fw_node_groups.emplace_back(node->logical_nodes()); }); - for (auto& fw_node_group : fw_node_groups) { - if (fw_node_group.size() < Global<JobDesc>::Get()->reduce_group_size()) { - fw_node_groups_.emplace_back(std::move(fw_node_group)); - } else { - int64_t fw_node_group_size = fw_node_group.size(); - int64_t seg_num = fw_node_group_size / Global<JobDesc>::Get()->reduce_group_size() + 1; - BalancedSplitter bs(fw_node_group_size, seg_num); - FOR_RANGE(int64_t, idx, 0, seg_num) { - std::vector<const LogicalNode*> sub_fw_node_group; - Range range = bs.At(idx); - FOR_RANGE(int64_t, nid, range.begin(), range.end()) { - sub_fw_node_group.emplace_back(fw_node_group[nid]); - } - fw_node_groups_.emplace_back(std::move(sub_fw_node_group)); + // get op model size + HashMap<std::string, size_t> op_name2model_size; + auto OpName2ModelSize = [&](const std::string& op_name) -> size_t { + if (op_name2model_size.find(op_name) == op_name2model_size.end()) { return 0; } + return op_name2model_size.at(op_name); + }; + Global<OpGraph>::Get()->InferOpModelSize(&op_name2model_size); + size_t model_total_size = 0; + for (const auto& pair : op_name2model_size) { model_total_size += pair.second; } + HashMap<ParallelDesc, std::list<const LogicalNode*>> parellel_desc2fw_group; + size_t avg_size = model_total_size / Global<JobDesc>::Get()->all_reduce_group_num(); + const size_t group_min_size = Global<JobDesc>::Get()->all_reduce_group_min_byte(); + const float group_size_warmup = Global<JobDesc>::Get()->all_reduce_group_size_warmup(); + size_t cur_group_size = group_min_size / group_size_warmup; + auto GetCurGroupSize = [&](int32_t group_id) { + if (cur_group_size < avg_size) { cur_group_size *= group_size_warmup; } + return std::min(cur_group_size, avg_size); + }; + // group fw nodes by parallel desc + ReverseTopoForEachNode([&](LogicalNode* fw_node) { + parellel_desc2fw_group[*fw_node->parallel_desc()].push_front(fw_node); + }); + CHECK_GT(parellel_desc2fw_group.size(), 0); + for (auto& pair : parellel_desc2fw_group) { + fw_node_groups_.emplace_back(std::vector<const LogicalNode*>()); + auto& fw_node_group = pair.second; + size_t cur_group_model_size = 0; + int32_t group_id = 0; + for (const LogicalNode* fw_node : fw_node_group) { + fw_node_groups_.back().emplace_back(fw_node); + cur_group_model_size += OpName2ModelSize(fw_node->SoleOp()->op_name()); + if (cur_group_model_size >= GetCurGroupSize(group_id)) { + fw_node_groups_.emplace_back(std::vector<const LogicalNode*>()); + cur_group_model_size = 0; + ++group_id; } } } @@ -200,20 +216,6 @@ void LogicalGraph::LinkUnpackFw2PackFw( }); } -void LogicalGraph::SetMainModelParallel() { - ForEachNode([](LogicalNode* node) { - if (node->parallel_desc()->policy() == kModelParallel) { node->set_main_model_parallel(node); } - }); - ForEachNode([](LogicalNode* node) { - LogicalNode* pred_node = node; - while (pred_node->SoleOp()->IsElemWiseOp()) { pred_node = pred_node->SoleInEdge()->src_node(); } - if (pred_node != node && pred_node->parallel_desc()->policy() == kModelParallel) { - node->mut_parallel_desc() = pred_node->parallel_desc(); - node->set_main_model_parallel(pred_node); - } - }); -} - void LogicalGraph::BuildBwStruct() { NaiveBuildBwStruct(); AddBackwardClone(); @@ -356,6 +358,7 @@ void LogicalGraph::BuildLossPrintStruct() { reduce_loss_op_conf.set_device_type(loss_op->device_type()); auto reduce_sum_conf = reduce_loss_op_conf.mutable_reduce_sum_conf(); *(reduce_sum_conf->mutable_in_sys()) = loss_op->BnInOp2Lbi("loss"); + reduce_sum_conf->add_axis(0); reduce_sum_conf->set_out("out"); std::shared_ptr<Operator> reduce_loss_op = ConstructOp(reduce_loss_op_conf); loss_logical->mut_op_vec().push_back(reduce_loss_op); @@ -508,8 +511,11 @@ void LogicalGraph::BuildModelStruct(bool is_train) { } } }); - for (auto& fw_node_group : fw_node_groups_) { + for (int i = 0; i < fw_node_groups_.size(); ++i) { + auto& fw_node_group = fw_node_groups_[i]; ReduceCtx group_reduce_ctx; + group_reduce_ctx.order_in_logical_graph = i; + int order_in_reduce_group = 0; for (auto& fw_node : fw_node_group) { auto reduce_ctx_it = fw_node2reduce_ctx.find(fw_node); if (reduce_ctx_it != fw_node2reduce_ctx.end()) { @@ -518,42 +524,55 @@ void LogicalGraph::BuildModelStruct(bool is_train) { group_reduce_ctx.bw_logicals.emplace_back(reduce_ctx.bw_logicals.at(0)); group_reduce_ctx.md_diff_acc_logicals.emplace_back(reduce_ctx.md_diff_acc_logicals.at(0)); group_reduce_ctx.md_updt_logicals.emplace_back(reduce_ctx.md_updt_logicals.at(0)); + auto* md_updt = dynamic_cast<NormalMdUpdtLogicalNode*>(reduce_ctx.md_updt_logicals.at(0)); + md_updt->set_order_in_reduce_group(order_in_reduce_group++); } } - BuildReduceStruct(group_reduce_ctx); + if (group_reduce_ctx.fw_logicals.size() > 0) { BuildReduceStruct(group_reduce_ctx); } } SetupNormalMdUpdtOp(); } void LogicalGraph::BuildReduceStruct(const ReduceCtx& reduce_ctx) { - if (reduce_ctx.fw_logicals.size() > 1) { - std::shared_ptr<const ParallelDesc> src_pd = reduce_ctx.fw_logicals[0]->parallel_desc(); - - OperatorConf reduce_concat_op_conf; - reduce_concat_op_conf.set_name("reduce_concat_" + NewUniqueId()); - reduce_concat_op_conf.set_device_type(src_pd->device_type()); - reduce_concat_op_conf.mutable_reduce_concat_conf()->set_in_num(reduce_ctx.fw_logicals.size()); - LogicalNode* reduce_concat_node = NewNode<ReduceConcatLogicalNode>(); - reduce_concat_node->mut_op_vec() = {ConstructOp(reduce_concat_op_conf)}; - reduce_concat_node->mut_parallel_desc() = src_pd; - - OperatorConf reduce_split_op_conf; - reduce_split_op_conf.set_name("reduce_split_" + NewUniqueId()); - reduce_split_op_conf.set_device_type(src_pd->device_type()); - reduce_split_op_conf.mutable_reduce_split_conf()->set_out_num(reduce_ctx.fw_logicals.size()); - LogicalNode* reduce_split_node = NewNode<ReduceSplitLogicalNode>(); - reduce_split_node->mut_op_vec() = {ConstructOp(reduce_split_op_conf)}; - reduce_split_node->mut_parallel_desc() = src_pd; - - for (auto& md_diff_acc_node : reduce_ctx.md_diff_acc_logicals) { - Connect(md_diff_acc_node, NewEdge(), reduce_concat_node); - } - for (auto& md_updt_node : reduce_ctx.md_updt_logicals) { - Connect(reduce_split_node, NewEdge(), md_updt_node); - } - AddAllReduce(reduce_concat_node, reduce_split_node); - } else if (reduce_ctx.fw_logicals.size() == 1) { - AddAllReduce(reduce_ctx.md_diff_acc_logicals.at(0), reduce_ctx.md_updt_logicals.at(0)); + CHECK_GT(reduce_ctx.fw_logicals.size(), 0); + std::shared_ptr<const ParallelDesc> src_pd = reduce_ctx.fw_logicals[0]->parallel_desc(); + + OperatorConf reduce_concat_op_conf; + reduce_concat_op_conf.set_name("reduce_concat_" + NewUniqueId()); + reduce_concat_op_conf.set_device_type(src_pd->device_type()); + reduce_concat_op_conf.mutable_reduce_concat_conf()->set_in_num(reduce_ctx.fw_logicals.size()); + LogicalNode* reduce_concat_node = NewNode<ReduceConcatLogicalNode>(); + reduce_concat_node->mut_op_vec() = {ConstructOp(reduce_concat_op_conf)}; + reduce_concat_node->mut_parallel_desc() = src_pd; + + // We can not add ctrl edges between all_reduce nodes due to the implementation of nccl. + // So we add ctrl edges between ReduceIdentityTaskNodes which are followed by + // all_reduce nodes; + OperatorConf reduce_identity_conf; + reduce_identity_conf.set_name("reduce_identity_" + NewUniqueId()); + reduce_identity_conf.set_device_type(src_pd->device_type()); + reduce_identity_conf.mutable_reduce_identity_conf(); + auto* reduce_identity_node = NewNode<ReduceIdentityLogicalNode>(); + reduce_identity_node->mut_op_vec() = {ConstructOp(reduce_identity_conf)}; + reduce_identity_node->mut_parallel_desc() = src_pd; + reduce_identity_node->set_order_in_logical_graph(reduce_ctx.order_in_logical_graph); + + OperatorConf reduce_split_op_conf; + reduce_split_op_conf.set_name("reduce_split_" + NewUniqueId()); + reduce_split_op_conf.set_device_type(src_pd->device_type()); + reduce_split_op_conf.mutable_reduce_split_conf()->set_out_num(reduce_ctx.fw_logicals.size()); + auto* reduce_split_node = NewNode<ReduceSplitLogicalNode>(); + reduce_split_node->mut_op_vec() = {ConstructOp(reduce_split_op_conf)}; + reduce_split_node->mut_parallel_desc() = src_pd; + reduce_split_node->set_order_in_logical_graph(reduce_ctx.order_in_logical_graph); + + for (auto& md_diff_acc_node : reduce_ctx.md_diff_acc_logicals) { + Connect(md_diff_acc_node, NewEdge(), reduce_concat_node); + } + Connect(reduce_concat_node, NewEdge(), static_cast<LogicalNode*>(reduce_identity_node)); + AddAllReduce(reduce_identity_node, reduce_split_node); + for (auto& md_updt_node : reduce_ctx.md_updt_logicals) { + Connect(static_cast<LogicalNode*>(reduce_split_node), NewEdge(), md_updt_node); } } @@ -562,7 +581,8 @@ void LogicalGraph::AddAllReduce(LogicalNode* src, LogicalNode* dst) { std::shared_ptr<const ParallelDesc> dst_pd = dst->parallel_desc(); CHECK_EQ(src_pd->parallel_num(), dst_pd->parallel_num()); CHECK_EQ(src_pd->device_type(), dst_pd->device_type()); - if (Global<JobDesc>::Get()->enable_nccl()) { + + if (Global<JobDesc>::Get()->enable_nccl() && src_pd->device_type() == DeviceType::kGPU) { if (src_pd->sorted_machine_ids().size() == 1 || Global<JobDesc>::Get()->use_nccl_inter_node_communication()) { AddNcclAllReduce(src, dst); @@ -694,7 +714,8 @@ NormalMdUpdtLogicalNode* LogicalGraph::BuildNormalMdUpdtAndMdSaveStruct( // for model BuildMdSaveStruct(fw_logical, md_updt_logical); // TODO: remove the following ugly hard coded `if' - if (Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf()) { + if (Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf() + || Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_adam_conf()) { // for forward_model BuildMdSaveStruct(fw_logical, md_updt_logical); } diff --git a/oneflow/core/graph/logical_graph.h b/oneflow/core/graph/logical_graph.h index 2f61c869cb274a575e6c5f9866d87fee3557c453..bb90a6e253fefb846469ee80e751cdd4d832b361 100644 --- a/oneflow/core/graph/logical_graph.h +++ b/oneflow/core/graph/logical_graph.h @@ -28,6 +28,7 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> { std::vector<LogicalEdge*> edges; }; struct ReduceCtx { + int32_t order_in_logical_graph; std::vector<LogicalNode*> fw_logicals; std::vector<LogicalNode*> bw_logicals; std::vector<LogicalNode*> md_diff_acc_logicals; @@ -43,7 +44,6 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> { void LinkUnpackFw2PackFw(const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes); void ReConnectToFwClone(LogicalNode* clone_node, const LogicalBlobId& lbi, const std::vector<LogicalEdge*>& edges, const std::string& obn); - void SetMainModelParallel(); void BuildBwStruct(); void NaiveBuildBwStruct(); void AddBackwardClone(); diff --git a/oneflow/core/graph/logical_node.cpp b/oneflow/core/graph/logical_node.cpp index 0ceb5e6090a2499b9904d57685f4fc7e364b4b42..e3fcdc25ee2ad236fa9d05d9ef6d44511c8b7704 100644 --- a/oneflow/core/graph/logical_node.cpp +++ b/oneflow/core/graph/logical_node.cpp @@ -23,13 +23,24 @@ #include "oneflow/core/graph/accuracy_accumulate_compute_task_node.h" #include "oneflow/core/graph/accuracy_print_compute_task_node.h" #include "oneflow/core/graph/task_graph.h" +#include "oneflow/core/graph/reduce_identity_task_node.h" +#include "oneflow/core/graph/op_graph.h" namespace oneflow { namespace { +bool HasSoleIdentityOp(const LogicalNode* logical_node) { + const auto& op_conf = logical_node->SoleOp()->op_conf(); + return logical_node->op_vec().size() == 1 + && (op_conf.has_parallel_cast_conf() || op_conf.has_tuple_identity_conf()); +} + BldBoxingOpConfMthd GetBldBoxingOpConfMethodByFwParallelPolicy(const LogicalNode* in_logical, const LogicalNode* out_logical) { + if (HasSoleIdentityOp(in_logical) || HasSoleIdentityOp(out_logical)) { + return &BoxingTaskNode::BldBoxingOpConfWithFwSbpParallel; + } ParallelPolicy in_policy = in_logical->parallel_desc()->policy(); ParallelPolicy out_policy = out_logical->parallel_desc()->policy(); if (in_policy == kDataParallel && out_policy == kDataParallel) { @@ -47,6 +58,9 @@ BldBoxingOpConfMthd GetBldBoxingOpConfMethodByFwParallelPolicy(const LogicalNode } BldBoxingOpConfMthd GetBldBoxingOpConfMethodByBwParallelPolicy(const LogicalNode* in_logical, const LogicalNode* out_logical) { + if (HasSoleIdentityOp(in_logical) || HasSoleIdentityOp(out_logical)) { + return &BoxingTaskNode::BldBoxingOpConfWithBwSbpParallel; + } ParallelPolicy in_policy = in_logical->parallel_desc()->policy(); ParallelPolicy out_policy = out_logical->parallel_desc()->policy(); if (in_policy == kDataParallel && out_policy == kDataParallel) { @@ -237,6 +251,10 @@ void LogicalNode::GenSortedCompTaskNodes( comp_task_node->set_thrd_id(id_mgr->GetGpuMixThrdId(dev_phy_id)); break; } + case CudaWorkType::kReduceCtrl: { + comp_task_node->set_thrd_id(id_mgr->GetGpuReduceCtrlThrdId(dev_phy_id)); + break; + } case CudaWorkType::kMdUpdt: { comp_task_node->set_thrd_id(id_mgr->GetGpuMdUpdtThrdId(dev_phy_id)); break; @@ -259,30 +277,6 @@ void LogicalNode::GenSortedCompTaskNodes( } } -int32_t LogicalNode::GetModelSplitAxis() const { - CHECK_EQ(parallel_desc_->policy(), kModelParallel); - CHECK_NOTNULL(main_model_parallel_); - if (main_model_parallel_ == this) { - int32_t ret = SoleOp()->ModelSplitAxis(); - CHECK_NE(ret, -1); - return ret; - } else { - return main_model_parallel_->GetModelSplitAxis(); - } -} - -int32_t LogicalNode::GetMaxModelSplitNum() const { - CHECK_EQ(parallel_desc_->policy(), kModelParallel); - CHECK_NOTNULL(main_model_parallel_); - if (main_model_parallel_ == this) { - int32_t ret = SoleOp()->MaxModelSplitNum(); - CHECK_NE(ret, -1); - return ret; - } else { - return main_model_parallel_->GetMaxModelSplitNum(); - } -} - bool LogicalNode::HasOpWithCondition(std::function<bool(const Operator*)> cond) const { for (std::shared_ptr<const Operator> op : op_vec_) { if (cond(op.get())) { return true; } @@ -291,12 +285,39 @@ bool LogicalNode::HasOpWithCondition(std::function<bool(const Operator*)> cond) } static bool IsModelParallel121(const LogicalNode* src_node, const LogicalNode* dst_node) { - return src_node->main_model_parallel() == dst_node->main_model_parallel(); + if (src_node->parallel_desc()->parallel_num() != dst_node->parallel_desc()->parallel_num()) { + return false; + } + LogicalEdge* connect_edge = nullptr; + for (LogicalEdge* edge : src_node->out_edges()) { + if (edge->dst_node() == dst_node) { connect_edge = edge; } + } + CHECK_NOTNULL(connect_edge); + CHECK_GT(connect_edge->lbis().size(), 0); + const std::string& src_op_name = src_node->SoleOp()->op_name(); + const std::string& dst_op_name = dst_node->SoleOp()->op_name(); + for (const LogicalBlobId& lbi : connect_edge->lbis()) { + const auto& src_sbp = Global<OpGraph>::Get()->GetSbpParallel(src_op_name, lbi); + const auto& dst_sbp = Global<OpGraph>::Get()->GetSbpParallel(dst_op_name, lbi); + if (src_sbp != dst_sbp) { return false; } + } + return true; } BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const LogicalNode* dst_node) { std::shared_ptr<const ParallelDesc> src_pd = src_node->parallel_desc(); std::shared_ptr<const ParallelDesc> dst_pd = dst_node->parallel_desc(); + if (src_node->op_vec().size() == 1 && dst_node->op_vec().size() == 1) { + if (src_node->SoleOp()->op_conf().has_record_load_conf() + && dst_node->SoleOp()->op_conf().has_tick_conf()) { + CHECK(src_pd->parallel_num() == dst_pd->parallel_num()); + CHECK(src_pd->policy() == kDataParallel && dst_pd->policy() == kDataParallel); + } + if (src_node->SoleOp()->op_conf().has_tick_conf() + && dst_node->SoleOp()->op_conf().has_log_counter_conf() == false) { + return &TaskGraph::BldSubTskGphByTickToSource; + } + } if (src_pd->parallel_num() == 1 && dst_pd->parallel_num() == 1) { return &TaskGraph::BldSubTskGphByOneToOne; } @@ -403,6 +424,7 @@ REGISTER_BLD_BOXING_OP_CONF_MTHD("NormalBackward" OF_PP_MAKE_TUPLE_SEQ(MdDiffAcc, kDataBackwardArea) \ OF_PP_MAKE_TUPLE_SEQ(Print, kPrintArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceConcat, kMdUpdtArea) \ + OF_PP_MAKE_TUPLE_SEQ(ReduceIdentity, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceScatter, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceAdd, kMdUpdtArea) \ OF_PP_MAKE_TUPLE_SEQ(ReduceGather, kMdUpdtArea) \ @@ -425,7 +447,6 @@ BackwardLogicalNode* ForwardLogicalNode::NewBackwardNode() { bw_node_->mut_op_vec() = op_vec(); bw_node_->mut_parallel_desc() = parallel_desc(); bw_node_->fw_node_ = this; - bw_node_->set_main_model_parallel(main_model_parallel()); return bw_node_; } diff --git a/oneflow/core/graph/logical_node.h b/oneflow/core/graph/logical_node.h index 5dc5e057c1ff3cf656017fdd4c8bc92883c26565..844d0ac11bf30b5e7a4e8f8576a4689bc5dbaf49 100644 --- a/oneflow/core/graph/logical_node.h +++ b/oneflow/core/graph/logical_node.h @@ -52,16 +52,11 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> { std::vector<std::pair<int64_t, CompTaskNode*>>* nodes, std::function<void(CompTaskNode*)>) const; - // model split - LogicalNode* main_model_parallel() const { return main_model_parallel_; } - void set_main_model_parallel(LogicalNode* val) { main_model_parallel_ = val; } - int32_t GetModelSplitAxis() const; - int32_t GetMaxModelSplitNum() const; - virtual int64_t GetAreaId() const = 0; + virtual bool MayConsumeModelDiff() const { return false; } protected: - LogicalNode() : main_model_parallel_(nullptr) {} + LogicalNode() {} virtual CompTaskNode* NewCompTaskNode() const = 0; virtual void FixCompTaskNode(CompTaskNode*) const {} @@ -71,7 +66,6 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> { std::vector<std::shared_ptr<Operator>> op_vec_; std::shared_ptr<const ParallelDesc> parallel_desc_; std::shared_ptr<const std::vector<LogicalNode*>> shared_model_nodes_; - LogicalNode* main_model_parallel_; HashMap<const LogicalNode*, std::vector<LogicalBlobId>> dst2data_lbis_; }; @@ -238,19 +232,31 @@ DECLARE_NAIVE_LOGICAL_NODE(AccuracyPrintLogicalNode); class NormalMdUpdtLogicalNode final : public LogicalNode { public: OF_DISALLOW_COPY_AND_MOVE(NormalMdUpdtLogicalNode); - NormalMdUpdtLogicalNode() : random_seed_(NewRandomSeed()) {} + NormalMdUpdtLogicalNode() : random_seed_(NewRandomSeed()), order_in_reduce_group_(0) {} ~NormalMdUpdtLogicalNode() = default; OVERRIDE_PURE_VIRTUAL_METHOD(); + bool MayConsumeModelDiff() const override { return true; } + + int order_in_reduce_group() const { return order_in_reduce_group_; } + void set_order_in_reduce_group(int order_in_reduce_group) { + order_in_reduce_group_ = order_in_reduce_group; + } private: void FixCompTaskNode(CompTaskNode*) const override; uint32_t random_seed_; + int order_in_reduce_group_; }; DECLARE_NAIVE_LOGICAL_NODE(MdSaveLogicalNode); -DECLARE_NAIVE_LOGICAL_NODE(MdDiffAccLogicalNode); + +class MdDiffAccLogicalNode final : public LogicalNode { + public: + LOGICAL_NODE_BOILERPLATE(MdDiffAccLogicalNode); + bool MayConsumeModelDiff() const override { return true; } +}; class ReduceLogicalNode : public LogicalNode { public: @@ -275,23 +281,41 @@ class ReduceLogicalNode : public LogicalNode { } }; -#define DECLARE_REDUCE_LOGICAL_NODE(name) \ - class name final : public ReduceLogicalNode { \ - public: \ - LOGICAL_NODE_BOILERPLATE(name); \ +#define DECLARE_REDUCE_LOGICAL_NODE(name, may_consume_md_diff) \ + class name final : public ReduceLogicalNode { \ + public: \ + LOGICAL_NODE_BOILERPLATE(name); \ + bool MayConsumeModelDiff() const override { return may_consume_md_diff; } \ } -DECLARE_REDUCE_LOGICAL_NODE(ReduceConcatLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(ReduceSplitLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(ReduceScatterLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(ReduceGatherLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(ReduceAddLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(NcclAllReduceLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(NcclAllGatherLogicalNode); -DECLARE_REDUCE_LOGICAL_NODE(NcclReduceScatterLogicalNode); +DECLARE_REDUCE_LOGICAL_NODE(ReduceConcatLogicalNode, true); +DECLARE_REDUCE_LOGICAL_NODE(ReduceScatterLogicalNode, true); +DECLARE_REDUCE_LOGICAL_NODE(ReduceGatherLogicalNode, false); +DECLARE_REDUCE_LOGICAL_NODE(NcclAllReduceLogicalNode, true); +DECLARE_REDUCE_LOGICAL_NODE(ReduceAddLogicalNode, false); +DECLARE_REDUCE_LOGICAL_NODE(NcclAllGatherLogicalNode, false); +DECLARE_REDUCE_LOGICAL_NODE(NcclReduceScatterLogicalNode, true); DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(RepeatForward); DECLARE_DERIVED_BACKWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(RepeatBackward); + +#define DECLARE_BEFORE_OR_AFTER_ALLREDUCE_REDUCE_NODE(class_name, may_consume_md_diff) \ + class class_name final : public ReduceLogicalNode { \ + public: \ + LOGICAL_NODE_BOILERPLATE(class_name); \ + void set_order_in_logical_graph(int32_t order_in_logical_graph) { \ + order_in_logical_graph_ = order_in_logical_graph; \ + } \ + int32_t order_in_logical_graph() const { return order_in_logical_graph_; } \ + bool MayConsumeModelDiff() const override { return may_consume_md_diff; } \ + \ + private: \ + int32_t order_in_logical_graph_; \ + } + +DECLARE_BEFORE_OR_AFTER_ALLREDUCE_REDUCE_NODE(ReduceIdentityLogicalNode, true); +DECLARE_BEFORE_OR_AFTER_ALLREDUCE_REDUCE_NODE(ReduceSplitLogicalNode, false); + } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_LOGICAL_NODE_H_ diff --git a/oneflow/core/graph/loss_compute_task_node.cpp b/oneflow/core/graph/loss_compute_task_node.cpp index d1cb346fe7113b959875d3b4ec66914630e739d2..08a9d3d5fd7f5c56bc74ccbd56882c2413b39bba 100644 --- a/oneflow/core/graph/loss_compute_task_node.cpp +++ b/oneflow/core/graph/loss_compute_task_node.cpp @@ -8,6 +8,7 @@ void LossCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("loss", false); ProduceRegst("out", true); ProduceRegst("data_tmp", true, 1, 1); + ProduceRegst("const_buf", false, 1, 1); for (TaskEdge* edge : out_edges()) { const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge); if (succ_logical->TypeName() == "LossAcc") { @@ -37,6 +38,7 @@ void LossCompTaskNode::BuildExecGphAndRegst() { } std::shared_ptr<RegstDesc> data_tmp_regst = GetProducedRegst("data_tmp"); loss_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, data_tmp_regst); + loss_node->AddBnToRegstAndBindIt(&Operator::const_buf_bns, GetProducedRegst("const_buf")); if (Global<JobDesc>::Get()->IsTrain()) { BuildRegstWhenTraining(); @@ -48,6 +50,8 @@ void LossCompTaskNode::BuildExecGphAndRegst() { } } mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); + mut_exec_gph().TopoForEachNode( + [this](ExecNode* node) { node->FixInDiffBlobDescs(parallel_ctx()); }); } void LossCompTaskNode::BuildRegstWhenTraining() { diff --git a/oneflow/core/graph/normal_backward_compute_task_node.cpp b/oneflow/core/graph/normal_backward_compute_task_node.cpp index 1bfbda612788d610d07990e8ef46d019a2b68f62..6ebb84eeaf091d8f76f4a0f878aa17cd9fa224c6 100644 --- a/oneflow/core/graph/normal_backward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_backward_compute_task_node.cpp @@ -4,17 +4,19 @@ namespace oneflow { +int64_t NormalBackwardCompTaskNode::AreaId4ChainMerge() const { + CHECK_EQ(area_id(), AreaType::kDataBackwardArea); + return AreaType::kDataForwardArea; +} + void NormalBackwardCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("in_diff", true); ProduceRegst("activation_diff", true, 1, 1); ProduceRegst("bw_buf", true, 1, 1); for (TaskEdge* edge : out_edges()) { const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge); - if (succ_logical->TypeName() == "MdDiffAcc" || succ_logical->TypeName() == "NormalMdUpdt" - || succ_logical->TypeName() == "ReduceScatter" || succ_logical->TypeName() == "ReduceConcat" - || succ_logical->TypeName() == "NcclAllReduce" - || succ_logical->TypeName() == "NcclReduceScatter") { - edge->AddRegst("model_diff", ProduceRegst("model_diff", true)); + if (succ_logical->MayConsumeModelDiff()) { + edge->AddRegst("model_diff", ProduceRegst("model_diff", true, 1, 1)); type_name4model_related_logical_node_ = succ_logical->TypeName(); } else { BindEdgeWithProducedRegst(edge, "in_diff"); diff --git a/oneflow/core/graph/normal_backward_compute_task_node.h b/oneflow/core/graph/normal_backward_compute_task_node.h index de6a83bbac56532ef790f4e2b51d1f21618b4ce4..d9eda294ea08020699b0f3674e5be787f195799c 100644 --- a/oneflow/core/graph/normal_backward_compute_task_node.h +++ b/oneflow/core/graph/normal_backward_compute_task_node.h @@ -16,6 +16,7 @@ class NormalBackwardCompTaskNode final : public CompTaskNode { void BuildExecGphAndRegst() override; TaskType GetTaskType() const override { return TaskType::kNormalBackward; } void RmUselessConsumeRelationshipToFw(); + int64_t AreaId4ChainMerge() const override; protected: void BuildExecGphAndBindOutDiffRegst(); diff --git a/oneflow/core/graph/normal_forward_compute_task_node.cpp b/oneflow/core/graph/normal_forward_compute_task_node.cpp index 0553c646c00ed58532460010c3799a5dcaf326f1..7e42af94b997e05f32b2ef156cf8726c86d26315 100644 --- a/oneflow/core/graph/normal_forward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_forward_compute_task_node.cpp @@ -1,11 +1,39 @@ #include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/operator/variable_op.h" namespace oneflow { +namespace { + +bool IsAllOutputNaive(const LogicalNode* logical_node) { + const Operator* op = logical_node->SoleOp().get(); + if (op->IsAllOutputConst()) { + return false; + } else if (dynamic_cast<const VariableOp*>(op) != nullptr) { + return false; + } else { + return true; + } +} + +} // namespace + +bool NormalForwardCompTaskNode::HasBackwardCompTaskNode() { + for (TaskEdge* edge : out_edges()) { + const LogicalNode* succ_logical = GetOneSuccLogicalNodeOnEdge(edge); + if (succ_logical->TypeName() == "NormalBackward") { return true; } + } + return false; +} + void NormalForwardCompTaskNode::ProduceAllRegstsAndBindEdges() { - ProduceRegst("out", true); + if (IsAllOutputNaive(logical_node())) { + ProduceRegst("out", true); + } else { + ProduceRegst("out", false, 1, 1); + } ProduceRegst("activation", true); ProduceRegst("data_tmp", true); ProduceRegst("fw_buf", true, 1, 1); diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h index e3680db4595c85684655fe008f38fd25b2cc7ab9..545acd83b58b76a86c112e74b09fea709800b6da 100644 --- a/oneflow/core/graph/normal_forward_compute_task_node.h +++ b/oneflow/core/graph/normal_forward_compute_task_node.h @@ -16,6 +16,7 @@ class NormalForwardCompTaskNode final : public CompTaskNode { bool IsReadyForBuild() override; TaskType GetTaskType() const override { return TaskType::kNormalForward; } + bool HasBackwardCompTaskNode(); virtual void ToProto(TaskProto*) override; void set_random_seed(int64_t random_seed) { random_seed_ = random_seed; } diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.cpp b/oneflow/core/graph/normal_model_update_compute_task_node.cpp index 6f439735626d0eca98736c3067a97e7ba0f5f60e..83e89c6352e48f75ed0724f498fcd218e7b8f551 100644 --- a/oneflow/core/graph/normal_model_update_compute_task_node.cpp +++ b/oneflow/core/graph/normal_model_update_compute_task_node.cpp @@ -86,9 +86,12 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { + "total_instance_num"); op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name()); if (Global<JobDesc>::Get()->IsTrain()) { - *(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) = - Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf(); - + if (lbi.blob_name() == "total_instance_num") { + op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()->mutable_naive_conf(); + } else { + *(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) = + Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf(); + } float primary_lr = Global<JobDesc>::Get()->primary_lr(); float secondary_lr = Global<JobDesc>::Get()->secondary_lr(); if (secondary_lr < 0) { secondary_lr = primary_lr; } @@ -146,6 +149,13 @@ void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) { task_proto->set_related_init_model_task_id(related_init_model_task_id_); } +void NormalMdUpdtCompTaskNode::FixPackedBlobDescOfProducedRegst() { + std::shared_ptr<RegstDesc> diff_add_out_regst = GetProducedRegst("processed_model_diff"); + CHECK(diff_add_out_regst->IsLocked()); + Shape& shape = diff_add_out_regst->MutBlobDesc(GenPackedLbi())->mut_shape(); + shape = Shape({static_cast<int64_t>(RoundUp(shape.elem_cnt(), parallel_ctx()->parallel_num()))}); +} + void NormalMdUpdtCompTaskNode::InferProducedDataRegstTimeShape() { ForEachProducedDataRegst([](const std::string& name, RegstDesc* regst) { if (name == "const_model") { @@ -157,4 +167,21 @@ void NormalMdUpdtCompTaskNode::InferProducedDataRegstTimeShape() { }); } +void NormalMdUpdtCompTaskNode::EnableMemSharingBetweenFirstInAndProcessedMdDiffRegst() { + if (!IsTrainable()) { return; } + ExecNode* diff_add_node = exec_gph().SoleSourceNode(); + RegstDesc* first_in_regst = + diff_add_node->RegstDesc4BnInOp(diff_add_node->op()->input_bns().Get(0)); + RegstDesc* diff_add_out_regst = diff_add_node->RegstDesc4BnInOp(diff_add_node->op()->SoleObn()); + CHECK_EQ(diff_add_out_regst, GetProducedRegst("processed_model_diff").get()); + CHECK(first_in_regst->HasSameMemSize(diff_add_out_regst)); + if (!first_in_regst->HasSetMemSharedId()) { + int64_t mem_shared_id = Global<IDMgr>::Get()->NewMemSharedId(); + first_in_regst->set_enable_mem_sharing(true); + first_in_regst->set_mem_shared_id(mem_shared_id); + first_in_regst->set_mem_shared_offset(0); + } + diff_add_out_regst->CopyMemSharedInfoFrom(first_in_regst); +} + } // namespace oneflow diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.h b/oneflow/core/graph/normal_model_update_compute_task_node.h index 3840232eb7e983b652d45d5e7fac2c2d85741b39..c80c1ec2f443a1bc358af277bbc9f0bbdf4471ad 100644 --- a/oneflow/core/graph/normal_model_update_compute_task_node.h +++ b/oneflow/core/graph/normal_model_update_compute_task_node.h @@ -17,6 +17,7 @@ class NormalMdUpdtCompTaskNode final : public CompTaskNode { bool IsReadyForBuild() override; void BuildExecGphAndRegst() override; void LockRegsts() override; + void EnableMemSharingBetweenFirstInAndProcessedMdDiffRegst(); void set_random_seed(uint32_t val) { random_seed_ = val; } TaskType GetTaskType() const override { return TaskType::kNormalMdUpdt; } @@ -26,6 +27,7 @@ class NormalMdUpdtCompTaskNode final : public CompTaskNode { private: const NormalForwardCompTaskNode* GetForwardTaskNode() const; bool IsTrainable() const; + void FixPackedBlobDescOfProducedRegst() override; void InferProducedDataRegstTimeShape() override; uint32_t random_seed_; int64_t related_init_model_task_id_; diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba46f2f6df8eb4b1482dbb05ce6f7883d9f3f180 --- /dev/null +++ b/oneflow/core/graph/op_graph.cpp @@ -0,0 +1,615 @@ +#include "oneflow/core/graph/op_graph.h" + +namespace oneflow { + +std::string OpEdge::VisualStr() const { + std::string str; + int32_t idx = 0; + for (const LogicalBlobId& lbi : lbis_) { + if (idx++ > 0) { str += "\\n"; } + str += lbi.blob_name() + ":"; + str += src_node()->NoParallelBlobDesc4Lbi(lbi).shape().ToString(); + } + return str; +} + +bool* OpNode::MutIsModelBlob4Lbi(const LogicalBlobId& lbi) { + CHECK_EQ(ProducerOpNode4Lbi(lbi), this); + return &lbi2is_model_blob_[lbi]; +} +bool OpNode::IsModelBlob4Lbi(const LogicalBlobId& lbi) const { + return ProducerOpNode4Lbi(lbi)->lbi2is_model_blob_.at(lbi); +} + +const SbpParallel& OpNode::SbpParallel4Lbi(const LogicalBlobId& lbi) const { + return lbi2sbp_parallel_.at(lbi); +} + +SbpParallel* OpNode::MutSbpParallel4Lbi(const LogicalBlobId& lbi) { + return &lbi2sbp_parallel_[lbi]; +} + +std::string OpNode::VisualStr() const { + std::string str = op().op_name(); + { + for (int64_t machine_id : parallel_desc().sorted_machine_ids()) { + std::string dev_type; + if (parallel_desc().device_type() == DeviceType::kCPU) { + dev_type = "cpu"; + } else if (parallel_desc().device_type() == DeviceType::kGPU) { + dev_type = "gpu"; + } else { + UNIMPLEMENTED(); + } + std::string parallel_desc_str = std::to_string(machine_id) + ":" + dev_type + ":"; + const auto& dev_phy_ids = parallel_desc().sorted_dev_phy_ids(machine_id); + parallel_desc_str += std::to_string(dev_phy_ids.front()); + if (dev_phy_ids.back() > dev_phy_ids.front()) { + parallel_desc_str += "-" + std::to_string(dev_phy_ids.back()); + } + str += "\\n" + parallel_desc_str; + } + } + auto GetTimeShapeStr = [&](const Shape& shape, const std::string& prefix) { + std::string time_shape_str = prefix + ":"; + time_shape_str += shape.ToString(); + return time_shape_str; + }; + if (in_edges().empty() == false) { + str += "\\n" + GetTimeShapeStr(*GetInputBlobTimeShape(), "in_blob_time_shape"); + } + str += "\\n" + GetTimeShapeStr(out_blob_time_shape(), "out_blob_time_shape"); + return str; +} + +const BlobDesc& OpNode::NoParallelBlobDesc4Lbi(const LogicalBlobId& lbi) const { + return lbi2no_parallel_blob_desc_.at(lbi); +} + +const BlobDesc& OpNode::LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const { + return lbi2logical_blob_desc_.at(lbi); +} + +BlobDesc* OpNode::MutNoParallelBlobDesc(const LogicalBlobId& lbi) { + CHECK_EQ(lbi.op_name(), op().op_name()); + return &lbi2no_parallel_blob_desc_[lbi]; +} + +BlobDesc* OpNode::MutLogicalBlobDesc4Lbi(const LogicalBlobId& lbi) { + CHECK_EQ(lbi.op_name(), op().op_name()); + return &lbi2logical_blob_desc_[lbi]; +} + +BlobDesc* OpNode::NoParallelBlobDesc4BnInOp(const std::string& bn_in_op) { + return ProducerOpNode4BnInOp(bn_in_op)->MutNoParallelBlobDesc(op().BnInOp2Lbi(bn_in_op)); +} + +const Shape* OpNode::GetInputBlobTimeShape(const std::string& bn_in_op) const { + return &SrcNode4InputBnInOp(bn_in_op)->out_blob_time_shape(); +} + +OpNode* OpNode::ProducerOpNode4BnInOp(const std::string& bn_in_op) { + if (ibns_.find(bn_in_op) != ibns_.end()) { return SrcNode4InputBnInOp(bn_in_op); } + return this; +} + +OpNode* OpNode::SrcNode4InputBnInOp(const std::string& bn_in_op) const { + const LogicalBlobId& lbi = op().BnInOp2Lbi(bn_in_op); + CHECK(ibns_.find(bn_in_op) != ibns_.end()); + return SrcNode4InputLbi(lbi); +} + +OpNode* OpNode::ProducerOpNode4Lbi(const LogicalBlobId& lbi) { + OpNode* producer = SrcNode4InputLbi(lbi); + if (producer == nullptr) { producer = this; } + return producer; +} + +const OpNode* OpNode::ProducerOpNode4Lbi(const LogicalBlobId& lbi) const { + const OpNode* producer = SrcNode4InputLbi(lbi); + if (producer == nullptr) { producer = this; } + return producer; +} + +OpNode* OpNode::SrcNode4InputLbi(const LogicalBlobId& lbi) const { + for (OpEdge* edge : in_edges()) { + for (const LogicalBlobId& edge_lbi : edge->lbis()) { + if (lbi == edge_lbi) { return edge->src_node(); } + } + } + return nullptr; +} + +const Shape* OpNode::GetInputBlobTimeShape() const { + if (in_edges().empty()) { UNIMPLEMENTED(); } + OpNode* first_input = (*in_edges().begin())->src_node(); + for (OpEdge* edge : in_edges()) { + CHECK_EQ(first_input->out_blob_time_shape(), edge->src_node()->out_blob_time_shape()); + } + return &first_input->out_blob_time_shape(); +} + +void OpNode::ForEachParallelBlobDesc(const BlobDesc& blob_desc, const SbpParallel& sbp_parallel, + const std::function<void(const BlobDesc&)>& Handler) const { + if (sbp_parallel.has_split_parallel()) { + // split BlobDesc + int32_t axis = sbp_parallel.split_parallel().axis(); + CHECK_GE(axis, 0); + CHECK_LT(axis, blob_desc.shape().NumAxes()); + CHECK_GE(blob_desc.shape().At(axis), parallel_desc().parallel_num()); + BalancedSplitter bs(blob_desc.shape().At(axis), parallel_desc().parallel_num()); + FOR_RANGE(int64_t, axis_parallel_id, 0, parallel_desc().parallel_num()) { + BlobDesc sub_blob_desc(blob_desc); + sub_blob_desc.mut_shape().Set(axis, bs.At(axis_parallel_id).size()); + Handler(sub_blob_desc); + } + } else { + CHECK(sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()); + // broadcast BlobDesc + FOR_RANGE(int64_t, axis_parallel_id, 0, parallel_desc().parallel_num()) { Handler(blob_desc); } + } +} + +void OpNode::ConcatBlobDesc(const std::vector<BlobDesc>& blob_descs, + const SbpParallel& sbp_parallel, + BlobDesc* concatenated_blob_desc) const { + CHECK_EQ(blob_descs.size(), parallel_desc().parallel_num()); + if (sbp_parallel.has_split_parallel()) { + int32_t axis = sbp_parallel.split_parallel().axis(); + // concat BlobDesc + CHECK_GE(axis, 0); + CHECK_LT(axis, blob_descs.at(0).shape().NumAxes()); + int64_t logical_blob_axis_dim = 0; + for (const BlobDesc& blob_desc : blob_descs) { + logical_blob_axis_dim += blob_desc.shape().At(axis); + } + CHECK_GE(logical_blob_axis_dim, parallel_desc().parallel_num()); + BalancedSplitter bs(logical_blob_axis_dim, parallel_desc().parallel_num()); + std::vector<BlobDesc> same_blob_descs(blob_descs); + FOR_RANGE(int64_t, axis_parallel_id, 0, parallel_desc().parallel_num()) { + CHECK_EQ(bs.At(axis_parallel_id).size(), blob_descs.at(axis_parallel_id).shape().At(axis)); + same_blob_descs.at(axis_parallel_id).mut_shape().Set(axis, logical_blob_axis_dim); + } + for (const BlobDesc& blob_desc : same_blob_descs) { CHECK(blob_desc == same_blob_descs.at(0)); } + *concatenated_blob_desc = same_blob_descs.at(0); + } else { + // select first BlobDesc + for (const BlobDesc& blob_desc : blob_descs) { CHECK(blob_desc == blob_descs.at(0)); } + *concatenated_blob_desc = blob_descs.at(0); + } +} + +int64_t OpNode::GetAxisParallelNum( + const std::function<void(bool*, int32_t*, int64_t*)>& GetAxisParallelInfo) const { + bool is_split = false; + int32_t axis = -1; + int64_t axis_parallel_num = 0; + GetAxisParallelInfo(&is_split, &axis, &axis_parallel_num); + return axis_parallel_num; +} + +void OpNode::SplitLogicalInputBlobDesc() { + for (const std::string& bn : op().input_bns()) { + const LogicalBlobId& lbi = op().BnInOp2Lbi(bn); + const BlobDesc& logical_blob_desc = ProducerOpNode4BnInOp(bn)->LogicalBlobDesc4Lbi(lbi); + const SbpParallel& sbp_parallel = SbpParallel4Lbi(lbi); + ForEachParallelBlobDesc(logical_blob_desc, sbp_parallel, [&](const BlobDesc& blob_desc) { + lbi2parallel_id2blob_desc_[lbi].push_back(blob_desc); + }); + CHECK_EQ(lbi2parallel_id2blob_desc_.at(lbi).size(), parallel_desc().parallel_num()); + } +} + +void OpNode::ConcatLogicalOutputBlobDesc() { + for (const std::string& bn : op().output_bns()) { + const LogicalBlobId& lbi = op().BnInOp2Lbi(bn); + const SbpParallel& sbp_parallel = SbpParallel4Lbi(lbi); + ConcatBlobDesc(lbi2parallel_id2blob_desc_.at(lbi), sbp_parallel, MutLogicalBlobDesc4Lbi(lbi)); + } +} + +void OpNode::CheckBlobDescs(const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + int64_t parallel_id = parallel_ctx->parallel_id(); + auto Check = [&](const std::string& bn) { + const LogicalBlobId& lbi = op().BnInOp2Lbi(bn); + if (lbi2parallel_id2blob_desc_.find(lbi) == lbi2parallel_id2blob_desc_.end()) { return; } + CHECK_EQ(parallel_ctx->parallel_num(), lbi2parallel_id2blob_desc_.at(lbi).size()); + const BlobDesc& blob_desc = *GetBlobDesc4BnInOp(bn); + CHECK(blob_desc == lbi2parallel_id2blob_desc_.at(lbi).at(parallel_id)); + }; + for (const std::string& bn : op().data_tmp_bns()) { Check(bn); } + for (const std::string& bn : op().fw_buf_bns()) { Check(bn); } + for (const std::string& bn : op().input_bns()) { Check(bn); } + for (const std::string& bn : op().output_bns()) { Check(bn); } + for (const std::string& bn : op().model_bns()) { Check(bn); } + for (const std::string& bn : op().const_model_bns()) { Check(bn); } + for (const std::string& bn : op().const_buf_bns()) { Check(bn); } + for (const std::string& bn : op().forward_model_bns()) { Check(bn); } +} + +void OpGraph::InferOpModelSize(HashMap<std::string, size_t>* op_name2model_size) { + ForEachNode([&](OpNode* op_node) { + size_t model_size = 0; + for (const std::string& model_bn : op_node->op().model_bns()) { + int64_t elem_cnt = op_node->NoParallelBlobDesc4BnInOp(model_bn)->shape().elem_cnt(); + model_size += elem_cnt * GetSizeOfDataType(job_desc_->DefaultDataType()); + model_size = RoundUp(model_size, kCudaAlignSize); + } + size_t parallel_num = op_node->parallel_desc().parallel_num(); + if (op_node->parallel_desc().policy() == ParallelPolicy::kModelParallel) { + model_size = (model_size + parallel_num - 1) / parallel_num; + } + CHECK(op_name2model_size->emplace(op_node->op().op_name(), model_size).second); + }); +} + +void OpGraph::Init() { + InitNodes(); + ForEachNode( + [&](OpNode* node) { CHECK(op_name2op_node_.emplace(node->op().op_name(), node).second); }); + InitEdges(); + FixOpParallelDesc(); + UpdateOpNodeHasInDiff(); + InferTimeShape(); + InferNoParallelBlobDesc(); + InferIsModelBlob(); + InferSbpParallel(); + InferLogicalBlobDesc(); +} + +void OpGraph::InitNodes() { + auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job_desc_->placement()); + for (const auto& op_conf : job_desc_->dlnet_conf().op()) { + OpNode* node = new OpNode(ParallelDesc(*ParallelConf4OpName(op_conf.name())), op_conf); + AddAllocatedNode(node); + } +} + +void OpGraph::InitEdges() { + HashMap<LogicalBlobId, OpNode*> lbi2producer; + ForEachNode([&](OpNode* op_node) { + for (const auto& obn : op_node->op().output_bns()) { + CHECK(lbi2producer.emplace(op_node->op().BnInOp2Lbi(obn), op_node).second); + } + }); + ForEachNode([&](OpNode* op_node) { + HashMap<std::string, std::vector<LogicalBlobId>> producer_name2lbis; + HashMap<std::string, HashMap<LogicalBlobId, std::vector<std::string>>> + consumer_op_name2lbi2ibns; + for (const auto& ibn : op_node->op().input_bns()) { + const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); + producer_name2lbis[lbi.op_name()].push_back(lbi); + consumer_op_name2lbi2ibns[op_node->op().op_name()][lbi].push_back(ibn); + } + for (const auto& pair : producer_name2lbis) { + const auto& lbis = pair.second; + const auto& lbi2ibns = consumer_op_name2lbi2ibns.at(op_node->op().op_name()); + OpNode* producer = lbi2producer.at(lbis.at(0)); + Connect(producer, NewEdge(lbis, lbi2ibns), op_node); + } + }); +} + +void OpGraph::FixOpParallelDesc() const { + ForEachNode([&](OpNode* node) { node->op().FixParallelDesc(node->mut_parallel_desc()); }); +} + +void OpGraph::UpdateOpNodeHasInDiff() const { + TopoForEachNode([&](OpNode* op_node) { + bool has_diff = false; + for (OpEdge* edge : op_node->in_edges()) { + if (edge->src_node()->has_in_diff() || edge->src_node()->has_model_diff()) { + edge->set_has_diff(true); + has_diff = true; + break; + } + } + op_node->set_has_in_diff(has_diff); + }); +} + +void OpGraph::InferTimeShape() const { + TopoForEachNode([&](OpNode* op_node) { + ParallelContext parallel_ctx; + parallel_ctx.set_parallel_id(0); + parallel_ctx.set_parallel_num(op_node->parallel_desc().parallel_num()); + parallel_ctx.set_policy(op_node->parallel_desc().policy()); + auto GetInputBlobTimeShape = [&](const std::string& bn_in_op) { + return op_node->GetInputBlobTimeShape(bn_in_op); + }; + op_node->op().InferOutputBlobTimeShapeIf(GetInputBlobTimeShape, ¶llel_ctx, + op_node->mut_out_blob_time_shape()); + }); +} + +void OpGraph::InferNoParallelBlobDesc() const { + TopoForEachNode([&](OpNode* op_node) { + ParallelContext parallel_ctx; + parallel_ctx.set_parallel_id(0); + parallel_ctx.set_parallel_num(1); + parallel_ctx.set_policy(op_node->parallel_desc().policy()); + // the real important data we want to get is: + // a) model blobs' byte size; + // b) number of axes of blobs' body shape; + // Hence the argument record_piece_size can be any positive number, here it's 1 + op_node->op().InferBlobDescsIf( + std::bind(&OpNode::NoParallelBlobDesc4BnInOp, op_node, std::placeholders::_1), + ¶llel_ctx, 1, [](OpContext*) {}); + }); +} + +void OpGraph::InferIsModelBlob() const { + TopoForEachNode([&](OpNode* op_node) { + op_node->op().InferIsModelBlob4OutputBlobsIf([&](const std::string& bn) -> bool* { + return op_node->ProducerOpNode4BnInOp(bn)->MutIsModelBlob4Lbi(op_node->op().BnInOp2Lbi(bn)); + }); + }); +} + +void OpGraph::InferSbpParallel() const { + TopoForEachNode([&](OpNode* op_node) { + HashMap<std::string, SbpInferHint> ibn2sbp_infer_hint; + for (const std::string& ibn : op_node->op().input_bns()) { + const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); + OpNode* producer = op_node->SrcNode4InputBnInOp(ibn); + bool is_model_blob = producer->IsModelBlob4Lbi(lbi); + int64_t parallel_num = op_node->parallel_desc().parallel_num(); + int64_t num_axes = producer->NoParallelBlobDesc4Lbi(lbi).shape().NumAxes(); + const auto& sbp = producer->SbpParallel4Lbi(lbi); + ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(is_model_blob, parallel_num, num_axes, sbp)); + } + auto SbpParallel4BnInOp = [&](const std::string& bn) -> SbpParallel* { + return op_node->MutSbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn)); + }; + auto SbpInferHint4Ibn = [&](const std::string& ibn) -> const SbpInferHint& { + return ibn2sbp_infer_hint.at(ibn); + }; + ParallelContext parallel_ctx; + parallel_ctx.set_parallel_id(0); + parallel_ctx.set_parallel_num(op_node->parallel_desc().parallel_num()); + parallel_ctx.set_policy(op_node->parallel_desc().policy()); + op_node->op().InferInputOutputSbpParallelIf(SbpParallel4BnInOp, SbpInferHint4Ibn, + ¶llel_ctx); + }); +} + +void OpGraph::InferLogicalBlobDesc() const { + TopoForEachNode([&](OpNode* op_node) { + auto* lbi2parallel_id2blob_desc = op_node->mut_lbi2parallel_id2blob_desc(); + op_node->SplitLogicalInputBlobDesc(); + int64_t parallel_num = op_node->parallel_desc().parallel_num(); + const auto& input_bns = op_node->op().input_bns(); + FOR_RANGE(int64_t, parallel_id, 0, parallel_num) { + auto BlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* { + const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn); + if (std::find(input_bns.begin(), input_bns.end(), bn) != input_bns.end()) { + CHECK(lbi2parallel_id2blob_desc->find(lbi) != lbi2parallel_id2blob_desc->end()); + CHECK_EQ(lbi2parallel_id2blob_desc->at(lbi).size(), parallel_num); + } else if (lbi2parallel_id2blob_desc->find(lbi) == lbi2parallel_id2blob_desc->end()) { + (*lbi2parallel_id2blob_desc)[lbi].resize(parallel_num); + } else { + CHECK_EQ(lbi2parallel_id2blob_desc->at(lbi).size(), parallel_num); + } + return &(*lbi2parallel_id2blob_desc)[lbi][parallel_id]; + }; + ParallelContext parallel_ctx; + parallel_ctx.set_parallel_id(parallel_id); + parallel_ctx.set_parallel_num(parallel_num); + parallel_ctx.set_policy(op_node->parallel_desc().policy()); + op_node->op().InferBlobDescsIf(BlobDesc4BnInOp, ¶llel_ctx, job_desc_->RecordPieceSize(), + [](OpContext*) {}); + } + op_node->ConcatLogicalOutputBlobDesc(); + }); +} + +BalancedSplitter OpGraph::GetBalancedSplitter(const std::string& op_name, + const LogicalBlobId& lbi) const { + OpNode* op_node = op_name2op_node_.at(GetOpNameKey(op_name, lbi)); + const SbpParallel& sbp_parallel = GetSbpParallel(op_name, lbi); + CHECK(sbp_parallel.has_split_parallel()); + int64_t split_num = GetSplitNum(op_name, lbi); + if (IsDataBlob(op_name, lbi)) { + CHECK_EQ(split_num % op_node->parallel_desc().parallel_num(), 0); + } else { + CHECK(IsModelBlob(op_name, lbi)); + CHECK_GE(split_num, op_node->parallel_desc().parallel_num()); + } + return BalancedSplitter(split_num, op_node->parallel_desc().parallel_num()); +} + +int32_t OpGraph::GetModelSplitAxis(const std::string& op_name, const LogicalBlobId& lbi) const { + const SbpParallel& sbp_parallel = GetSbpParallel(op_name, lbi); + CHECK(sbp_parallel.has_split_parallel()); + return sbp_parallel.split_parallel().axis(); +} + +int64_t OpGraph::GetSplitNum(const std::string& op_name, const LogicalBlobId& lbi) const { + OpNode* op_node = op_name2op_node_.at(GetOpNameKey(op_name, lbi)); + const LogicalBlobId& lbi_key = GetLogicalBlobIdKey(op_name, lbi); + const SbpParallel& sbp_parallel = op_node->SbpParallel4Lbi(lbi_key); + CHECK(sbp_parallel.has_split_parallel()); + return op_node->ProducerOpNode4Lbi(lbi)->LogicalBlobDesc4Lbi(lbi_key).shape().At( + sbp_parallel.split_parallel().axis()); +} + +const SbpParallel& OpGraph::GetSbpParallel(const std::string& op_name, + const LogicalBlobId& lbi) const { + return op_name2op_node_.at(GetOpNameKey(op_name, lbi)) + ->SbpParallel4Lbi(GetLogicalBlobIdKey(op_name, lbi)); +} + +DataType OpGraph::GetBlobDataType(const LogicalBlobId& lbi) const { + return op_name2op_node_.at(lbi.op_name()) + ->NoParallelBlobDesc4Lbi(GetLogicalBlobIdKey(lbi.op_name(), lbi)) + .data_type(); +} + +bool OpGraph::IsModelBlob(const std::string& op_name, const LogicalBlobId& lbi) const { + return op_name2op_node_.at(GetOpNameKey(op_name, lbi)) + ->IsModelBlob4Lbi(GetLogicalBlobIdKey(op_name, lbi)); +} + +bool OpGraph::IsDataBlob(const std::string& op_name, const LogicalBlobId& lbi) const { + return !IsModelBlob(op_name, lbi); +} + +void OpGraph::CheckBlobDescs(const std::string& op_name, + const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + if (op_name2op_node_.find(op_name) == op_name2op_node_.end()) { return; } + op_name2op_node_.at(op_name)->CheckBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); +} + +void OpGraph::ForEachPseudoChain( + const std::function<void(const HashSet<OpNode*>&)>& Handler) const { + auto IsReachable = MakePredicatorIsReachable(); + ForEachComponentWithSameDataParallelDescAndTimeShape( + [&](const std::vector<OpNode*>& nodes) { ForEachPseudoChain(nodes, IsReachable, Handler); }); +} + +std::function<bool(OpNode* src, OpNode* dst)> OpGraph::MakePredicatorIsReachable() const { + auto node2ancestors_ptr = std::make_shared<HashMap<OpNode*, HashSet<OpNode*>>>(); + TopoForEachNode([&](OpNode* node) { + node->ForEachNodeOnInEdge([&](OpNode* in_node) { + (*node2ancestors_ptr)[node].insert(in_node); + (*node2ancestors_ptr)[node].insert((*node2ancestors_ptr)[in_node].begin(), + (*node2ancestors_ptr)[in_node].end()); + }); + }); + return [node2ancestors_ptr](OpNode* src, OpNode* dst) -> bool { + return node2ancestors_ptr->at(dst).find(src) != node2ancestors_ptr->at(dst).end(); + }; +} + +void OpGraph::ForEachComponentWithSameDataParallelDescAndTimeShape( + const std::function<void(const std::vector<OpNode*>&)>& Handler) const { + auto WithSameDataParallelDescAndTimeShape = [](OpNode* src, OpNode* dst) -> bool { + if (src->parallel_desc().policy() != ParallelPolicy::kDataParallel) { return false; } + if (dst->parallel_desc().policy() != ParallelPolicy::kDataParallel) { return false; } + if (src->in_edges().empty()) { return false; } + if (*src->GetInputBlobTimeShape() != src->out_blob_time_shape()) { return false; } + if (*dst->GetInputBlobTimeShape() != dst->out_blob_time_shape()) { return false; } + return src->parallel_desc() == dst->parallel_desc() + && src->out_blob_time_shape() == dst->out_blob_time_shape(); + }; + auto ForEachNext = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) { + node->ForEachNodeOnInEdge([&](OpNode* in_node) { + if (WithSameDataParallelDescAndTimeShape(in_node, node)) { Handler(in_node); } + }); + node->ForEachNodeOnOutEdge([&](OpNode* out_node) { + if (WithSameDataParallelDescAndTimeShape(node, out_node)) { Handler(out_node); } + }); + }; + HashMap<OpNode*, int32_t> op_node2component_id; + int32_t cur_component_id = 0; + ForEachNode([&](OpNode* start) { + if (op_node2component_id.find(start) != op_node2component_id.end()) { return; } + ++cur_component_id; + BfsForEachNode({start}, ForEachNext, [&](OpNode* node) { + CHECK(op_node2component_id.emplace(node, cur_component_id).second); + }); + }); + HashMap<int32_t, std::vector<OpNode*>> component_id2op_nodes; + for (const auto& pair : op_node2component_id) { + component_id2op_nodes[pair.second].push_back(pair.first); + } + for (const auto& pair : component_id2op_nodes) { Handler(pair.second); } +} + +void OpGraph::ForEachPseudoChain( + const std::vector<OpNode*>& nodes, + const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable, + const std::function<void(const HashSet<OpNode*>&)>& Handler) const { + if (nodes.size() <= 1) { return; } + if (nodes.front()->parallel_desc().device_type() == DeviceType::kCPU) { return; } + if (nodes.front()->parallel_desc().policy() != ParallelPolicy::kDataParallel) { return; } + HashSet<OpNode*> all_nodes(nodes.begin(), nodes.end()); + while (all_nodes.size() > 1) { + HashSet<OpNode*> chain; + ReverseTopoGetPseudoChain(all_nodes, &chain, IsReachable); + Handler(chain); + for (OpNode* node_in_chain : chain) { all_nodes.erase(node_in_chain); } + } +} + +void OpGraph::ReverseTopoGetPseudoChain( + const HashSet<OpNode*>& op_nodes, HashSet<OpNode*>* pseudo_chain_nodes, + const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable) const { + // get sink nodes + std::list<OpNode*> sinks; + auto IsSink = [&](OpNode* node) { + for (OpNode* inner_node : op_nodes) { + if (IsReachable(node, inner_node)) { return false; } + } + return true; + }; + for (OpNode* op_node : op_nodes) { + if (IsSink(op_node)) { sinks.push_back(op_node); } + } + // generate connections of subgraph + HashMap<OpNode*, std::vector<OpNode*>> node2in_nodes; + HashMap<OpNode*, std::vector<OpNode*>> node2out_nodes; + auto IsInSubset = [&](OpNode* node) { return op_nodes.find(node) != op_nodes.end(); }; + auto ReachableToAnySink = [&](OpNode* node) { + for (OpNode* sink : sinks) { + if (node == sink) { return true; } + if (IsReachable(node, sink)) { return true; } + } + return false; + }; + auto AnyOutputNodesNotInSubsetAndReachableIntoSink = [&](OpNode* node) { + for (OpEdge* edge : node->out_edges()) { + if (!IsInSubset(edge->dst_node()) && ReachableToAnySink(edge->dst_node())) { return true; } + } + return false; + }; + for (OpNode* node : op_nodes) { + if (AnyOutputNodesNotInSubsetAndReachableIntoSink(node)) { continue; } + node->ForEachNodeOnOutEdge([&](OpNode* out_node) { + if (IsInSubset(out_node)) { + node2in_nodes[out_node].push_back(node); + node2out_nodes[node].push_back(out_node); + } + }); + } + // get chain nodes + auto ForEachInNode = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) { + for (OpNode* in_node : node2in_nodes[node]) { Handler(in_node); } + }; + auto ForEachOutNode = [&](OpNode* node, const std::function<void(OpNode*)>& Handler) { + for (OpNode* out_node : node2out_nodes[node]) { Handler(out_node); } + }; + TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, + [&](OpNode* node) { CHECK(pseudo_chain_nodes->emplace(node).second); }); +} + +std::string OpGraph::GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const { + CHECK(!lbi.has_is_packed_id()); + std::string op_name_key; + if (op_name2op_node_.find(op_name) == op_name2op_node_.end()) { + CHECK(lbi.has_clone_id()); + return lbi.op_name(); + } else { + CHECK(!lbi.has_clone_id()); + return op_name; + } +} + +LogicalBlobId OpGraph::GetLogicalBlobIdKey(const std::string& op_name, + const LogicalBlobId& lbi) const { + CHECK(!lbi.has_is_packed_id()); + if (op_name2op_node_.find(op_name) == op_name2op_node_.end()) { + CHECK(lbi.has_clone_id()); + LogicalBlobId lbi_key; + lbi_key.set_op_name(lbi.op_name()); + lbi_key.set_blob_name(lbi.blob_name()); + return lbi_key; + } else { + CHECK(!lbi.has_clone_id()); + return lbi; + } +} + +} // namespace oneflow diff --git a/oneflow/core/graph/op_graph.h b/oneflow/core/graph/op_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..94d0974e986fb647490e53bc0082e75dc3f134bf --- /dev/null +++ b/oneflow/core/graph/op_graph.h @@ -0,0 +1,160 @@ +#ifndef ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ +#define ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ + +#include "oneflow/core/graph/graph.h" +#include "oneflow/core/job/job_desc.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/common/balanced_splitter.h" + +namespace oneflow { + +class OpEdge; +class OpGraph; + +class OpNode final : public Node<OpNode, OpEdge> { + public: + OF_DISALLOW_COPY_AND_MOVE(OpNode); + explicit OpNode(const ParallelDesc& parallel_desc, const OperatorConf& op_conf) + : parallel_desc_(parallel_desc), + op_(ConstructOp(op_conf, parallel_desc.device_type())), + ibns_(op_->input_bns().begin(), op_->input_bns().end()), + has_in_diff_(false) {} + ~OpNode() = default; + + // Getters + const Shape& out_blob_time_shape() const { return out_blob_time_shape_; } + const Operator& op() const { return *op_; } + bool HasBackward() const { return has_in_diff() || has_model_diff(); } + bool has_in_diff() const { return has_in_diff_; } + bool has_model_diff() const { return op().model_diff_bns().size() > 0; } + void set_has_in_diff(bool has_in_diff) { has_in_diff_ = has_in_diff; } + const ParallelDesc& parallel_desc() const { return parallel_desc_; } + + std::string VisualStr() const override; + + private: + friend class OpGraph; + friend class OpEdge; + // Getters + const BlobDesc& NoParallelBlobDesc4Lbi(const LogicalBlobId& lbi) const; + const BlobDesc& LogicalBlobDesc4Lbi(const LogicalBlobId& lbi) const; + const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const; + const Shape* GetInputBlobTimeShape(const std::string& bn_in_op) const; + const Shape* GetInputBlobTimeShape() const; + + // Setters + ParallelDesc* mut_parallel_desc() { return ¶llel_desc_; } + Shape* mut_out_blob_time_shape() { return &out_blob_time_shape_; } + HashMap<LogicalBlobId, std::vector<BlobDesc>>* mut_lbi2parallel_id2blob_desc() { + return &lbi2parallel_id2blob_desc_; + } + bool IsModelBlob4Lbi(const LogicalBlobId& lbi) const; + bool* MutIsModelBlob4Lbi(const LogicalBlobId& lbi); + BlobDesc* NoParallelBlobDesc4BnInOp(const std::string& bn_in_op); + BlobDesc* MutNoParallelBlobDesc(const LogicalBlobId& lbi); + BlobDesc* MutLogicalBlobDesc4Lbi(const LogicalBlobId& lbi); + SbpParallel* MutSbpParallel4Lbi(const LogicalBlobId& lbi); + OpNode* SrcNode4InputBnInOp(const std::string& bn_in_op) const; + OpNode* ProducerOpNode4BnInOp(const std::string& bn_in_op); + OpNode* SrcNode4InputLbi(const LogicalBlobId& lbi) const; + OpNode* ProducerOpNode4Lbi(const LogicalBlobId& lbi); + const OpNode* ProducerOpNode4Lbi(const LogicalBlobId& lbi) const; + + void ForEachParallelBlobDesc(const BlobDesc& blob_desc, const SbpParallel& sbp_parallel, + const std::function<void(const BlobDesc&)>& Handler) const; + int64_t GetAxisParallelNum( + const std::function<void(bool*, int32_t*, int64_t*)>& GetAxisParallelInfo) const; + void ConcatBlobDesc(const std::vector<BlobDesc>& blob_descs, const SbpParallel& sbp_parallel, + BlobDesc* concatenated_blob_desc) const; + void SplitLogicalInputBlobDesc(); + void ConcatLogicalOutputBlobDesc(); + void CheckBlobDescs(const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const; + + ParallelDesc parallel_desc_; + std::shared_ptr<Operator> op_; + HashSet<std::string> ibns_; + bool has_in_diff_; + Shape out_blob_time_shape_; + HashMap<LogicalBlobId, BlobDesc> lbi2no_parallel_blob_desc_; + HashMap<LogicalBlobId, bool> lbi2is_model_blob_; + HashMap<LogicalBlobId, SbpParallel> lbi2sbp_parallel_; + HashMap<LogicalBlobId, std::vector<BlobDesc>> lbi2parallel_id2blob_desc_; + HashMap<LogicalBlobId, BlobDesc> lbi2logical_blob_desc_; +}; + +class OpEdge final : public Edge<OpNode, OpEdge> { + public: + OF_DISALLOW_COPY_AND_MOVE(OpEdge); + explicit OpEdge(const std::vector<LogicalBlobId>& lbis, + const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns) + : lbis_(lbis), lbi2ibns_(lbi2ibns), has_diff_(false) {} + ~OpEdge() = default; + + const std::vector<LogicalBlobId>& lbis() const { return lbis_; } + const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns() const { return lbi2ibns_; } + bool has_diff() const { return has_diff_; } + std::string VisualStr() const override; + + void set_has_diff(bool val) { has_diff_ = val; } + + private: + std::vector<LogicalBlobId> lbis_; + HashMap<LogicalBlobId, std::vector<std::string>> lbi2ibns_; + bool has_diff_; +}; + +class OpGraph final : public Graph<OpNode, OpEdge> { + public: + OF_DISALLOW_COPY_AND_MOVE(OpGraph); + explicit OpGraph(const JobDesc* job_desc) : job_desc_(job_desc) { Init(); } + ~OpGraph() = default; + + void InferOpModelSize(HashMap<std::string, size_t>* op_name2model_size); + + int32_t GetModelSplitAxis(const std::string& op_name, const LogicalBlobId& lbi) const; + BalancedSplitter GetBalancedSplitter(const std::string& op_name, const LogicalBlobId& lbi) const; + const SbpParallel& GetSbpParallel(const std::string& op_name, const LogicalBlobId& lbi) const; + DataType GetBlobDataType(const LogicalBlobId& lbi) const; + void CheckBlobDescs(const std::string& op_name, + const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const; + + // a set of nodes is called a pseudo chain if they can merge into a chain regardless of the + // connections before their source nodes + void ForEachPseudoChain(const std::function<void(const HashSet<OpNode*>&)>& Handler) const; + + private: + void Init(); + void InitNodes(); + void InitEdges(); + void FixOpParallelDesc() const; + void UpdateOpNodeHasInDiff() const; + void InferTimeShape() const; + void InferNoParallelBlobDesc() const; + void InferIsModelBlob() const; + void InferSbpParallel() const; + void InferLogicalBlobDesc() const; + bool IsModelBlob(const std::string& op_name, const LogicalBlobId& lbi) const; + bool IsDataBlob(const std::string& op_name, const LogicalBlobId& lbi) const; + std::string GetOpNameKey(const std::string& op_name, const LogicalBlobId& lbi) const; + LogicalBlobId GetLogicalBlobIdKey(const std::string& op_name, const LogicalBlobId& lbi) const; + void ForEachPseudoChain(const std::vector<OpNode*>& nodes, + const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable, + const std::function<void(const HashSet<OpNode*>&)>& Handler) const; + void ReverseTopoGetPseudoChain( + const HashSet<OpNode*>& op_nodes, HashSet<OpNode*>* chain, + const std::function<bool(OpNode* src, OpNode* dst)>& IsReachable) const; + std::function<bool(OpNode* src, OpNode* dst)> MakePredicatorIsReachable() const; + void ForEachComponentWithSameDataParallelDescAndTimeShape( + const std::function<void(const std::vector<OpNode*>&)>& Handler) const; + + int64_t GetSplitNum(const std::string& op_name, const LogicalBlobId& lbi) const; + const JobDesc* job_desc_; + HashMap<std::string, OpNode*> op_name2op_node_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_ diff --git a/oneflow/core/graph/pack_forward_task_node.cpp b/oneflow/core/graph/pack_forward_task_node.cpp index ff6bdec8b520cf75e49b4627410ddeed2e50018d..36ff141f95e91e62078c530bdc9a3e98ac693113 100644 --- a/oneflow/core/graph/pack_forward_task_node.cpp +++ b/oneflow/core/graph/pack_forward_task_node.cpp @@ -51,8 +51,7 @@ void PackForwardCompTaskNode::BuildExecGphAndRegst() { *out_blob = *in_blob; const PackOp* pack_op = dynamic_cast<const PackOp*>(op.get()); CHECK_NOTNULL(pack_op); - CHECK_EQ(pack_op->GetPackNum(parallel_ctx()->parallel_num()), - related_unpack_in_blob->shape().At(0) / in_blob->shape().At(0)); + CHECK_EQ(pack_op->GetPackNum(), related_unpack_in_blob->shape().At(0) / in_blob->shape().At(0)); out_blob->mut_shape().Set(0, related_unpack_in_blob->shape().At(0)); if (out_blob->has_dim0_valid_num_field()) { out_blob->mut_dim0_inner_shape() = related_unpack_in_blob->dim0_inner_shape(); @@ -66,7 +65,7 @@ void PackForwardCompTaskNode::InferProducedDataRegstTimeShape() { const PackOp* pack_op = dynamic_cast<const PackOp*>(logical_node()->SoleOp().get()); CHECK_NOTNULL(pack_op); - int64_t pack_num = pack_op->GetPackNum(parallel_ctx()->parallel_num()); + int64_t pack_num = pack_op->GetPackNum(); CHECK_GT(time_shape_dim_vec.size(), 0); CHECK_EQ(pack_num, time_shape_dim_vec.back()); time_shape_dim_vec.pop_back(); diff --git a/oneflow/core/graph/reduce_comp_task_node_if.h b/oneflow/core/graph/reduce_comp_task_node_if.h index 0131bec319a1f5e9b1d75d518a1f915608b98ad5..988fcc18dca217298bfd8601ef666af480d6cb60 100644 --- a/oneflow/core/graph/reduce_comp_task_node_if.h +++ b/oneflow/core/graph/reduce_comp_task_node_if.h @@ -3,7 +3,7 @@ #include "oneflow/core/register/register_desc.h" #include "oneflow/core/graph/compute_task_node.h" -#include "logical_node.h" +#include "oneflow/core/graph/logical_node.h" namespace oneflow { diff --git a/oneflow/core/graph/reduce_concat_compute_task_node.h b/oneflow/core/graph/reduce_concat_compute_task_node.h index 36a32511178b1cb325f944bcfd3097848c2e7583..00d8c2973c81e14b7ce70f27210059c26fef4e37 100644 --- a/oneflow/core/graph/reduce_concat_compute_task_node.h +++ b/oneflow/core/graph/reduce_concat_compute_task_node.h @@ -16,7 +16,7 @@ class ReduceConcatCompTaskNode final : public CompTaskNode, public ReduceCompTas void ConsumeAllRegsts() override; TaskType GetTaskType() const override { return TaskType::kReduceConcat; } - CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; } + CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kReduceCtrl; } void EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) override; diff --git a/oneflow/core/graph/reduce_identity_task_node.cpp b/oneflow/core/graph/reduce_identity_task_node.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c889d3cfd0a9a70b78b76a4e9d60320ceec64002 --- /dev/null +++ b/oneflow/core/graph/reduce_identity_task_node.cpp @@ -0,0 +1,37 @@ +#include "oneflow/core/graph/reduce_identity_task_node.h" +#include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/operator/reduce_identity_op.h" + +namespace oneflow { + +void ReduceIdentityCompTaskNode::ProduceAllRegstsAndBindEdges() { + ProduceRegst("out", false); + BindEdgeWithProducedRegst(SoleOutEdge(), "out"); +} + +void ReduceIdentityCompTaskNode::EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) { + ctx.EnableMemSharing4Regst(GetProducedRegst("out").get(), 0); + ctx.EnableMemSharing4Regst(GetSoleConsumedRegst("in").get(), 0); +} + +void ReduceIdentityCompTaskNode::ConsumeAllRegsts() { + ConsumeRegst("in", SoleInEdge()->GetSoleRegst()); +} + +void ReduceIdentityCompTaskNode::BuildExecGphAndRegst() { + std::shared_ptr<const Operator> op = logical_node()->SoleOp(); + ExecNode* exec_node = mut_exec_gph().NewNode(); + exec_node->mut_op() = op; + exec_node->BindBnWithRegst(op->SoleIbn(), GetSoleConsumedRegst("in")); + + std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out"); + out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); + exec_node->BindBnWithRegst(op->SoleObn(), out_regst); + exec_node->InferBlobDescs(parallel_ctx()); +} + +void ReduceIdentityCompTaskNode::InferProducedDataRegstTimeShape() { + NaiveInferProducedDataRegstTimeShape(); +} + +} // namespace oneflow diff --git a/oneflow/core/graph/reduce_identity_task_node.h b/oneflow/core/graph/reduce_identity_task_node.h new file mode 100644 index 0000000000000000000000000000000000000000..2f2c5d7b794b018af753887e825ccd774230d7e9 --- /dev/null +++ b/oneflow/core/graph/reduce_identity_task_node.h @@ -0,0 +1,31 @@ +#ifndef ONEFLOW_CORE_GRAPH_REDUCE_IDENTITY_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_REDUCE_IDENTITY_TASK_NODE_H_ + +#include "oneflow/core/graph/compute_task_node.h" +#include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/graph/pipe_compute_task_node.h" +#include "oneflow/core/graph/reduce_comp_task_node_if.h" + +namespace oneflow { + +class ReduceIdentityCompTaskNode final : public CompTaskNode, public ReduceCompTaskNodeIf { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceIdentityCompTaskNode); + ReduceIdentityCompTaskNode() = default; + ~ReduceIdentityCompTaskNode() override = default; + + TaskType GetTaskType() const override { return TaskType::kReduceIdentity; } + CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kReduceCtrl; } + void EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) override; + + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() override; + + private: + void BuildExecGphAndRegst() override; + void InferProducedDataRegstTimeShape() override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_REDUCE_IDENTITY_TASK_NODE_H_ diff --git a/oneflow/core/graph/reduce_split_compute_task_node.cpp b/oneflow/core/graph/reduce_split_compute_task_node.cpp index c31fd0c327c94cec2fa0bb04aabc4137e24e76d8..245aa773fd269a9179673f8fef58c37fb87b91e2 100644 --- a/oneflow/core/graph/reduce_split_compute_task_node.cpp +++ b/oneflow/core/graph/reduce_split_compute_task_node.cpp @@ -5,6 +5,19 @@ namespace oneflow { +namespace { + +int32_t GetDataRegstDescCnt( + const HashMap<std::string, std::shared_ptr<RegstDesc>> name2regst_desc) { + size_t cnt = 0; + for (const auto& pair : name2regst_desc) { + cnt += pair.second->regst_desc_type().has_data_regst_desc(); + } + return cnt; +} + +} // namespace + void ReduceSplitCompTaskNode::ProduceAllRegstsAndBindEdges() { std::vector<EdgeInfo> edge_infos; for (TaskEdge* edge : out_edges()) { @@ -32,17 +45,23 @@ void ReduceSplitCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", this->SoleInEdge()->GetSoleRegst()); } +TaskNode* ReduceSplitCompTaskNode::GetPrevReduceTaskNode(TaskType task_type) { + CHECK(task_type == TaskType::kReduceConcat || task_type == TaskType::kReduceIdentity); + TaskNode* task_node = + FindPredReduceTaskNodeIf([&](TaskNode* node) { return node->GetTaskType() == task_type; }); + CHECK_NOTNULL(task_node); + return task_node; +} + void ReduceSplitCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr<Operator> reduce_split_op = this->logical_node()->SoleOp(); node->mut_op() = reduce_split_op; node->BindBnWithRegst(reduce_split_op->SoleIbn(), GetSoleConsumedRegst("in")); - TaskNode* reduce_concat_node = FindPredReduceTaskNodeIf( - [](TaskNode* node) { return node->GetTaskType() == TaskType::kReduceConcat; }); - CHECK(reduce_concat_node); + TaskNode* reduce_concat_node = GetPrevReduceTaskNode(TaskType::kReduceConcat); - CHECK_EQ(reduce_concat_node->consumed_regsts().size(), produced_regsts().size()); + CHECK_EQ(reduce_concat_node->consumed_regsts().size(), GetDataRegstDescCnt(produced_regsts())); FOR_RANGE(size_t, i, 0, reduce_split_op->output_bns().size()) { std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out_" + std::to_string(i)); CHECK(out_regst.get() != nullptr); @@ -53,7 +72,7 @@ void ReduceSplitCompTaskNode::BuildExecGphAndRegst() { } void ReduceSplitCompTaskNode::FixPackedBlobDescOfProducedRegst() { - int64_t out_regst_num = produced_regsts().size(); + int64_t out_regst_num = GetDataRegstDescCnt(produced_regsts()); FOR_RANGE(int64_t, idx, 0, out_regst_num) { std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out_" + std::to_string(idx)); CHECK(out_regst->IsLocked()); @@ -65,7 +84,7 @@ void ReduceSplitCompTaskNode::FixPackedBlobDescOfProducedRegst() { void ReduceSplitCompTaskNode::EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) { CHECK_EQ(GetRankCtx().TotalSegmentCount(), 1); - size_t split_num = produced_regsts().size(); + size_t split_num = GetDataRegstDescCnt(produced_regsts()); int64_t offset = 0; FOR_RANGE(int32_t, idx, 0, split_num) { RegstDesc* split_out_regst = GetProducedRegst("out_" + std::to_string(idx)).get(); diff --git a/oneflow/core/graph/reduce_split_compute_task_node.h b/oneflow/core/graph/reduce_split_compute_task_node.h index 8fd05420a2932f42f0f5f061bc2a8118025cf64c..87a0028de529b06278bd4b41226841d3e4d1d07e 100644 --- a/oneflow/core/graph/reduce_split_compute_task_node.h +++ b/oneflow/core/graph/reduce_split_compute_task_node.h @@ -16,9 +16,11 @@ class ReduceSplitCompTaskNode final : public CompTaskNode, public ReduceCompTask void ConsumeAllRegsts() override; TaskType GetTaskType() const override { return TaskType::kReduceSplit; } - CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kMix; } + CudaWorkType GetCudaWorkType() const override { return CudaWorkType::kReduceCtrl; } void EnableMemSharingInReduce(const ReduceMemSharingCtx& ctx) override; + TaskNode* GetPrevReduceTaskNode(TaskType task_type); + private: void BuildExecGphAndRegst() override; void FixPackedBlobDescOfProducedRegst() override; diff --git a/oneflow/core/graph/regst_lifetime_graph.cpp b/oneflow/core/graph/regst_lifetime_graph.cpp index dd7fa852de8c7ee7ba95eb4292e18bfdaf6c89b6..645949abf5eb928a403dae3efa5fc1e51b07af61 100644 --- a/oneflow/core/graph/regst_lifetime_graph.cpp +++ b/oneflow/core/graph/regst_lifetime_graph.cpp @@ -3,17 +3,17 @@ namespace oneflow { RegstLifetimeGraph::RegstLifetimeGraph( - const std::list<const RegstDescProto*>& regst_descs, + const std::vector<const RegstDescProto*>& regst_descs, const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds) { - std::list<RegstLifetimeNode*> nodes; + std::vector<RegstLifetimeNode*> nodes; InitNodes(regst_descs, ComputeLifetimeActorIds, &nodes); InitEdges(nodes); } void RegstLifetimeGraph::InitNodes( - const std::list<const RegstDescProto*>& regst_descs, + const std::vector<const RegstDescProto*>& regst_descs, const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds, - std::list<RegstLifetimeNode*>* nodes) { + std::vector<RegstLifetimeNode*>* nodes) { for (const RegstDescProto* regst_desc : regst_descs) { auto lifetime_actor_ids = std::make_unique<HashSet<int64_t>>(); ComputeLifetimeActorIds(regst_desc, lifetime_actor_ids.get()); @@ -23,7 +23,7 @@ void RegstLifetimeGraph::InitNodes( } } -void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) { +void RegstLifetimeGraph::InitEdges(const std::vector<RegstLifetimeNode*>& nodes) { HashMap<int64_t, HashSet<RegstLifetimeNode*>> task_id2intersected_nodes; for (RegstLifetimeNode* node : nodes) { for (int64_t task_id : node->lifetime_actor_ids()) { @@ -46,25 +46,26 @@ void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) { } void RegstLifetimeGraph::ForEachSameColoredRegstDescs( - const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) const { + const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) const { + std::vector<const RegstLifetimeNode*> nodes; + ForEachNode([&](const RegstLifetimeNode* node) { nodes.push_back(node); }); + std::sort(nodes.begin(), nodes.end(), + [&](const RegstLifetimeNode* lhs, const RegstLifetimeNode* rhs) { + return lhs->byte_size() > rhs->byte_size(); + }); HashMap<const RegstLifetimeNode*, std::set<int32_t>> node2excluded_color_ids; HashMap<const RegstLifetimeNode*, int32_t> node2color_id; - auto ForEachIntersected = &RegstLifetimeNode::ForEachNodeOnInOutEdge; - ForEachNode([&](const RegstLifetimeNode* start) { - if (node2color_id.find(start) != node2color_id.end()) { return; } - BfsForEachNode({start}, ForEachIntersected, [&](const RegstLifetimeNode* node) { - if (node2color_id.find(node) != node2color_id.end()) { return; } - int32_t color_id = 0; - const auto& excluded_color_ids = node2excluded_color_ids[node]; - for (; excluded_color_ids.find(color_id) != excluded_color_ids.end(); ++color_id) {} - node2color_id[node] = color_id; - (node->*ForEachIntersected)([&](const RegstLifetimeNode* intersected) { - if (node2color_id.find(intersected) != node2color_id.end()) { return; } - node2excluded_color_ids[intersected].insert(color_id); - }); + for (const RegstLifetimeNode* node : nodes) { + int32_t color_id = 0; + const auto& excluded_color_ids = node2excluded_color_ids[node]; + for (; excluded_color_ids.find(color_id) != excluded_color_ids.end(); ++color_id) {} + node2color_id[node] = color_id; + node->ForEachNodeOnInOutEdge([&](const RegstLifetimeNode* intersected) { + if (node2color_id.find(intersected) != node2color_id.end()) { return; } + node2excluded_color_ids[intersected].insert(color_id); }); - }); - HashMap<int32_t, std::list<const RegstDescProto*>> color_id2regst_descs; + } + HashMap<int32_t, std::vector<const RegstDescProto*>> color_id2regst_descs; for (const auto& pair : node2color_id) { color_id2regst_descs[pair.second].push_back(&pair.first->regst_desc()); } diff --git a/oneflow/core/graph/regst_lifetime_graph.h b/oneflow/core/graph/regst_lifetime_graph.h index d5aa92271bf6538ad9793c804ef407e81d835622..d60585c7aac5fb7a9fd9dc36f9dbb9848a6b1754 100644 --- a/oneflow/core/graph/regst_lifetime_graph.h +++ b/oneflow/core/graph/regst_lifetime_graph.h @@ -3,6 +3,7 @@ #include "oneflow/core/graph/graph.h" #include "oneflow/core/register/register_desc.pb.h" +#include "oneflow/core/register/runtime_register_desc.h" namespace oneflow { @@ -20,37 +21,41 @@ class RegstLifetimeNode final : public Node<RegstLifetimeNode, RegstLifetimeEdge OF_DISALLOW_COPY_AND_MOVE(RegstLifetimeNode); RegstLifetimeNode(const RegstDescProto* regst_desc, std::unique_ptr<HashSet<int64_t>>&& lifetime_actor_ids) - : regst_desc_(regst_desc), lifetime_actor_ids_(std::move(lifetime_actor_ids)) {} + : regst_desc_(regst_desc), + lifetime_actor_ids_(std::move(lifetime_actor_ids)), + byte_size_(RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst()) { + CHECK_EQ(regst_desc->register_num(), 1); + } ~RegstLifetimeNode() = default; int64_t regst_desc_id() const { return regst_desc().regst_desc_id(); } const RegstDescProto& regst_desc() const { return *regst_desc_; } const HashSet<int64_t>& lifetime_actor_ids() const { return *lifetime_actor_ids_; } + size_t byte_size() const { return byte_size_; } private: const RegstDescProto* regst_desc_; std::unique_ptr<HashSet<int64_t>> lifetime_actor_ids_; + size_t byte_size_; }; class RegstLifetimeGraph final : public Graph<const RegstLifetimeNode, RegstLifetimeEdge> { public: OF_DISALLOW_COPY_AND_MOVE(RegstLifetimeGraph); RegstLifetimeGraph( - const std::list<const RegstDescProto*>& regst_descs, + const std::vector<const RegstDescProto*>& regst_descs, const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds); ~RegstLifetimeGraph() = default; void ForEachSameColoredRegstDescs( - const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) const; + const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) const; private: void InitNodes( - const std::list<const RegstDescProto*>& regst_descs, + const std::vector<const RegstDescProto*>& regst_descs, const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds, - std::list<RegstLifetimeNode*>* nodes); - void InitEdges(const std::list<RegstLifetimeNode*>& nodes); - HashMap<const RegstLifetimeNode*, HashSet<const RegstLifetimeNode*>> - regst_lifetime_node2intersected_nodes_; + std::vector<RegstLifetimeNode*>* nodes); + void InitEdges(const std::vector<RegstLifetimeNode*>& nodes); }; } // namespace oneflow diff --git a/oneflow/core/graph/repeat_backward_compute_task_node.cpp b/oneflow/core/graph/repeat_backward_compute_task_node.cpp index e6c32ce9d1558cb27bc77538f8444a3f6002c42e..7474ee42816767e35e3b80a003a99e10a73d1378 100644 --- a/oneflow/core/graph/repeat_backward_compute_task_node.cpp +++ b/oneflow/core/graph/repeat_backward_compute_task_node.cpp @@ -34,7 +34,7 @@ void RepeatBackwardCompTaskNode::InferProducedDataRegstTimeShape() { CHECK(this->logical_node()->SoleOp()->op_conf().has_repeat_conf()); const RepeatOp* repeat_op = dynamic_cast<RepeatOp*>(this->logical_node()->SoleOp().get()); CHECK_NOTNULL(repeat_op); - int32_t repeat_num = repeat_op->GetRepeatNum(parallel_ctx()->parallel_num()); + int32_t repeat_num = repeat_op->GetRepeatNum(); CHECK(!time_shape_dim_vec.empty()); CHECK(time_shape_dim_vec.back() == repeat_num); time_shape_dim_vec.pop_back(); diff --git a/oneflow/core/graph/repeat_forward_compute_task_node.cpp b/oneflow/core/graph/repeat_forward_compute_task_node.cpp index 5022bebc802b8712766bd49868a2a98318818e55..f52509380b4c1ef85bdc21269007a3bf41317052 100644 --- a/oneflow/core/graph/repeat_forward_compute_task_node.cpp +++ b/oneflow/core/graph/repeat_forward_compute_task_node.cpp @@ -33,7 +33,7 @@ void RepeatForwardCompTaskNode::InferProducedDataRegstTimeShape() { GetSoleConsumedRegst("in")->data_regst_time_shape()->dim_vec(); const RepeatOp* repeat_op = dynamic_cast<RepeatOp*>(this->logical_node()->SoleOp().get()); CHECK_NOTNULL(repeat_op); - int32_t repeat_num = repeat_op->GetRepeatNum(parallel_ctx()->parallel_num()); + int32_t repeat_num = repeat_op->GetRepeatNum(); time_shape_dim_vec.push_back(repeat_num); GetProducedRegst("out")->mut_data_regst_time_shape()->reset(new Shape(time_shape_dim_vec)); } diff --git a/oneflow/core/graph/sharable_mem_block_graph.cpp b/oneflow/core/graph/sharable_mem_block_graph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50f5f7205fa05ec497f09e62094583e0c2a792e7 --- /dev/null +++ b/oneflow/core/graph/sharable_mem_block_graph.cpp @@ -0,0 +1,81 @@ +#include "oneflow/core/graph/sharable_mem_block_graph.h" +#include "oneflow/core/register/register_desc.h" +#include "oneflow/core/register/runtime_register_desc.h" + +namespace oneflow { + +namespace { + +bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc, + const PlanTaskGraph& plan_task_graph) { + auto ChainId4TaskId = [&](int64_t task_id) { + return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id(); + }; + int64_t producer_chain_id = ChainId4TaskId(regst_desc.producer_task_id()); + for (int64_t consumer_task_id : regst_desc.consumer_task_id()) { + if (ChainId4TaskId(consumer_task_id) != producer_chain_id) { return false; } + } + return true; +} + +} // namespace + +SharableMemBlockGraph::SharableMemBlockGraph( + const PlanTaskGraph& plan_task_gph, + const std::function<bool(const RegstDescProto&)>& IsSharable) { + auto ForEachSharableChainRegstDesc = + [&](const std::function<void(int64_t, const RegstDescProto&)>& Handler) { + for (const TaskProto& task : plan_task_gph.plan().task()) { + for (const auto& pair : task.produced_regst_desc()) { + if (IsConsumersAndProducerInSameChain(pair.second, plan_task_gph) + && IsSharable(pair.second)) { + Handler(task.task_set_info().chain_id(), pair.second); + } + } + } + }; + HashMap<std::pair<int64_t, MemBlock>, HashSet<const RegstDescProto*>> + chain_id7mem_block2regst_descs; + HashSet<int64_t> mem_block_ids_check; + ForEachSharableChainRegstDesc([&](int64_t chain_id, const RegstDescProto& regst_desc) { + int32_t idx = 0; + for (const auto& mem_block : regst_desc.mem_block_hierarchy()) { + if (idx++ == 0) { CHECK(mem_block_ids_check.emplace(mem_block.mem_block_id()).second); } + auto& regst_descs = chain_id7mem_block2regst_descs[std::make_pair(chain_id, mem_block)]; + CHECK(regst_descs.emplace(®st_desc).second); + } + }); + HashMap<std::pair<int64_t, MemBlock>, SharableMemBlockNode*> chain_id7mem_block2node; + for (const auto& pair : chain_id7mem_block2regst_descs) { + auto* node = + new SharableMemBlockNode(pair.first.first, pair.first.second, pair.second, plan_task_gph); + AddAllocatedNode(node); + CHECK(chain_id7mem_block2node.emplace(pair.first, node).second); + } + HashSet<const SharableMemBlockNode*> connected_children; + ForEachSharableChainRegstDesc([&](int64_t chain_id, const RegstDescProto& regst_desc) { + SharableMemBlockNode* child = nullptr; + for (const auto& mem_block : regst_desc.mem_block_hierarchy()) { + auto* parent = chain_id7mem_block2node.at(std::make_pair(chain_id, mem_block)); + if (child != nullptr && connected_children.find(child) == connected_children.end()) { + Connect(parent, NewEdge(), child); + CHECK(connected_children.emplace(child).second); + } + child = parent; + } + }); +} + +void SharableMemBlockGraph::ForEachSourceNodeGroup( + const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy, + const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const { + HashMap<int64_t, std::vector<const SharableMemBlockNode*>> group_key2source_nodes; + for (const SharableMemBlockNode* source : source_nodes()) { + group_key2source_nodes[GroupBy(source)].push_back(source); + } + for (const auto& pair : group_key2source_nodes) { + if (pair.second.size() > 1) { Handler(pair.second); } + } +} + +} // namespace oneflow diff --git a/oneflow/core/graph/sharable_mem_block_graph.h b/oneflow/core/graph/sharable_mem_block_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..33354fddb6975dde37be6e1510fab5c0e12c3378 --- /dev/null +++ b/oneflow/core/graph/sharable_mem_block_graph.h @@ -0,0 +1,55 @@ +#ifndef ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_ +#define ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_ + +#include "oneflow/core/graph/graph.h" +#include "oneflow/core/register/register_desc.pb.h" +#include "oneflow/core/graph/plan_task_graph.h" + +namespace oneflow { + +class SharableMemBlockEdge; + +class SharableMemBlockNode final : public Node<SharableMemBlockNode, SharableMemBlockEdge> { + public: + OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockNode); + SharableMemBlockNode(int64_t chain_id, const MemBlock& mem_block, + const HashSet<const RegstDescProto*>& regst_descs, + const PlanTaskGraph& plan_task_graph) + : chain_id_(chain_id), + mem_block_(mem_block), + regst_descs_(regst_descs.begin(), regst_descs.end()) {} + + ~SharableMemBlockNode() = default; + + int64_t chain_id() const { return chain_id_; } + const std::vector<const RegstDescProto*>& regst_descs() const { return regst_descs_; } + const MemBlock& mem_block() const { return mem_block_; } + + private: + const int64_t chain_id_; + const MemBlock mem_block_; + const std::vector<const RegstDescProto*> regst_descs_; +}; + +class SharableMemBlockEdge final : public Edge<SharableMemBlockNode, SharableMemBlockEdge> { + public: + OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockEdge); + SharableMemBlockEdge() = default; + ~SharableMemBlockEdge() = default; +}; + +class SharableMemBlockGraph final : public Graph<const SharableMemBlockNode, SharableMemBlockEdge> { + public: + OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockGraph); + SharableMemBlockGraph(const PlanTaskGraph& plan_task_gph, + const std::function<bool(const RegstDescProto&)>& IsSharable); + ~SharableMemBlockGraph() = default; + + void ForEachSourceNodeGroup( + const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy, + const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index e74e65afdc01befb740a23c506e03f81584b79a2..daa1efd3ddfebae1fe2b35f2a1b5eecf27d6661f 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -11,9 +11,83 @@ #include "oneflow/core/graph/reduce_split_compute_task_node.h" #include "oneflow/core/register/runtime_blob_desc.h" #include "oneflow/core/job/thrd_id_generator.h" +#include "oneflow/core/graph/reduce_identity_task_node.h" +#include "oneflow/core/operator/variable_op.h" +#include "oneflow/core/operator/constant_op.h" namespace oneflow { +namespace { + +bool IsConnectToTickOp(const TaskNode* node) { + const auto* comp_task_node = dynamic_cast<const CompTaskNode*>(node); + if (comp_task_node == nullptr) { return false; } + if (comp_task_node->logical_node()->op_vec().size() != 1) { return false; } + const Operator* op = comp_task_node->logical_node()->SoleOp().get(); + if (dynamic_cast<const VariableOp*>(op) != nullptr) { return true; } + if (dynamic_cast<const ConstantOp*>(op) != nullptr) { return true; } + return false; +} + +template<typename ReduceType = ReduceIdentityLogicalNode> +typename std::enable_if<std::is_same<ReduceType, ReduceIdentityLogicalNode>::value + || std::is_same<ReduceType, ReduceSplitLogicalNode>::value, + std::function<int32_t(const LogicalNode*)>>::type +MakeGetterReduceTaskNodeCtrlOrder(const LogicalGraph& logical_graph) { + std::vector<const ReduceType*> logical_nodes; + logical_graph.ForEachNode([&](LogicalNode* node) { + auto* logical_node = dynamic_cast<ReduceType*>(node); + if (logical_node == nullptr) { return; } + logical_nodes.push_back(logical_node); + }); + std::sort(logical_nodes.begin(), logical_nodes.end(), + [](const ReduceType* lhs, const ReduceType* rhs) { + return lhs->order_in_logical_graph() < rhs->order_in_logical_graph(); + }); + auto logical_node2ctrl_order = std::make_shared<HashMap<const LogicalNode*, int32_t>>(); + int32_t lazy_count = Global<JobDesc>::Get()->all_reduce_lazy_ratio() * logical_nodes.size(); + for (int32_t i = 0; i < logical_nodes.size(); ++i) { + int32_t ctrl_order = 0; + if (i > lazy_count) { + ctrl_order = -i; + } else { + ctrl_order = i; + } + (*logical_node2ctrl_order)[logical_nodes[i]] = ctrl_order; + } + return [logical_node2ctrl_order](const LogicalNode* identity_node) { + return logical_node2ctrl_order->at(identity_node); + }; +} + +void ForEachDeviceSrcUntrainableNode(const std::vector<NormalForwardCompTaskNode*>& fw_nodes, + const std::function<void(CompTaskNode*)>& Handler) { + HashSet<const TaskNode*> fw_nodes_set(fw_nodes.begin(), fw_nodes.end()); + auto IsSourceTaskNode = [&](NormalForwardCompTaskNode* node) { + for (TaskEdge* edge : node->in_edges()) { + if (fw_nodes_set.find(edge->src_node()) != fw_nodes_set.end()) { return false; } + } + return true; + }; + auto HasBwNode = [&](NormalForwardCompTaskNode* node) { + const auto* fw_logical_node = dynamic_cast<const ForwardLogicalNode*>(node->logical_node()); + return fw_logical_node->bw_node() != nullptr; + }; + for (NormalForwardCompTaskNode* fw_node : fw_nodes) { + if (IsSourceTaskNode(fw_node) && !HasBwNode(fw_node)) { Handler(fw_node); } + } +} + +bool IsTimeShapeContain(const Shape& big_shape, const Shape& small_shape) { + if (big_shape.NumAxes() < small_shape.NumAxes()) { return false; } + FOR_RANGE(int, i, 0, small_shape.NumAxes()) { + if (big_shape.At(i) != small_shape.At(i)) { return false; } + } + return true; +} + +} // namespace + TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) { logical_gph_ = std::move(logical_gph); HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks; @@ -105,13 +179,13 @@ void TaskGraph::AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAll std::function<void(TaskNode* node)> Handler) const { auto ForEachInNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) { node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) { - if (IsBackEdge(node_on_in_edge, node)) return; + if (IsBackEdge(node_on_in_edge, node)) { return; } Handler(const_cast<TaskNode*>(node_on_in_edge)); }); }; auto ForEachOutNode = [&](TaskNode* node, const std::function<void(TaskNode*)>& Handler) { node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) { - if (IsBackEdge(node, node_on_out_edge)) return; + if (IsBackEdge(node, node_on_out_edge)) { return; } Handler(const_cast<TaskNode*>(node_on_out_edge)); }); }; @@ -159,7 +233,8 @@ void TaskGraph::MergeChainAndSetOrderInGraphForEachNode() { void TaskGraph::BuildCtrlRegstDescInSameChain() { HashMap<int64_t, TaskNode*> chain_id2node; - for (auto node : ordered_task_nodes_) { + for (auto* node : ordered_task_nodes_) { + if (IsConnectToTickOp(node)) { continue; } int64_t chain_id = node->chain_id(); auto iter = chain_id2node.find(chain_id); if (iter == chain_id2node.end()) { @@ -171,6 +246,115 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() { } } +void TaskGraph::AddReduceSequenceCtrlEdges() { + HashMap<int64_t, std::vector<ReduceSplitCompTaskNode*>> global_thrd_id2split_nodes; + for (auto* node : ordered_task_nodes_) { + auto* split_node = dynamic_cast<ReduceSplitCompTaskNode*>(node); + if (split_node == nullptr) { continue; } + int64_t global_thrd_id = Global<IDMgr>::Get()->GlobalThrdId4TaskId(split_node->task_id()); + global_thrd_id2split_nodes[global_thrd_id].push_back(split_node); + } + auto GetCtrlOrder = MakeGetterReduceTaskNodeCtrlOrder<ReduceSplitLogicalNode>(*logical_gph_); + for (auto& pair : global_thrd_id2split_nodes) { + auto& split_nodes = pair.second; + std::sort(split_nodes.begin(), split_nodes.end(), + [&](ReduceSplitCompTaskNode* lhs, ReduceSplitCompTaskNode* rhs) { + return GetCtrlOrder(lhs->logical_node()) < GetCtrlOrder(rhs->logical_node()); + }); + ReduceSplitCompTaskNode* prev_split_node = split_nodes.at(0); + for (auto* split_node : split_nodes) { + if (prev_split_node != split_node) { + auto* to_node = split_node->GetPrevReduceTaskNode(TaskType::kReduceIdentity); + TaskNode* from_node = prev_split_node; + if (GetCtrlOrder(split_node->logical_node()) < 0) { + from_node = prev_split_node->GetPrevReduceTaskNode(TaskType::kReduceIdentity); + } + from_node->BuildCtrlRegstDescIfNeed(to_node); + } + prev_split_node = split_node; + } + } +} + +void TaskGraph::AddMdUpdtCtrlEdgesWithinReduceSplitNode() { + auto GetOrderInReduceGroup = [&](NormalMdUpdtCompTaskNode* md_updt_node) { + const auto* logical_node = + dynamic_cast<const NormalMdUpdtLogicalNode*>(md_updt_node->logical_node()); + return logical_node->order_in_reduce_group(); + }; + for (auto* node : ordered_task_nodes_) { + auto* split_node = dynamic_cast<ReduceSplitCompTaskNode*>(node); + if (split_node == nullptr) { continue; } + std::vector<NormalMdUpdtCompTaskNode*> md_updt_nodes; + split_node->ForEachNodeOnOutEdge([&](TaskNode* node) { + auto* md_updt_node = dynamic_cast<NormalMdUpdtCompTaskNode*>(node); + if (md_updt_node == nullptr) { return; } + md_updt_nodes.push_back(md_updt_node); + }); + std::sort(md_updt_nodes.begin(), md_updt_nodes.end(), + [&](NormalMdUpdtCompTaskNode* lhs, NormalMdUpdtCompTaskNode* rhs) { + return GetOrderInReduceGroup(lhs) < GetOrderInReduceGroup(rhs); + }); + NormalMdUpdtCompTaskNode* prev_md_updt = md_updt_nodes.at(0); + for (auto* md_updt_node : md_updt_nodes) { + if (md_updt_node != prev_md_updt) { prev_md_updt->BuildCtrlRegstDescIfNeed(md_updt_node); } + prev_md_updt = md_updt_node; + } + } +} + +void TaskGraph::AddReduceNoBwForwardNodeOverlapingCtrlEdges() { + HashMap<int64_t, std::vector<ReduceIdentityCompTaskNode*>> global_thrd_id2identity_nodes; + HashMap<std::pair<int64_t, int64_t>, std::vector<NormalForwardCompTaskNode*>> + global_dev_phy_id2fw_nodes; + const auto* id_mgr = Global<IDMgr>::Get(); + for (auto* node : ordered_task_nodes_) { + if (id_mgr->GetDeviceTypeFromThrdId(node->thrd_id()) == DeviceType::kCPU) { continue; } + int64_t global_thrd_id = id_mgr->GlobalThrdId4TaskId(node->task_id()); + auto* identity_node = dynamic_cast<ReduceIdentityCompTaskNode*>(node); + auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node); + if (identity_node != nullptr) { + global_thrd_id2identity_nodes[global_thrd_id].push_back(identity_node); + } else if (fw_node != nullptr) { + int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(node->thrd_id()); + global_dev_phy_id2fw_nodes[std::make_pair(node->machine_id(), dev_phy_id)].push_back(fw_node); + } else { + // do nothing + } + } + auto GetIdentityNodeOrder = [&](const ReduceIdentityCompTaskNode* id_node) { + const auto* id_logical_node = + dynamic_cast<const ReduceIdentityLogicalNode*>(id_node->logical_node()); + return id_logical_node->order_in_logical_graph(); + }; + for (auto& pair : global_thrd_id2identity_nodes) { + auto& identity_nodes = pair.second; + std::sort(identity_nodes.begin(), identity_nodes.end(), + [&](ReduceIdentityCompTaskNode* lhs, ReduceIdentityCompTaskNode* rhs) { + return GetIdentityNodeOrder(lhs) < GetIdentityNodeOrder(rhs); + }); + auto* first_identity_node = identity_nodes.at(0); + int64_t machine_id = first_identity_node->machine_id(); + int64_t dev_phy_id = id_mgr->GetGpuPhyIdFromThrdId(first_identity_node->thrd_id()); + const auto& fw_nodes = global_dev_phy_id2fw_nodes.at(std::make_pair(machine_id, dev_phy_id)); + const Shape& identity_time_shape = + *first_identity_node->GetProducedRegst("out")->data_regst_time_shape(); + ForEachDeviceSrcUntrainableNode(fw_nodes, [&](CompTaskNode* node) { + std::shared_ptr<RegstDesc> regst_desc = node->GetProducedRegst("out"); + if (!regst_desc) { return; } + const Shape& time_shape = *regst_desc->data_regst_time_shape(); + if (!IsTimeShapeContain(time_shape, identity_time_shape)) { return; } + CHECK_EQ(time_shape.elem_cnt() % identity_time_shape.elem_cnt(), 0); + int regst_desc_num = time_shape.elem_cnt() / identity_time_shape.elem_cnt(); + RegstDesc* ctrl_regst_desc = node->BuildCtrlRegstDesc(first_identity_node); + ctrl_regst_desc->UpdtMinRegstNumIfNeed(regst_desc_num); + ctrl_regst_desc->UpdtMaxRegstNumIfNeed(regst_desc_num); + ctrl_regst_desc->mut_regst_desc_type()->mutable_ctrl_regst_desc()->set_returned_regst_num( + regst_desc_num); + }); + } +} + void TaskGraph::EnableMemSharingInReduceStruct() { auto GetPredReduceTaskNode = [](TaskNode* succ) { std::vector<TaskNode*> nodes; @@ -223,6 +407,94 @@ void TaskGraph::EnableMemSharingInReduceStruct() { }); } +void TaskGraph::EnableMemSharingAfterAllManualSetForMdUpdt() { + ForEachNode([&](TaskNode* node) { + auto* updt = dynamic_cast<NormalMdUpdtCompTaskNode*>(node); + if (!updt) { return; } + updt->EnableMemSharingBetweenFirstInAndProcessedMdDiffRegst(); + }); +} + +void TaskGraph::EnableMemSharingInVariableOp() { + ForEachNode([&](TaskNode* node) { + if (node->exec_gph().node_num() != 1) { return; } + auto* variable_op = dynamic_cast<const VariableOp*>(node->exec_gph().SoleNode()->op().get()); + if (variable_op == nullptr) { return; } + std::string model_bn = variable_op->op_conf().variable_conf().model_name(); + auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node); + auto* bw_task_node = dynamic_cast<NormalBackwardCompTaskNode*>(node); + if (fw_task_node != nullptr) { + const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(model_bn); + RegstDesc* model_regst = fw_task_node->GetSoleConsumedRegst("model").get(); + CHECK_EQ(model_regst->min_register_num(), 1); + CHECK_EQ(model_regst->max_register_num(), 1); + model_regst->set_enable_mem_sharing(true); + if (model_regst->mem_shared_id() == -1) { + model_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId()); + model_regst->set_mem_shared_offset(0); + } + RegstDesc* out_regst = fw_task_node->GetProducedRegst("out").get(); + CHECK_EQ(out_regst->min_register_num(), 1); + CHECK_EQ(out_regst->max_register_num(), 1); + CHECK_EQ(out_regst->NumOfLbi(), 1); + out_regst->set_enable_mem_sharing(true); + out_regst->set_mem_shared_id(model_regst->mem_shared_id()); + out_regst->set_mem_shared_offset(model_regst->mem_shared_offset() + + model_regst->ByteOffsetInPackedBlobDescBody(lbi)); + variable_op->set_is_fw_inplace(true); + } else if (bw_task_node != nullptr) { + const LogicalBlobId& lbi = variable_op->BnInOp2Lbi(GenDiffBn(model_bn)); + RegstDesc* model_diff_regst = bw_task_node->GetProducedRegst("model_diff").get(); + CHECK_EQ(model_diff_regst->min_register_num(), 1); + CHECK_EQ(model_diff_regst->max_register_num(), 1); + model_diff_regst->set_enable_mem_sharing(true); + if (model_diff_regst->mem_shared_id() == -1) { + model_diff_regst->set_mem_shared_id(Global<IDMgr>::Get()->NewMemSharedId()); + model_diff_regst->set_mem_shared_offset(0); + } + RegstDesc* out_diff_regst = bw_task_node->GetSoleConsumedRegst("out_diff").get(); + if (out_diff_regst->min_register_num() != 1) { return; } + if (out_diff_regst->max_register_num() != 1) { return; } + if (out_diff_regst->NumOfLbi() != 1) { return; } + out_diff_regst->set_enable_mem_sharing(true); + out_diff_regst->set_mem_shared_id(model_diff_regst->mem_shared_id()); + out_diff_regst->set_mem_shared_offset( + model_diff_regst->mem_shared_offset() + + model_diff_regst->ByteOffsetInPackedBlobDescBody(lbi)); + variable_op->set_is_bw_inplace(true); + } else { + // do nothing + } + }); +} + +void TaskGraph::EnableInplaceMemSharing() { + AcyclicTopoForEachNode([&](TaskNode* node) { + if (node->exec_gph().node_num() != 1) { return; } + const Operator* op = node->exec_gph().SoleNode()->op().get(); + auto* fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node); + auto* bw_task_node = dynamic_cast<NormalBackwardCompTaskNode*>(node); + RegstDesc* input_regst = nullptr; + RegstDesc* output_regst = nullptr; + if (op->IsForwardInplace() && fw_task_node) { + input_regst = fw_task_node->GetSoleConsumedRegst("in").get(); + output_regst = fw_task_node->GetProducedRegst("out").get(); + } else if (op->IsBackwardInplace() && bw_task_node) { + input_regst = bw_task_node->GetSoleConsumedRegst(GenDiffBn("out")).get(); + output_regst = bw_task_node->GetProducedRegst(GenDiffBn("in")).get(); + } else { + // do nothing + return; + } + if (input_regst->NumOfLbi() != 1) { return; } + if (output_regst->NumOfLbi() != 1) { return; } + if (input_regst->mem_shared_inplace_block_id() == -1) { + input_regst->set_mem_shared_inplace_block_id(Global<IDMgr>::Get()->NewMemBlockId()); + } + output_regst->set_mem_shared_inplace_block_id(input_regst->mem_shared_inplace_block_id()); + }); +} + void TaskGraph::RmUselessConsumeRelationshipBetweenFwBw() { for (TaskNode* task_node : ordered_task_nodes_) { auto bw_node = dynamic_cast<NormalBackwardCompTaskNode*>(task_node); @@ -331,6 +603,31 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) { } } +DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByTickToSource) { + CHECK(src_logical->SoleOp()->op_conf().has_tick_conf()); + HashMap<size_t, CompTaskNode*> machine_id2tick_task; + HashMap<size_t, std::vector<CompTaskNode*>> machine_id2dst_tasks; + for (CompTaskNode* tick_node : sorted_src_comp_tasks) { + machine_id2tick_task[tick_node->machine_id()] = tick_node; + } + for (CompTaskNode* dst_node : sorted_dst_comp_tasks) { + machine_id2dst_tasks[dst_node->machine_id()].push_back(dst_node); + } + + CompTaskNode* first_tick = sorted_src_comp_tasks.at(0); + for (const auto& pair : machine_id2dst_tasks) { + size_t machine_id = pair.first; + for (CompTaskNode* dst_node : pair.second) { + if (machine_id2tick_task.find(machine_id) != machine_id2tick_task.end()) { + Connect<TaskNode>(machine_id2tick_task.at(machine_id), NewEdge(), dst_node); + } else { + TaskNode* next_node = AddCopyCommNetTaskBetween(first_tick, dst_node); + Connect<TaskNode>(first_tick, NewEdge(), next_node); + } + } + } +} + DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySelectOneSourceToSoleSink) { CHECK_EQ(sorted_dst_comp_tasks.size(), 1); CompTaskNode* sole_dst_comp_task = sorted_dst_comp_tasks.front(); diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index d7609a7e714f79822fae18da0db4e4faefe041b3..db74b89263ad8daa93ac111b63af7805e6b0119b 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -20,8 +20,14 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> { const char* TypeName() const override { return "TaskGraph"; } void RemoveEmptyRegsts(); void AddOrderingCtrlEdgeInSameChain(); + void AddReduceSequenceCtrlEdges(); + void AddMdUpdtCtrlEdgesWithinReduceSplitNode(); + void AddReduceNoBwForwardNodeOverlapingCtrlEdges(); void EnableMemSharingInReduceStruct(); + void EnableMemSharingAfterAllManualSetForMdUpdt(); + void EnableMemSharingInVariableOp(); + void EnableInplaceMemSharing(); void AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); void RmUselessConsumeRelationshipBetweenFwBw(); @@ -32,6 +38,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> { DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne); + DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByRecordLoadToTick); + DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByTickToSource); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySelectOneSourceToSoleSink); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceScatter2ReduceAdd); DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByReduceAdd2ReduceGather); @@ -40,6 +48,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> { private: void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode, std::function<void(TaskNode* node)> Handler) const; + void BuildTaskPath( CompTaskNode* src, CompTaskNode* dst, std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)> diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 6b57c7ce18cff500ab304116bbcd51675cdc1463..674762770b5a716d9d2959982cf8546b400f1735 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -422,5 +422,5 @@ std::map<TaskType, std::string> task_type2color = { {kDecodeRandom, "1"}, {kPackForward, "11"}, {kPackBackward, "12"}, {kUnpackForward, "11"}, {kUnpackBackward, "12"}, {kRepeatForward, "2"}, - {kRepeatBackward, "3"}}; + {kRepeatBackward, "3"}, {kReduceIdentity, "2"}}; } // namespace oneflow diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 624de108fdb22cc9e025a9d6a78dae59e591f820..30034659e815aaa1e4982c746ec9ac32b1fb6e40 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -47,6 +47,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> { int64_t LocalWorkStreamId() const; int64_t GlobalWorkStreamId() const; int64_t GpuPhyId() const { return Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(thrd_id_); } + virtual int64_t AreaId4ChainMerge() const { return area_id(); } // Setters void set_machine_id(int64_t val); diff --git a/oneflow/core/graph/unpack_forward_task_node.cpp b/oneflow/core/graph/unpack_forward_task_node.cpp index 655453456477ed08b856d6f7f176c20295c25720..934068d74cbe9e8307185bc81286b44a6035f722 100644 --- a/oneflow/core/graph/unpack_forward_task_node.cpp +++ b/oneflow/core/graph/unpack_forward_task_node.cpp @@ -36,7 +36,7 @@ void UnpackForwardCompTaskNode::InferProducedDataRegstTimeShape() { const UnpackOp* op = dynamic_cast<UnpackOp*>(logical_node()->SoleOp().get()); CHECK_NOTNULL(op); int64_t in_piece_size = in_regst->GetBlobDesc(op->BnInOp2Lbi("in"))->shape().At(0); - int64_t unpack_num = op->GetUnpackNum(parallel_ctx()->parallel_num()); + int64_t unpack_num = op->GetUnpackNum(); CHECK_EQ(0, in_piece_size % unpack_num); time_shape_dim_vec.push_back(unpack_num); *out_regst->mut_data_regst_time_shape() = std::make_shared<Shape>(std::move(time_shape_dim_vec)); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 65b551b1d969b2d3e6880a1875496ae65e14fbf7..cb4f2f38fc70300b9dd3e7d93d10d5c6349f7ccc 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -1,6 +1,7 @@ #include "oneflow/core/job/compiler.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/device/cudnn_conv_ctx_cache.h" +#include "oneflow/core/graph/op_graph.h" namespace oneflow { @@ -95,7 +96,11 @@ Plan Compiler::DoCompile() { #ifdef WITH_CUDA Global<CudnnConvCtxCache>::New(); #endif + Global<JobDesc>::Get()->FixAndOptimizeDLNet(); const JobDesc* job_desc = Global<JobDesc>::Get(); + TeePersistentLogStream::Create("optimized_job_conf")->Write(job_desc->job_conf()); + Global<OpGraph>::New(job_desc); + Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_op_graph.dot"); auto logical_gph = std::make_unique<LogicalGraph>(job_desc->IsTrain()); int64_t total_mbn_num = logical_gph->total_mbn_num(); auto task_gph = std::make_unique<TaskGraph>(std::move(logical_gph)); @@ -104,14 +109,24 @@ Plan Compiler::DoCompile() { task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::Build); + if (job_desc->IsTrain()) { + task_gph->AddReduceSequenceCtrlEdges(); + task_gph->AddMdUpdtCtrlEdgesWithinReduceSplitNode(); + } task_gph->RemoveEmptyRegsts(); task_gph->AddOrderingCtrlEdgeInSameChain(); if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) { task_gph->EnableMemSharingInReduceStruct(); + task_gph->EnableMemSharingAfterAllManualSetForMdUpdt(); // must last mem shared manual set } + task_gph->EnableInplaceMemSharing(); if (job_desc->IsTrain()) { task_gph->AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); } if (job_desc->IsTrain()) { task_gph->RmUselessConsumeRelationshipBetweenFwBw(); } task_gph->MdUpdtDelayedTopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful); + if (job_desc->IsTrain() && job_desc->enable_mem_sharing()) { + task_gph->EnableMemSharingInVariableOp(); + } + if (job_desc->IsTrain()) { task_gph->AddReduceNoBwForwardNodeOverlapingCtrlEdges(); } Plan plan; task_gph->ForEachNode([&](TaskNode* task_node) { @@ -121,6 +136,7 @@ Plan Compiler::DoCompile() { plan.set_total_mbn_num(total_mbn_num); GenNetTopo(&plan); ToDotFile(plan, "/dot/plan.dot"); + Global<OpGraph>::Delete(); #ifdef WITH_CUDA Global<CudnnConvCtxCache>::Delete(); #endif diff --git a/oneflow/core/job/id_manager.cpp b/oneflow/core/job/id_manager.cpp index 4b3e16d9b3e3186e0e5a4099469742af3548619e..9e17203412d989696a15b34352d80bbf65be9252 100644 --- a/oneflow/core/job/id_manager.cpp +++ b/oneflow/core/job/id_manager.cpp @@ -16,9 +16,12 @@ int64_t IDMgr::GetGpuNcclGatherThrdId(int64_t dev_phy_id) const { int64_t IDMgr::GetGpuMixThrdId(int64_t dev_phy_id) const { return gpu_device_num_ * 5 + dev_phy_id; } -int64_t IDMgr::GetGpuMdUpdtThrdId(int64_t dev_phy_id) const { +int64_t IDMgr::GetGpuReduceCtrlThrdId(int64_t dev_phy_id) const { return gpu_device_num_ * 6 + dev_phy_id; } +int64_t IDMgr::GetGpuMdUpdtThrdId(int64_t dev_phy_id) const { + return gpu_device_num_ * 7 + dev_phy_id; +} int64_t IDMgr::GetCpuDeviceThrdId(int64_t dev_phy_id) const { return gpu_device_num_ * GetCudaWorkTypeSize() + dev_phy_id; } @@ -76,6 +79,11 @@ int64_t IDMgr::GlobalWorkStreamId4ActorId(int64_t actor_id) const { return GlobalWorkStreamId4TaskId(actor_id); } +int64_t IDMgr::GlobalThrdId4TaskId(int64_t task_id) const { + int shift = local_work_stream_id_bit_num_ + task_id_bit_num_; + return (task_id >> shift) << shift; +} + int64_t IDMgr::LocalWorkStreamId4TaskId(int64_t task_id) const { int64_t tmp = (task_id << (machine_id_bit_num_ + thread_id_bit_num_)); tmp &= ~(static_cast<int64_t>(1) << 63); @@ -100,6 +108,7 @@ IDMgr::IDMgr() { CHECK_LT(gpu_device_num_ + cpu_device_num_, (static_cast<int64_t>(1) << thread_id_bit_num_) - 3); regst_desc_id_count_ = 0; mem_shared_id_count_ = 0; + mem_block_id_count_ = 0; } int64_t IDMgr::GetMachineThrdId(int64_t machine_id, int64_t thrd_id) { diff --git a/oneflow/core/job/id_manager.h b/oneflow/core/job/id_manager.h index 59dc7ff65ac29cb4e6b9cdb8217593b68841e91f..5cda9288a858deabb9728d75989eb2c136565c85 100644 --- a/oneflow/core/job/id_manager.h +++ b/oneflow/core/job/id_manager.h @@ -19,6 +19,7 @@ class IDMgr final { int64_t GetGpuNcclScatterThrdId(int64_t dev_phy_id) const; int64_t GetGpuNcclGatherThrdId(int64_t dev_phy_id) const; int64_t GetGpuMixThrdId(int64_t dev_phy_id) const; + int64_t GetGpuReduceCtrlThrdId(int64_t dev_phy_id) const; int64_t GetGpuMdUpdtThrdId(int64_t dev_phy_id) const; int64_t GetCpuDeviceThrdId(int64_t dev_phy_id) const; int64_t CommNetThrdId() const; @@ -27,6 +28,7 @@ class IDMgr final { int64_t NewTaskId(int64_t machine_id, int64_t thrd_id, int64_t local_work_stream_id); int64_t NewRegstDescId() { return regst_desc_id_count_++; } int64_t NewMemSharedId() { return mem_shared_id_count_++; } + int64_t NewMemBlockId() { return mem_block_id_count_++; } // MemZoneId int64_t CpuMemZoneId() const { return Global<JobDesc>::Get()->GpuDeviceNum(); } @@ -53,6 +55,10 @@ class IDMgr final { int64_t AllocateLocalWorkStreamId(int64_t machine_id, int64_t thrd_id); int64_t LocalWorkStreamId4TaskId(int64_t task_id) const; int64_t LocalWorkStreamId4ActorId(int64_t actor_id) const; + // global_thread_id + // sign | machine_id | thrd_id | 0 | 0 + // 1 | 10 | 11 | 21 | 21 + int64_t GlobalThrdId4TaskId(int64_t task_id) const; // global_work_stream_id // sign | machine_id | thrd_id | local_work_stream_id | 0 // 1 | 10 | 11 | 21 | 21 @@ -69,6 +75,7 @@ class IDMgr final { int64_t cpu_device_num_; int64_t regst_desc_id_count_; int64_t mem_shared_id_count_; + int64_t mem_block_id_count_; HashMap<int64_t, int64_t> machine_thrd_id2num_of_tasks_; HashMap<int64_t, int64_t> machine_thrd_id2stream_id_cnt_; HashMap<int64_t, int64_t> stream_id2chain_cnt_; diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp index 6c3e533644ad35cd92d811f7dd582269a3886edb..b95e8155792bd0e8c3a834ee2095df2bdeeac30a 100644 --- a/oneflow/core/job/improver.cpp +++ b/oneflow/core/job/improver.cpp @@ -7,7 +7,10 @@ #include "oneflow/core/job/profiler.h" #include "oneflow/core/graph/plan_task_graph.h" #include "oneflow/core/graph/regst_lifetime_graph.h" +#include "oneflow/core/graph/sharable_mem_block_graph.h" #include "oneflow/core/actor/act_event_logger.h" +#include "oneflow/core/thread/thread_pool.h" +#include "oneflow/core/common/blocking_counter.h" namespace oneflow { @@ -27,16 +30,10 @@ bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc, return true; } -bool IsSharableRegstWithConsumer(const RegstDescProto& regst_desc, - const std::function<int64_t(int64_t)>& ChainId4TaskId) { - return regst_desc.mem_shared_id() == -1 && regst_desc.consumer_task_id_size() > 0 - && regst_desc.enable_mem_sharing() && regst_desc.register_num() == 1 - && IsConsumersAndProducerInSameChain(regst_desc, ChainId4TaskId); -} - void ForEachSharableStreamRegstDescsWithoutConsumer( - const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) { - HashMap<int64_t, std::list<const RegstDescProto*>> global_work_stream_id2regst_descs; + const Plan& plan, + const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) { + HashMap<int64_t, std::vector<const RegstDescProto*>> global_work_stream_id2regst_descs; for (const auto& task : plan.task()) { int64_t global_work_stream_id = Global<IDMgr>::Get()->GlobalWorkStreamId4TaskId(task.task_id()); for (const auto& pair : task.produced_regst_desc()) { @@ -50,64 +47,125 @@ void ForEachSharableStreamRegstDescsWithoutConsumer( } } -void ForEachSharableChainRegstDescsWithConsumer( - const Plan& plan, const std::function<int64_t(int64_t)>& ChainId4TaskId, - const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) { - HashMap<int64_t, std::list<const TaskProto*>> chain_id2task_proto; - for (const TaskProto& task : plan.task()) { - chain_id2task_proto[task.task_set_info().chain_id()].push_back(&task); - } - for (const auto& chain_tasks_pair : chain_id2task_proto) { - if (chain_tasks_pair.second.size() == 1) { continue; } - std::list<const RegstDescProto*> regst_descs; - for (const TaskProto* task : chain_tasks_pair.second) { - for (const auto& pair : task->produced_regst_desc()) { - if (IsSharableRegstWithConsumer(pair.second, ChainId4TaskId)) { - regst_descs.push_back(&pair.second); - } - } - } - if (regst_descs.size() > 1) { Handler(regst_descs); } - } -} - void ForEachSameColoredStreamRegstDescWithoutConsumer( - const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) { + const Plan& plan, + const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) { auto GetProducerTaskId = [](const RegstDescProto* regst_desc, HashSet<int64_t>* ret_actor_ids) { CHECK(regst_desc->enable_mem_sharing()); ret_actor_ids->insert(regst_desc->producer_task_id()); }; ForEachSharableStreamRegstDescsWithoutConsumer( - plan, [&](const std::list<const RegstDescProto*>& regst_descs) { + plan, [&](const std::vector<const RegstDescProto*>& regst_descs) { RegstLifetimeGraph(regst_descs, GetProducerTaskId).ForEachSameColoredRegstDescs(Handler); }); } +void ForEachSameColoredChainRegstRegstDescs( + const SharableMemBlockGraph& sharable_mem_block_gph, + const std::function<std::vector<const RegstDescProto*>( + const std::vector<const SharableMemBlockNode*>&)>& GetRegstDescs, + const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& + ComputeLifetimeSameChainActorIds, + const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) { + std::vector<std::vector<const SharableMemBlockNode*>> sharable_mem_blocks_vec; + sharable_mem_block_gph.ForEachSourceNodeGroup( + &SharableMemBlockNode::chain_id, + [&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) { + sharable_mem_blocks_vec.push_back(sharable_mem_blocks); + }); + std::vector<std::vector<std::vector<const RegstDescProto*>>> same_colored_regst_descs_vec( + sharable_mem_blocks_vec.size()); + int64_t cpu_num = std::thread::hardware_concurrency(); + int64_t thread_pool_size = std::min<int64_t>(sharable_mem_blocks_vec.size(), cpu_num); + BlockingCounter counter(sharable_mem_blocks_vec.size()); + ThreadPool thread_pool(thread_pool_size); + FOR_RANGE(int64_t, i, 0, sharable_mem_blocks_vec.size()) { + thread_pool.AddWork([i, &GetRegstDescs, &ComputeLifetimeSameChainActorIds, + &sharable_mem_blocks_vec, &same_colored_regst_descs_vec, &counter]() { + const auto& sharable_mem_blocks = sharable_mem_blocks_vec.at(i); + RegstLifetimeGraph(GetRegstDescs(sharable_mem_blocks), ComputeLifetimeSameChainActorIds) + .ForEachSameColoredRegstDescs([&](const std::vector<const RegstDescProto*>& regst_descs) { + same_colored_regst_descs_vec.at(i).push_back(regst_descs); + }); + counter.Decrease(); + }); + } + counter.WaitUntilCntEqualZero(); + for (const auto& regst_descs_vec : same_colored_regst_descs_vec) { + for (const auto& regst_descs : regst_descs_vec) { Handler(regst_descs); } + } +} + void ForEachSameColoredChainRegstDescWithConsumer( const PlanTaskGraph& plan_task_graph, - const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) { + const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) { + // construct SharableMemBlockGraph + auto ChainId4TaskId = [&](int64_t task_id) { + return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id(); + }; + auto IsSharableRegstWithConsumer = [&](const RegstDescProto& regst_desc) { + return regst_desc.mem_shared_id() == -1 && regst_desc.consumer_task_id_size() > 0 + && regst_desc.enable_mem_sharing() && regst_desc.register_num() == 1 + && IsConsumersAndProducerInSameChain(regst_desc, ChainId4TaskId); + }; + SharableMemBlockGraph sharable_mem_block_gph(plan_task_graph, IsSharableRegstWithConsumer); + sharable_mem_block_gph.ForEachNode([&](const SharableMemBlockNode* sharable_mem_block) { + CHECK_EQ(sharable_mem_block->mem_block().mem_reduce_method(), MemReduceMethod::kMemMax); + }); + // group regst_descs for pre-colored regst_descs. + // example: + // given dlnet: A -> B -> C -> D -> E -> F -> H -> I, where D is a inplace op. + // Regst(C) and Regst(D) are pre-colored with same color as a group, which + // then shares memory with other regsts like A, B, E, ... + HashMap<const RegstDescProto*, std::vector<const RegstDescProto*>> header2members; + for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_block_gph.source_nodes()) { + auto regst_descs = sharable_mem_block->regst_descs(); + HashMap<const RegstDescProto*, size_t> regst_desc2mem_size; + for (const RegstDescProto* regst_desc : regst_descs) { + size_t size = RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst(); + CHECK(regst_desc2mem_size.emplace(regst_desc, size).second); + } + std::sort(regst_descs.begin(), regst_descs.end(), + [&](const RegstDescProto* lhs, const RegstDescProto* rhs) { + return regst_desc2mem_size.at(lhs) > regst_desc2mem_size.at(rhs); + }); + header2members.emplace(regst_descs.at(0), regst_descs); + } + auto GetRegstDescs = [&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) { + std::vector<const RegstDescProto*> ret; + for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_blocks) { + for (const RegstDescProto* regst_desc : sharable_mem_block->regst_descs()) { + if (header2members.find(regst_desc) != header2members.end()) { + ret.push_back(regst_desc); + break; + } + } + } + return ret; + }; auto ComputeLifetimeSameChainActorIds = [&](const RegstDescProto* regst_desc, HashSet<int64_t>* ret_actor_ids) { CHECK(regst_desc->enable_mem_sharing()); ret_actor_ids->clear(); - plan_task_graph.ComputeLifetimeSameChainActorIds(regst_desc, ret_actor_ids); + for (const RegstDescProto* member : header2members.at(regst_desc)) { + plan_task_graph.ComputeLifetimeSameChainActorIds(member, ret_actor_ids); + } }; - auto ChainId4TaskId = [&](int64_t task_id) { - return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id(); + auto AppendGroupMembers = [&](const std::vector<const RegstDescProto*>& regst_descs) { + std::vector<const RegstDescProto*> members; + for (const auto* header : regst_descs) { + for (const auto* member : header2members.at(header)) { members.push_back(member); } + } + Handler(members); }; - const Plan& plan = plan_task_graph.plan(); - ForEachSharableChainRegstDescsWithConsumer( - plan, ChainId4TaskId, [&](const std::list<const RegstDescProto*>& regst_descs) { - RegstLifetimeGraph(regst_descs, ComputeLifetimeSameChainActorIds) - .ForEachSameColoredRegstDescs(Handler); - }); + ForEachSameColoredChainRegstRegstDescs(sharable_mem_block_gph, GetRegstDescs, + ComputeLifetimeSameChainActorIds, AppendGroupMembers); } void ForEachImprovedMemSharedId(const PlanTaskGraph& plan_task_graph, const std::function<void(int64_t, int64_t)>& Handler) { - using RegstDescs = std::list<const RegstDescProto*>; const Plan& plan = plan_task_graph.plan(); - auto HandleMemSharedId = [&](const RegstDescs& regst_descs) { + auto HandleMemSharedId = [&](const std::vector<const RegstDescProto*>& regst_descs) { int64_t mem_shared_id = Global<IDMgr>::Get()->NewMemSharedId(); for (const RegstDescProto* regst_desc : regst_descs) { Handler(regst_desc->regst_desc_id(), mem_shared_id); diff --git a/oneflow/core/job/init_op_conf.cpp b/oneflow/core/job/init_op_conf.cpp index 9e6b04e2348e60ee088aba910eaba091bad3e212..9587584bab392e767d8ccfd3c872957d6125a3ba 100644 --- a/oneflow/core/job/init_op_conf.cpp +++ b/oneflow/core/job/init_op_conf.cpp @@ -219,6 +219,12 @@ void InitInitializerConf(InitializerConf* initializer, const InitializerConf::Ty initializer->set_allocated_random_normal_conf(random_normal_conf); break; } + case InitializerConf::kTruncatedNormalConf: { + TruncatedNormalInitializerConf* truncated_normal_conf = new TruncatedNormalInitializerConf(); + truncated_normal_conf->set_std(param1); + initializer->set_allocated_truncated_normal_conf(truncated_normal_conf); + break; + } case InitializerConf::kXavierConf: { XavierInitializerConf* xavier_conf = new XavierInitializerConf(); xavier_conf->set_variance_norm(static_cast<VarianceNorm>(static_cast<int>(param1))); @@ -231,6 +237,20 @@ void InitInitializerConf(InitializerConf* initializer, const InitializerConf::Ty initializer->set_allocated_msra_conf(msra_conf); break; } + case InitializerConf::kRangeConf: { + RangeInitializerConf* range_conf = new RangeInitializerConf(); + range_conf->set_start(param1); + range_conf->set_stride(param2); + initializer->set_allocated_range_conf(range_conf); + break; + } + case InitializerConf::kIntRangeConf: { + IntRangeInitializerConf* int_range_conf = new IntRangeInitializerConf(); + int_range_conf->set_start(static_cast<int64_t>(param1)); + int_range_conf->set_stride(static_cast<int64_t>(param2)); + initializer->set_allocated_int_range_conf(int_range_conf); + break; + } case InitializerConf::TYPE_NOT_SET: { LOG(INFO) << "InitializerConf::TYPE_NOT_SET"; break; diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 3bdad1aed56426fb36913070aa75132b3fa2835e..07ed54c52c9f196df7a2191a0e67a6ff5c38b104 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -67,18 +67,30 @@ message OtherConf { optional bool save_downloaded_file_to_local_fs = 109 [default = false]; optional uint64 rdma_mem_block_mbyte = 110 [default = 8]; optional uint64 rdma_recv_msg_buf_mbyte = 111 [default = 6]; + required FileSystemConf data_fs_conf = 112; + required FileSystemConf snapshot_fs_conf = 113; - optional int64 reduce_group_size = 112 [default = 20]; - optional bool collect_act_event = 113 [default = false]; - optional bool enable_mem_sharing = 114 [default = true]; + optional bool collect_act_event = 125 [default = false]; + optional bool enable_mem_sharing = 126 [default = true]; optional bool enable_write_snapshot = 130 [default = true]; optional bool enable_blob_mem_sharing = 140 [default = true]; optional bool enable_nccl = 142 [default = true]; optional bool use_nccl_inter_node_communication = 143 [default = false]; optional int64 cudnn_buf_limit_mbyte = 144 [default = 1024]; // 1GByte - - required FileSystemConf data_fs_conf = 121; - required FileSystemConf snapshot_fs_conf = 122; + optional int64 all_reduce_group_num = 145 [default = 8]; + // example: + // all_reduce_lazy_ratio = 0.5 + // It means that half of all_reduce nodes overlap with the forward pass of next batch + optional float all_reduce_lazy_ratio = 146 [default = 0.5]; + optional int64 all_reduce_group_min_mbyte = 147 [default = 16]; + // example: + // total weight bytes is 1024M + // all_reduce_group_num = 8 + // all_reduce_group_min_mbyte = 16 + // all_reduce_group_size_warmup = 2 + // Each nodes' weight size are [16, 32, 64, 128, 128, 128, 128, 128, 128, 128, 16]. + // You can see the actual number of reduce group is slightly bigger than all_reduce_group_num. + optional float all_reduce_group_size_warmup = 148 [default = 2]; oneof JobType { TrainConf train_conf = 200; diff --git a/oneflow/core/job/job_desc.cpp b/oneflow/core/job/job_desc.cpp index abc2ccec121cdf6b3def19f29dc84701366467cd..51d31f7e1237a45735d149bf3cdbd7917a277ef1 100644 --- a/oneflow/core/job/job_desc.cpp +++ b/oneflow/core/job/job_desc.cpp @@ -1,10 +1,146 @@ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/operator/operator.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/persistence/hadoop/hadoop_file_system.h" +#include "oneflow/core/graph/graph.h" +#include "oneflow/core/graph/op_graph.h" namespace oneflow { +std::function<const ParallelConf*(const std::string&)> MakeGetterParallelConf4OpName( + const Placement& placement) { + auto op_name2parallel_conf = std::make_shared<HashMap<std::string, const ParallelConf*>>(); + for (const auto& placement_group : placement.placement_group()) { + for (const std::string& op_name : placement_group.op_set().op_name()) { + const ParallelConf* parallel_conf = &placement_group.parallel_conf(); + CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second); + } + } + return [op_name2parallel_conf](const std::string& op_name) { + return op_name2parallel_conf->at(op_name); + }; +} + +std::function<ParallelConf*(const std::string&)> MakeGetterMutParallelConf4OpName( + Placement* placement) { + auto op_name2parallel_conf = std::make_shared<HashMap<std::string, ParallelConf*>>(); + FOR_RANGE(int, idx, 0, placement->placement_group_size()) { + auto* placement_group = placement->mutable_placement_group(idx); + for (const std::string& op_name : placement_group->op_set().op_name()) { + ParallelConf* parallel_conf = placement_group->mutable_parallel_conf(); + CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second); + } + } + return [op_name2parallel_conf](const std::string& op_name) { + return op_name2parallel_conf->at(op_name); + }; +} + +namespace { + +std::function<OperatorConf*(const std::string&)> MakeMutableOperatorConf4OpName( + JobConf1* job_conf) { + auto op_name2op_conf = std::make_shared<HashMap<std::string, OperatorConf*>>(); + FOR_RANGE(int, idx, 0, job_conf->net().op_size()) { + OperatorConf* op_conf = job_conf->mutable_net()->mutable_op(idx); + CHECK(op_name2op_conf->emplace(op_conf->name(), op_conf).second); + } + return [op_name2op_conf](const std::string& op_name) { return op_name2op_conf->at(op_name); }; +} + +void AddIdentityOp(const std::string& prefix, JobConf1* job_conf, + const HashSet<LogicalBlobId>& input_lbis, + HashMap<LogicalBlobId, LogicalBlobId>* old_lbi2new_lbi, + const ParallelConf& parallel_conf) { + // add tuple identity op + OperatorConf* tuple_identity_op = job_conf->mutable_net()->add_op(); + tuple_identity_op->set_name(prefix + NewUniqueId()); + TupleIdentityOpConf* tuple_identity_op_conf = tuple_identity_op->mutable_tuple_identity_conf(); + int32_t idx = 0; + for (const LogicalBlobId& lbi : input_lbis) { + std::string blob_name = std::string("out_") + std::to_string(idx++); + { + LogicalBlobId output_lbi; + output_lbi.set_op_name(tuple_identity_op->name()); + output_lbi.set_blob_name(blob_name); + CHECK(old_lbi2new_lbi->emplace(lbi, output_lbi).second); + } + tuple_identity_op_conf->add_in(lbi.op_name() + "/" + lbi.blob_name()); + tuple_identity_op_conf->add_out(blob_name); + } + // add placement of tuple identity op + PlacementGroup* p_group = job_conf->mutable_placement()->add_placement_group(); + *(p_group->mutable_op_set()->add_op_name()) = tuple_identity_op->name(); + *(p_group->mutable_parallel_conf()) = parallel_conf; +} + +void SetPbMessageField(PbMessage* pb_msg, const std::string& field, const std::string& old_val, + const std::string& new_val) { + const PbFd* fd = pb_msg->GetDescriptor()->FindFieldByName(field); + if (fd) { + CHECK_EQ(GetValFromPbMessage<std::string>(*pb_msg, field), old_val); + SetValInPbMessage<std::string>(pb_msg, field, new_val); + } else { + const std::pair<std::string, int32_t> prefix_idx = GenUnRepeatedBn(field); + CHECK_EQ(GetPbRpfFromPbMessage<std::string>(*pb_msg, prefix_idx.first).Get(prefix_idx.second), + old_val); + PbRpf<std::string>* rpf = MutPbRpfFromPbMessage<std::string>(pb_msg, prefix_idx.first); + *rpf->Mutable(prefix_idx.second) = new_val; + } +} + +void AddIdentityOpAndReconnect( + const std::string& identity_op_name_prefix, JobConf1* job_conf, + const std::vector<OpEdge*>& op_edges, + const std::function<OperatorConf*(const std::string&)>& MutOperatorConf4OpName, + const ParallelConf& parallel_conf) { + // add identity op + HashSet<LogicalBlobId> lbis; + for (OpEdge* edge : op_edges) { lbis.insert(edge->lbis().begin(), edge->lbis().end()); } + HashMap<LogicalBlobId, LogicalBlobId> old_lbi2new_lbi; + AddIdentityOp(identity_op_name_prefix, job_conf, lbis, &old_lbi2new_lbi, parallel_conf); + // reconnect to identity op + for (OpEdge* edge : op_edges) { + OperatorConf* op_conf = MutOperatorConf4OpName(edge->dst_node()->op().op_name()); + PbMessage* op_type_conf = MutableMessageInPbMessage(op_conf, op_conf->op_type_case()); + for (const LogicalBlobId& lbi : edge->lbis()) { + std::string lbn_check = GenLogicalBlobName(lbi); + std::string identity_out_lbn = GenLogicalBlobName(old_lbi2new_lbi.at(lbi)); + for (const std::string& ibn : edge->lbi2ibns().at(lbi)) { + SetPbMessageField(op_type_conf, ibn, lbn_check, identity_out_lbn); + } + } + } +} + +} // namespace + +int64_t JobDesc::all_reduce_group_min_byte() const { + int64_t ret = job_conf_.other().all_reduce_group_min_mbyte() * 1024 * 1024; + CHECK_GT(ret, 0); + return ret; +} + +float JobDesc::all_reduce_group_size_warmup() const { + float ret = job_conf_.other().all_reduce_group_size_warmup(); + CHECK_GT(ret, 1); + return ret; +} + +int64_t JobDesc::all_reduce_group_num() const { + int64_t ret = job_conf_.other().all_reduce_group_num(); + CHECK_GT(ret, 0); + return ret; +} + +float JobDesc::all_reduce_lazy_ratio() const { + float ratio = job_conf_.other().all_reduce_lazy_ratio(); + CHECK_GE(ratio, 0.0); + CHECK_LE(ratio, 1.0); + return ratio; +} + int64_t JobDesc::piece_num_of_experiment_phase() const { return job_conf_.other().exp_run_conf().piece_num_of_experiment_phase(); } @@ -65,8 +201,8 @@ int64_t JobDesc::BatchSize() const { } int64_t JobDesc::NumOfPiecesInBatch() const { if (IsPredict()) { return 1; } - CHECK_EQ(BatchSize() % PieceSize(), 0); - return BatchSize() / PieceSize(); + CHECK_EQ(BatchSize() % RecordPieceSize(), 0); + return BatchSize() / RecordPieceSize(); } float JobDesc::primary_lr() const { CHECK(IsTrain()); @@ -195,6 +331,7 @@ void JobDesc::SplitDecodeOps() { void JobDesc::AddRecordLoadOps() { HashMap<std::pair<std::string, std::string>, std::vector<OperatorConf*>> data_info2decode_ops; HashMap<std::pair<std::string, std::string>, int32_t> data_info2suffix_length; + HashMap<std::pair<std::string, std::string>, const RandomShuffleConf*> data_info2shuffle_conf; size_t op_num = job_conf_.net().op_size(); FOR_RANGE(size_t, idx, 0, op_num) { OperatorConf* op_conf = job_conf_.mutable_net()->mutable_op()->Mutable(idx); @@ -210,6 +347,18 @@ void JobDesc::AddRecordLoadOps() { } else { data_info2suffix_length[data_info] = part_name_suffix_length; } + const RandomShuffleConf* shuffle_conf = + decode_conf.has_random_shuffle_conf() ? &decode_conf.random_shuffle_conf() : nullptr; + if (data_info2shuffle_conf.find(data_info) != data_info2shuffle_conf.end()) { + if (shuffle_conf == nullptr) { + CHECK(data_info2shuffle_conf.at(data_info) == nullptr); + } else { + CHECK(data_info2shuffle_conf.at(data_info) != nullptr); + CHECK_EQ(data_info2shuffle_conf.at(data_info)->buffer_size(), shuffle_conf->buffer_size()); + } + } else { + CHECK(data_info2shuffle_conf.emplace(data_info, shuffle_conf).second); + } } HashMap<std::string, const ParallelConf*> name2parallel_conf; @@ -246,6 +395,9 @@ void JobDesc::AddRecordLoadOps() { record_load_op->set_data_dir(pair.first.first); record_load_op->set_part_name_prefix(pair.first.second); record_load_op->set_part_name_suffix_length(data_info2suffix_length.at(pair.first)); + if (data_info2shuffle_conf.at(pair.first) != nullptr) { + *record_load_op->mutable_random_shuffle_conf() = *data_info2shuffle_conf.at(pair.first); + } PlacementGroup* p_group = job_conf_.mutable_placement()->add_placement_group(); *(p_group->mutable_op_set()->add_op_name()) = record_load_op_name; *(p_group->mutable_parallel_conf()) = *parallel_conf; @@ -261,4 +413,128 @@ void JobDesc::AddRecordLoadOps() { } } +void JobDesc::FixAndOptimizeDLNet() { + FixTickOpIfExists(); + ConvertPseudoChainToChain(); + if (IsTrain()) { AddIdentityOpForAllReduceOverlapingUntrainble(); } +} + +void JobDesc::ConvertPseudoChainToChain() { + auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes, + HashSet<OpNode*>* source_nodes, + std::vector<OpEdge*>* source_edges) { + for (OpNode* node : chain_nodes) { + for (OpEdge* edge : node->in_edges()) { + if (chain_nodes.find(edge->src_node()) == chain_nodes.end()) { + source_edges->push_back(edge); + source_nodes->insert(node); + } + } + } + }; + auto MutOperatorConf4OpName = MakeMutableOperatorConf4OpName(&job_conf_); + auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job_conf_.placement()); + OpGraph(this).ForEachPseudoChain([&](const HashSet<OpNode*>& chain_nodes) { + HashSet<OpNode*> source_nodes; + std::vector<OpEdge*> source_edges; + GetSourceNodesAndEdges(chain_nodes, &source_nodes, &source_edges); + if (source_edges.size() <= 1) { return; } + if (source_nodes.size() <= 1) { return; } + if (chain_nodes.size() - source_nodes.size() <= 2) { return; } + const OpNode* first_node = *source_nodes.begin(); + if (first_node->parallel_desc().device_type() == DeviceType::kCPU) { return; } + HashMap<bool, std::vector<OpEdge*>> has_diff2source_edges; + for (OpEdge* edge : source_edges) { has_diff2source_edges[edge->has_diff()].push_back(edge); } + for (const auto& pair : has_diff2source_edges) { + HashSet<OpNode*> src_nodes; + HashSet<OpNode*> dst_nodes; + for (OpEdge* edge : pair.second) { + src_nodes.emplace(edge->src_node()); + dst_nodes.emplace(edge->dst_node()); + } + if (src_nodes.size() > 1 && dst_nodes.size() > 1) { + AddIdentityOpAndReconnect("pseudo_chain_header_", &job_conf_, pair.second, + MutOperatorConf4OpName, + *ParallelConf4OpName(first_node->op().op_name())); + } + } + }); +} + +void JobDesc::AddIdentityOpForAllReduceOverlapingUntrainble() { + auto MutOperatorConf4OpName = MakeMutableOperatorConf4OpName(&job_conf_); + auto ParallelConf4OpName = MakeGetterParallelConf4OpName(job_conf_.placement()); + OpGraph(this).TopoForEachNode([&](OpNode* op_node) { + if (op_node->HasBackward()) { return; } + HashMap<bool, std::vector<OpEdge*>> has_bw2out_op_edges; + for (OpEdge* edge : op_node->out_edges()) { + has_bw2out_op_edges[edge->dst_node()->HasBackward()].push_back(edge); + } + if (has_bw2out_op_edges.size() <= 1) { return; } + // only handle op_nodes that: + // a) have no backward node; + // b) have trainable and untrainble consumers; + + // group out_edge by trainable consumers' ParallelDesc + HashMap<ParallelDesc, std::vector<OpEdge*>> consumer_op_pr2edges; + for (OpEdge* edge : has_bw2out_op_edges.at(true)) { + ParallelDesc pr(*ParallelConf4OpName(edge->dst_node()->op().op_name())); + consumer_op_pr2edges[pr].push_back(edge); + } + for (const auto& pair : consumer_op_pr2edges) { + AddIdentityOpAndReconnect( + "all_reduce_overlapping_untrainable_", &job_conf_, pair.second, MutOperatorConf4OpName, + *ParallelConf4OpName(pair.second.at(0)->dst_node()->op().op_name())); + } + }); +} + +void JobDesc::FixTickOpIfExists() { + auto MutParallelConf4OpName = MakeGetterMutParallelConf4OpName(job_conf_.mutable_placement()); + OperatorConf* tick_op_conf = nullptr; + FOR_RANGE(int, idx, 0, job_conf_.mutable_net()->op_size()) { + OperatorConf* op_conf = job_conf_.mutable_net()->mutable_op(idx); + if (op_conf->has_tick_conf()) { + CHECK(tick_op_conf == nullptr); + tick_op_conf = op_conf; + } + } + if (tick_op_conf == nullptr) { return; } + std::map<OperatorConf::OpTypeCase, std::vector<OperatorConf*>> op_type_case2source_op_confs; + FOR_RANGE(int, idx, 0, job_conf_.mutable_net()->op_size()) { + OperatorConf* op_conf = job_conf_.mutable_net()->mutable_op(idx); + if (op_conf == tick_op_conf) { continue; } + DeviceType device_type = ParallelDesc(*MutParallelConf4OpName(op_conf->name())).device_type(); + if (ConstructOp(*op_conf, device_type)->input_bns().size() == 0) { + op_type_case2source_op_confs[op_conf->op_type_case()].push_back(op_conf); + } + } + if (op_type_case2source_op_confs.find(OperatorConf::kRecordLoadConf) + != op_type_case2source_op_confs.end()) { + CHECK_EQ(op_type_case2source_op_confs.size(), 1); + } + // set input of tick op + OperatorConf* source_op_conf = op_type_case2source_op_confs.cbegin()->second.at(0); + ParallelConf* source_parallel_conf = MutParallelConf4OpName(source_op_conf->name()); + DeviceType device_type = ParallelDesc(*source_parallel_conf).device_type(); + std::shared_ptr<Operator> source_op = ConstructOp(*source_op_conf, device_type); + CHECK_GE(source_op->output_bns().size(), 1); + LogicalBlobId src_first_output_lbi = source_op->BnInOp2Lbi(source_op->output_bns().Get(0)); + std::string source_op_output_lbn = GenLogicalBlobName(src_first_output_lbi); + CHECK_EQ(tick_op_conf->tick_conf().has_in(), false); + tick_op_conf->mutable_tick_conf()->set_in(source_op_output_lbn); + // fix tick op placement + *MutParallelConf4OpName(tick_op_conf->name()) = *source_parallel_conf; + // add log_counter op connecting to tick op, making tick op always consumed + OperatorConf* tick_log_counter = job_conf_.mutable_net()->add_op(); + tick_log_counter->set_name("tick_log_counter_" + NewUniqueId()); + LogCounterOpConf* tick_log_counter_conf = tick_log_counter->mutable_log_counter_conf(); + tick_log_counter_conf->set_in(tick_op_conf->name() + "/" + tick_op_conf->tick_conf().out()); + tick_log_counter_conf->set_interval(MaxVal<int32_t>::value); + // add placement of tick_log_counter op + PlacementGroup* p_group = job_conf_.mutable_placement()->add_placement_group(); + *(p_group->mutable_op_set()->add_op_name()) = tick_log_counter->name(); + *(p_group->mutable_parallel_conf()) = *source_parallel_conf; +} + } // namespace oneflow diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h index a1593ccd7bc2d507d4f666085d91992a30a8ad54..0b1fc41fe2d53b7cc7e97d8c6c849c0e060e71cd 100644 --- a/oneflow/core/job/job_desc.h +++ b/oneflow/core/job/job_desc.h @@ -7,6 +7,7 @@ #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/job/resource.pb.h" #include "oneflow/core/persistence/file_system.h" +#include "oneflow/core/register/logical_blob_id.pb.h" namespace oneflow { @@ -17,6 +18,7 @@ class JobDesc final { ~JobDesc() = default; // Common + const JobConf1& job_conf() const { return job_conf_; } const DLNetConf& dlnet_conf() const { return job_conf_.net(); } const Resource& resource() const { return job_conf_.resource(); } const Placement& placement() const { return job_conf_.placement(); } @@ -35,7 +37,7 @@ class JobDesc final { int32_t MaxMdSaveWorkerNum() const { return job_conf_.resource().max_mdsave_worker_num(); } bool IsTrain() const { return job_conf_.other().has_train_conf(); } bool IsPredict() const { return job_conf_.other().has_predict_conf(); } - int64_t PieceSize() const { return job_conf_.other().piece_size(); } + int64_t RecordPieceSize() const { return job_conf_.other().piece_size(); } int64_t piece_num_of_experiment_phase() const; bool enable_experiment_run() const; float available_zone_mem_ratio() const; @@ -58,7 +60,10 @@ class JobDesc final { bool use_nccl_inter_node_communication() const { return job_conf_.other().use_nccl_inter_node_communication(); } - int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); } + int64_t all_reduce_group_num() const; + int64_t all_reduce_group_min_byte() const; + float all_reduce_group_size_warmup() const; + float all_reduce_lazy_ratio() const; int64_t cudnn_buf_limit_mbyte() const { return job_conf_.other().cudnn_buf_limit_mbyte(); } int64_t GetMachineId(const std::string& addr) const; @@ -79,6 +84,9 @@ class JobDesc final { float bias_l2() const; int32_t DataPartNum() const; + // fix and Optimize + void FixAndOptimizeDLNet(); + private: friend class Global<JobDesc>; JobDesc(const std::string& job_conf_filepath); @@ -87,10 +95,19 @@ class JobDesc final { void SanityCheck(); void SplitDecodeOps(); void AddRecordLoadOps(); + void ConvertPseudoChainToChain(); + void AddIdentityOpForChainMergeOptimization(); + void AddIdentityOpForAllReduceOverlapingUntrainble(); + void FixTickOpIfExists(); JobConf1 job_conf_; }; +std::function<const ParallelConf*(const std::string&)> MakeGetterParallelConf4OpName( + const Placement& placement); +std::function<ParallelConf*(const std::string&)> MakeGetterMutParallelConf4OpName( + Placement* placement); + } // namespace oneflow #endif // ONEFLOW_CORE_JOB_JOB_DESC_H_ diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index e1c5a0ef7284a731d46bd6617fef15c6a2e9c945..37756f65e16ca21e770efbd6e996a8bb2791cbeb 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -34,7 +34,9 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kGPU); device_type_ = DeviceType::kGPU; } - sorted_machine_ids_.push_back(mchn_id); + if (machine_id_set.find(mchn_id) == machine_id_set.end()) { + sorted_machine_ids_.push_back(mchn_id); + } int64_t minus_pos = device_id_str.find("-"); if (minus_pos == std::string::npos) { device_id_str = device_id_str + "-" + device_id_str; @@ -54,42 +56,6 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { SanityCheck(); } -void ParallelDesc::RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num) { - if (max_device_num >= parallel_num_) { return; } - LOG_IF(WARNING, op_name != "") << "parallel_num of " << op_name - << " is greater than max_device_num " << max_device_num; - int32_t device_cnt = 0; - int64_t max_machine_id = -1; - for (int64_t machine_id : sorted_machine_ids_) { - auto it = machine_id2sorted_dev_phy_ids_.find(machine_id); - int32_t cur_device_num = it->second.size(); - int32_t cur_device_max_num = max_device_num - device_cnt; - if (cur_device_num > cur_device_max_num) { - it->second.erase(it->second.begin() + cur_device_max_num, it->second.end()); - if (it->second.empty()) { - max_machine_id = machine_id - 1; - } else { - max_machine_id = machine_id; - } - break; - } else { - device_cnt += cur_device_num; - } - } - CHECK_NE(max_machine_id, -1); - FOR_EACH(it, sorted_machine_ids_) { - if (*it > max_machine_id) { - sorted_machine_ids_.erase(it, sorted_machine_ids_.end()); - break; - } - } - EraseIf<int64_t, std::vector<int64_t>>(&machine_id2sorted_dev_phy_ids_, - [&](HashMap<int64_t, std::vector<int64_t>>::iterator it) { - return it->first > max_machine_id; - }); - parallel_num_ = max_device_num; -} - void ParallelDesc::RandomSelectOneDeviceAndRemoveTheOthers() { CHECK_GE(parallel_num_, 1); if (parallel_num_ == 1) { return; } diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index 224d24816a9180ab0f672628e840e522148a6841..b2d3561b5f0a258414d441f788c1b7f7c5bac225 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -11,12 +11,13 @@ namespace oneflow { void ParseDeviceNameConf(const std::string& device_name, int64_t* mchn_id, std::string* device_tag, std::string* device_id_str); -class ParallelDesc { +class ParallelDesc final { public: // OF_DISALLOW_COPY_AND_MOVE(ParallelDesc); ParallelDesc() = delete; ~ParallelDesc() = default; + ParallelDesc(const ParallelDesc&) = default; ParallelDesc(const ParallelConf& user_conf); // Getters @@ -32,13 +33,12 @@ class ParallelDesc { // Setters void set_policy(ParallelPolicy val) { policy_ = val; } void set_device_type(DeviceType device_type) { device_type_ = device_type; } - void RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num); - void RemoveNeedlessDevice(int32_t max_device_num) { RemoveNeedlessDevice("", max_device_num); } void RandomSelectOneDeviceAndRemoveTheOthers(); void UseCPUDevicesOnMaster(); // bool Equal(const ParallelDesc& rhs) const; + bool operator==(const ParallelDesc& rhs) const { return Equal(rhs); } bool Equal(const ParallelDesc* rhs) const { return Equal(*rhs); } private: @@ -58,4 +58,20 @@ std::tuple<int32_t, int32_t> GetPartIdAndPartNumFromParallelCtx( } // namespace oneflow +namespace std { + +template<> +struct hash<oneflow::ParallelDesc> { + size_t operator()(const oneflow::ParallelDesc& pr) const { + std::string str; + for (int machine_id : pr.sorted_machine_ids()) { + str += "::" + std::to_string(machine_id) + ":"; + for (int dev_id : pr.sorted_dev_phy_ids(machine_id)) { str += std::to_string(dev_id) + ","; } + } + return hash<std::string>()(str); + } +}; + +} // namespace std + #endif // ONEFLOW_CORE_JOB_PARALLEL_DESC_H_ diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index da7e207b83f20003cf89e1cc8e970cd50836ea49..e886051661bf5c5ec0c0b33dcb3b27e4de42435f 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -7,6 +7,7 @@ #include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/actor/act_event_logger.h" #include "oneflow/core/graph/task_node.h" +#include "oneflow/core/device/cuda_util.h" namespace oneflow { @@ -103,6 +104,9 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) { } #endif } +#ifdef WITH_CUDA + InitGlobalCudaDeviceProp(); +#endif Global<SnapshotMgr>::New(plan); Global<MemoryAllocator>::New(); Global<RegstMgr>::New(plan); diff --git a/oneflow/core/job/sbp_infer_hint.cpp b/oneflow/core/job/sbp_infer_hint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22117b868227b2d6cd4a3b91d174114e259498b4 --- /dev/null +++ b/oneflow/core/job/sbp_infer_hint.cpp @@ -0,0 +1,40 @@ +#include "oneflow/core/job/sbp_infer_hint.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/protobuf.h" + +namespace oneflow { + +int64_t SbpInferHint::parallel_num() const { + CHECK_GT(parallel_num_, 0); + return parallel_num_; +} + +int64_t SbpInferHint::num_axes() const { + CHECK_GT(num_axes_, 0); + return num_axes_; +} + +bool SbpInferHint::has_split_axis() const { + bool ret = sbp_parallel_.has_split_parallel(); + if (ret) { + CHECK_GE(sbp_parallel_.split_parallel().axis(), 0); + CHECK_LT(sbp_parallel_.split_parallel().axis(), num_axes()); + } + return ret; +} + +int64_t SbpInferHint::split_axis() const { + CHECK(has_split_axis()); + return sbp_parallel_.split_parallel().axis(); +} + +bool SbpInferHint::is_model_split() const { return is_model_blob() && has_split_axis(); } +bool SbpInferHint::is_model_broadcast() const { + return is_model_blob() && sbp_parallel_.has_broadcast_parallel(); +} +bool SbpInferHint::is_data_split() const { return is_data_blob() && has_split_axis(); } +bool SbpInferHint::is_data_partial_sum() const { + return is_data_blob() && sbp_parallel_.has_partial_sum_parallel(); +} + +} // namespace oneflow diff --git a/oneflow/core/job/sbp_infer_hint.h b/oneflow/core/job/sbp_infer_hint.h new file mode 100644 index 0000000000000000000000000000000000000000..6e89a0651e11480c50749b961cb4ae94d7b8aaf4 --- /dev/null +++ b/oneflow/core/job/sbp_infer_hint.h @@ -0,0 +1,41 @@ +#ifndef ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_ +#define ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_ + +#include "oneflow/core/job/sbp_parallel.pb.h" + +namespace oneflow { + +class SbpInferHint final { + public: + SbpInferHint(bool is_model_blob, int64_t parallel_num, int64_t num_axes, + const SbpParallel& sbp_parallel) + : is_model_blob_(is_model_blob), + parallel_num_(parallel_num), + num_axes_(num_axes), + sbp_parallel_(sbp_parallel) {} + SbpInferHint(const SbpInferHint&) = default; + ~SbpInferHint() = default; + + // Getters + bool is_model_blob() const { return is_model_blob_; } + int64_t parallel_num() const; + int64_t num_axes() const; + int64_t split_axis() const; + bool has_split_axis() const; + bool is_model_split() const; + bool is_model_broadcast() const; + bool is_data_split() const; + bool is_data_partial_sum() const; + bool is_data_blob() const { return !is_model_blob(); } + const SbpParallel& sbp_parallel() const { return sbp_parallel_; } + + private: + const bool is_model_blob_; + const int64_t parallel_num_; + const int64_t num_axes_; + const SbpParallel sbp_parallel_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_SBP_INFER_HINT_H_ diff --git a/oneflow/core/job/sbp_parallel.cpp b/oneflow/core/job/sbp_parallel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5eb6d089e190749276471f85c4da6448faa036a7 --- /dev/null +++ b/oneflow/core/job/sbp_parallel.cpp @@ -0,0 +1,29 @@ +#include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/common/protobuf.h" + +namespace oneflow { + +bool operator==(const SbpParallel& lhs, const SbpParallel& rhs) { + return PbMd().Equivalent(lhs, rhs); +} + +bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs) { return !(lhs == rhs); } + +// S -> S +// P -> C +// C -> P +SbpParallel GetDualSbpParallel(const SbpParallel& sbp_parallel) { + SbpParallel ret(sbp_parallel); + if (sbp_parallel.has_split_parallel()) { + // do nothing + } else if (sbp_parallel.has_broadcast_parallel()) { + ret.mutable_partial_sum_parallel(); + } else if (sbp_parallel.has_partial_sum_parallel()) { + ret.mutable_broadcast_parallel(); + } else { + UNIMPLEMENTED(); + } + return ret; +} + +} // namespace oneflow diff --git a/oneflow/core/job/sbp_parallel.h b/oneflow/core/job/sbp_parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..100cfac3159254b61e92ad93f3db97c5ba30b748 --- /dev/null +++ b/oneflow/core/job/sbp_parallel.h @@ -0,0 +1,14 @@ +#ifndef ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ +#define ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ + +#include "oneflow/core/job/sbp_parallel.pb.h" + +namespace oneflow { + +bool operator==(const SbpParallel& lhs, const SbpParallel& rhs); +bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs); +SbpParallel GetDualSbpParallel(const SbpParallel&); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ diff --git a/oneflow/core/job/sbp_parallel.proto b/oneflow/core/job/sbp_parallel.proto new file mode 100644 index 0000000000000000000000000000000000000000..c89b23c7235a87c96cd6808b55c6284183f3cba4 --- /dev/null +++ b/oneflow/core/job/sbp_parallel.proto @@ -0,0 +1,20 @@ +syntax = "proto2"; +package oneflow; + +message SplitParallel { + required int64 axis = 1; +} + +message BroadcastParallel { +} + +message PartialSumParallel { +} + +message SbpParallel { + oneof parallel_type { + SplitParallel split_parallel = 1; + BroadcastParallel broadcast_parallel = 2; + PartialSumParallel partial_sum_parallel = 3; + } +} diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index 4f8d9b5432e0f95f2b2cec570eab8dfd8e76b7f6..7de6d3339f35ebbf52f8709719a93e892a410a5f 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -41,6 +41,7 @@ enum TaskType { kUnpackBackward = 33; kRepeatForward = 34; kRepeatBackward = 35; + kReduceIdentity = 36; }; enum AreaType { diff --git a/oneflow/core/kernel/accuracy_kernel.cpp b/oneflow/core/kernel/accuracy_kernel.cpp index 60b818bbc05356c35b699cce033ff35cf0dc44e1..1f2a496e03eaed78849018a8e135cfb79b98693c 100644 --- a/oneflow/core/kernel/accuracy_kernel.cpp +++ b/oneflow/core/kernel/accuracy_kernel.cpp @@ -1,17 +1,30 @@ #include "oneflow/core/kernel/accuracy_kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { template<DeviceType device_type, typename PredType, typename LabelType> void AccuracyKernel<device_type, PredType, LabelType>::SetAccuracyInstanceNumBlob( const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const { - CHECK_GE(this->op_attribute().input_bns().size(), 2); - this->CheckSameDim0ValidNum(this->op_attribute().input_bns(), BnInOp2Blob); - int64_t dim0_valid_num_sum = - BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum(); - KernelUtil<device_type, PredType>::Set( - ctx.device_ctx, static_cast<PredType>(dim0_valid_num_sum), - BnInOp2Blob("accuracy_instance_num")->mut_dptr<PredType>()); + const Blob* weight = BnInOp2Blob("weight"); + Blob* accuracy_instance_num = BnInOp2Blob("accuracy_instance_num"); + if (weight == nullptr) { + CHECK_GE(this->op_attribute().input_bns().size(), 2); + this->CheckSameDim0ValidNum(this->op_attribute().input_bns(), BnInOp2Blob); + int64_t dim0_valid_num_sum = + BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum(); + KernelUtil<device_type, PredType>::Set(ctx.device_ctx, + static_cast<PredType>(dim0_valid_num_sum), + accuracy_instance_num->mut_dptr<PredType>()); + } else { + Blob* weight_reduce_tmp = BnInOp2Blob("weight_reduce_tmp"); + CHECK_LE(weight->shape().elem_cnt(), weight_reduce_tmp->shape().elem_cnt()); + const int64_t num_instance = weight->shape().elem_cnt(); + NdarrayUtil<device_type, PredType>::ReduceSum( + ctx.device_ctx, XpuVarNdarray<PredType>({1}, accuracy_instance_num->mut_dptr<PredType>()), + XpuVarNdarray<const PredType>({num_instance}, weight->dptr<PredType>()), + XpuVarNdarray<PredType>({num_instance}, weight_reduce_tmp->mut_dptr<PredType>())); + } } template<DeviceType device_type, typename PredType, typename LabelType> @@ -19,6 +32,8 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* X = BnInOp2Blob("prediction"); const Blob* label = BnInOp2Blob("label"); + const Blob* weight = BnInOp2Blob("weight"); + if (weight != nullptr) { CHECK_EQ(label->shape().elem_cnt(), weight->shape().elem_cnt()); } Blob* accuracy = BnInOp2Blob("accuracy"); auto kernel_conf = this->kernel_conf(); const int32_t top_k = kernel_conf.op_attribute().op_conf().accuracy_conf().top_k(); @@ -29,7 +44,7 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardDataContent( AccuracyKernelUtil<device_type, PredType, LabelType>::Forward( ctx.device_ctx, N, D, top_k, X->dptr<PredType>(), label->dptr<LabelType>(), - accuracy->mut_dptr<PredType>()); + weight ? weight->dptr<PredType>() : nullptr, accuracy->mut_dptr<PredType>()); SetAccuracyInstanceNumBlob(ctx, BnInOp2Blob); } @@ -48,8 +63,9 @@ void AccuracyKernel<device_type, PredType, LabelType>::ForwardRecordIdInDevicePi template<typename PredType, typename LabelType> struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> { static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k, - const PredType* XData, const LabelType* labelData, PredType* accuracyData) { - int correct = 0; + const PredType* XData, const LabelType* labelData, const PredType* weight, + PredType* accuracyData) { + PredType correct = 0; for (int i = 0; i < N; ++i) { auto label_i = labelData[i]; auto label_pred = XData[i * D + label_i]; @@ -60,10 +76,9 @@ struct AccuracyKernelUtil<DeviceType::kCPU, PredType, LabelType> { if (++cnt > top_k) { break; } } } - if (cnt <= top_k) { ++correct; } + if (cnt <= top_k) { correct += weight ? weight[i] : OneVal<PredType>::value; } } - CHECK_LE(correct, N); - *accuracyData = static_cast<PredType>(correct); + *accuracyData = correct; } }; diff --git a/oneflow/core/kernel/accuracy_kernel.cu b/oneflow/core/kernel/accuracy_kernel.cu index 46a6e51e25fe03523cd75082103eae4a1b7575f2..f5671d517652785d9f068af17c49e333eae5a99e 100644 --- a/oneflow/core/kernel/accuracy_kernel.cu +++ b/oneflow/core/kernel/accuracy_kernel.cu @@ -17,10 +17,10 @@ __global__ void AccuracySetZeroKernel(PredType* accuracy) { template<typename PredType, typename LabelType> __global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const int32_t top_k, const PredType* Xdata, const LabelType* labelData, - PredType* accuracy) { + const PredType* weight, PredType* accuracy) { typedef cub::BlockReduce<int32_t, kCudaThreadsNumPerBlock> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - int32_t correct = 0; + PredType correct = 0; for (int32_t row = blockIdx.x; row < N; row += gridDim.x) { const LabelType label = labelData[row]; const PredType label_pred = Xdata[row * D + label]; @@ -30,20 +30,22 @@ __global__ void AccuracyComputeKernel(const int32_t N, const int32_t D, const in if (pred > label_pred || (pred == label_pred && col <= label)) { ++ngt; } } ngt = BlockReduce(temp_storage).Sum(ngt); - if (ngt <= top_k) { ++correct; } + if (ngt <= top_k) { correct += weight ? weight[row] : OneVal<PredType>::value; } __syncthreads(); } - if (threadIdx.x == 0) { gpu_atomic_add(accuracy, static_cast<PredType>(correct)); } + if (threadIdx.x == 0) { gpu_atomic_add(accuracy, correct); } } } // namespace template<typename PredType, typename LabelType> struct AccuracyKernelUtil<DeviceType::kGPU, PredType, LabelType> { static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k, - const PredType* XData, const LabelType* labelData, PredType* accuracyData) { + const PredType* XData, const LabelType* labelData, const PredType* weight, + PredType* accuracyData) { AccuracySetZeroKernel<<<1, 1, 0, ctx->cuda_stream()>>>(accuracyData); AccuracyComputeKernel<<<BlocksNum4ThreadsNum(N), kCudaThreadsNumPerBlock, 0, - ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, accuracyData); + ctx->cuda_stream()>>>(N, D, top_k, XData, labelData, weight, + accuracyData); }; }; #define MAKE_ENTRY(data_type_pair, label_type_pair) \ diff --git a/oneflow/core/kernel/accuracy_kernel.h b/oneflow/core/kernel/accuracy_kernel.h index d3a47ba3a781499b0ba729c48f71a6e6aa482be5..c79a50f061af3616aa6861ce67d88f2777c846d1 100644 --- a/oneflow/core/kernel/accuracy_kernel.h +++ b/oneflow/core/kernel/accuracy_kernel.h @@ -27,7 +27,8 @@ class AccuracyKernel final : public KernelIf<device_type> { template<DeviceType device_type, typename PredType, typename LabelType> struct AccuracyKernelUtil { static void Forward(DeviceCtx* ctx, const int32_t N, const int32_t D, int32_t top_k, - const PredType* XData, const LabelType* labelData, PredType* accuracyData); + const PredType* XData, const LabelType* labelData, const PredType* weight, + PredType* accuracyData); }; } // namespace oneflow diff --git a/oneflow/core/kernel/adam_model_update_kernel.cpp b/oneflow/core/kernel/adam_model_update_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..498c2f688386543190262804f68f8b7de8d5b8d6 --- /dev/null +++ b/oneflow/core/kernel/adam_model_update_kernel.cpp @@ -0,0 +1,114 @@ +#include "oneflow/core/kernel/adam_model_update_kernel.h" +#include "oneflow/core/kernel/normal_model_update_kernel.cuh" + +namespace oneflow { + +namespace { + +template<typename T> +void UpdateMomentEstimate(int64_t n, bool do_bias_correction, T beta, int32_t p, + const T* model_diff, const T* beta_t, T* moment) { + FOR_RANGE(int64_t, i, 0, n) { + // Update biased moment estimate + moment[i] = beta * moment[i] + (1 - beta) * std::pow(model_diff[i], p); + if (do_bias_correction) { + // Correct deviation of moment estimate + moment[i] = moment[i] / (1 - *beta_t); + } + } +} + +} // namespace + +template<DeviceType device_type, typename T> +void AdamMdUpdateKernel<device_type, T>::InitModelBlobsWithRandomSeed( + DeviceCtx* ctx, std::mt19937* random_seed_gen, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const auto& adam_conf = this->op_conf().normal_mdupdt_conf().user_conf().adam_conf(); + InitializerConf m_init_conf; + InitializerConf v_init_conf; + m_init_conf.mutable_constant_conf()->set_value(0.0f); + v_init_conf.mutable_constant_conf()->set_value(0.0f); + KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &m_init_conf, 0, BnInOp2Blob("m")); + KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &v_init_conf, 0, BnInOp2Blob("v")); + if (!adam_conf.do_bias_correction()) { return; } + InitializerConf beta1_init_conf; + InitializerConf beta2_init_conf; + beta1_init_conf.mutable_constant_conf()->set_value(adam_conf.beta1()); + beta2_init_conf.mutable_constant_conf()->set_value(adam_conf.beta2()); + KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &beta1_init_conf, 0, + BnInOp2Blob("beta1_t")); + KernelUtil<device_type, T>::InitializeWithProperConf(ctx, &beta2_init_conf, 0, + BnInOp2Blob("beta2_t")); +} + +template<DeviceType device_type, typename T> +void AdamMdUpdateKernel<device_type, T>::InitModelBlobsWithDir( + DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const auto& adam_conf = this->op_conf().normal_mdupdt_conf().user_conf().adam_conf(); + Blob* m_blob = BnInOp2Blob("m"); + Blob* v_blob = BnInOp2Blob("v"); + KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, m_blob, "m", + m_blob->shape().At(0), m_blob->shape().Count(1)); + KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, v_blob, "v", + v_blob->shape().At(0), v_blob->shape().Count(1)); + if (!adam_conf.do_bias_correction()) { return; } + Blob* beta1_t_blob = BnInOp2Blob("beta1_t"); + Blob* beta2_t_blob = BnInOp2Blob("beta2_t"); + KernelUtil<device_type, T>::InitializeWithDir( + ctx, part_id, part_num, model_load_dir, beta1_t_blob, "beta1_t", beta1_t_blob->shape().At(0), + beta1_t_blob->shape().Count(1)); + KernelUtil<device_type, T>::InitializeWithDir( + ctx, part_id, part_num, model_load_dir, beta2_t_blob, "beta2_t", beta2_t_blob->shape().At(0), + beta2_t_blob->shape().Count(1)); +} + +template<DeviceType device_type, typename T> +void AdamMdUpdateKernel<device_type, T>::UpdateModel( + DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, + int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Blob* model_blob = BnInOp2Blob("model"); + Blob* m_blob = BnInOp2Blob("m"); + Blob* v_blob = BnInOp2Blob("v"); + Blob* beta1_t_blob = BnInOp2Blob("beta1_t"); + Blob* beta2_t_blob = BnInOp2Blob("beta2_t"); + const AdamModelUpdateConf& adam_conf = + this->op_conf().normal_mdupdt_conf().user_conf().adam_conf(); + if ((next_model_vid != 1) && adam_conf.do_bias_correction()) { + KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta1()), + beta1_t_blob->mut_dptr<T>(), 1); + KernelUtil<device_type, T>::Scal(ctx, 1, static_cast<T>(adam_conf.beta2()), + beta2_t_blob->mut_dptr<T>(), 1); + } + AdamMdUpdateKernelUtil<device_type, T>::UpdateModel( + ctx, model_blob->shape().elem_cnt(), batch_instance_num_ptr, learning_rate, l1, l2, + static_cast<T>(adam_conf.beta1()), static_cast<T>(adam_conf.beta2()), + static_cast<T>(adam_conf.epsilon()), adam_conf.do_bias_correction(), next_model_vid, + (beta1_t_blob ? beta1_t_blob->dptr<T>() : nullptr), + (beta2_t_blob ? beta2_t_blob->dptr<T>() : nullptr), BnInOp2Blob("model_diff")->mut_dptr<T>(), + model_blob->mut_dptr<T>(), m_blob->mut_dptr<T>(), v_blob->mut_dptr<T>()); +} + +template<typename T> +class AdamMdUpdateKernelUtil<DeviceType::kCPU, T> final { + public: + static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, + T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon, + bool do_bias_correction, int64_t next_model_vid, const T* beta1_t, + const T* beta2_t, T* model_diff, T* model, T* m, T* v) { + // first-order moment + UpdateMomentEstimate<T>(n, do_bias_correction, beta1, 1, model_diff, beta1_t, m); + // second-order moment + UpdateMomentEstimate<T>(n, do_bias_correction, beta2, 2, model_diff, beta2_t, v); + FOR_RANGE(int64_t, i, 0, n) { + model_diff[i] = m[i] / (std::sqrt(v[i]) + epsilon); + T reg_diff = RegularizeDiff(model_diff[i], *batch_instance_num_ptr, l1, l2, model[i]); + model[i] = model[i] - learning_rate * reg_diff; + } + } +}; + +DEFINE_MDUPDT_KERNEL_CREATOR(Adam); + +} // namespace oneflow diff --git a/oneflow/core/kernel/adam_model_update_kernel.cu b/oneflow/core/kernel/adam_model_update_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a229180fef5bfee509881df0140cf0884616f3c4 --- /dev/null +++ b/oneflow/core/kernel/adam_model_update_kernel.cu @@ -0,0 +1,93 @@ +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/kernel/adam_model_update_kernel.h" +#include "oneflow/core/kernel/normal_model_update_kernel.cuh" + +namespace oneflow { + +namespace { + +template<int32_t power> +struct PowUtil; + +template<> +struct PowUtil<1> final { + template<typename T> + __device__ static T pow(const T x) { + return x; + } +}; + +template<> +struct PowUtil<2> final { + template<typename T> + __device__ static T pow(const T x) { + return x * x; + } +}; + +template<bool do_bias_correction, typename T> +__device__ typename std::enable_if<do_bias_correction>::type ScaleMomentum(const T beta_t, + T* moment) { + *moment /= (1 - beta_t); +} + +template<bool do_bias_correction, typename T> +__device__ typename std::enable_if<!do_bias_correction>::type ScaleMomentum(const T beta_t, + T* moment) {} + +template<int32_t power, bool do_bias_correction, typename T> +__device__ void UpdateMomentEstimate(T beta, const T* model_diff, const T* beta_t, T* moment) { + // Update biased moment estimate + *moment = beta * (*moment) + (1 - beta) * PowUtil<power>::pow(*model_diff); + // Correct deviation of moment estimate + ScaleMomentum<do_bias_correction>(*beta_t, moment); +} + +template<typename T> +__device__ void UpdateModel(const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, T epsilon, + T* model_diff, T* model, T* m, T* v) { + *model_diff = *m / (sqrt(*v) + epsilon); + T reg_diff = RegularizeDiff(*model_diff, *batch_instance_num_ptr, l1, l2, *model); + *model = *model - learning_rate * reg_diff; +} + +template<bool do_bias_correction, typename T> +__global__ void UpdateModelGpu(int64_t n, const T* batch_instance_num_ptr, T learning_rate, T l1, + T l2, T beta1, T beta2, T epsilon, const T* beta1_t, + const T* beta2_t, T* model_diff, T* model, T* m, T* v) { + CUDA_1D_KERNEL_LOOP(i, n) { + UpdateMomentEstimate<1, do_bias_correction>(beta1, model_diff + i, beta1_t, m + i); + UpdateMomentEstimate<2, do_bias_correction>(beta2, model_diff + i, beta2_t, v + i); + UpdateModel(batch_instance_num_ptr, learning_rate, l1, l2, epsilon, model_diff + i, model + i, + m + i, v + i); + } +} + +} // namespace + +template<typename T> +class AdamMdUpdateKernelUtil<DeviceType::kGPU, T> final { + public: + static void UpdateModel(DeviceCtx* ctx, int64_t n, const T* batch_instance_num_ptr, + T learning_rate, T l1, T l2, T beta1, T beta2, T epsilon, + bool do_bias_correction, int64_t next_model_vid, const T* beta1_t, + const T* beta2_t, T* model_diff, T* model, T* m, T* v) { + if (do_bias_correction) { + UpdateModelGpu<true, T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + n, batch_instance_num_ptr, learning_rate, l1, l2, beta1, beta2, epsilon, beta1_t, + beta2_t, model_diff, model, m, v); + } else { + UpdateModelGpu<false, T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + n, batch_instance_num_ptr, learning_rate, l1, l2, beta1, beta2, epsilon, beta1_t, + beta2_t, model_diff, model, m, v); + } + } +}; + +#define INSTANTIATE_GPU_KERNEL_UTIL(type_cpp, type_proto) \ + template class AdamMdUpdateKernelUtil<DeviceType::kGPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GPU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/kernel/adam_model_update_kernel.h b/oneflow/core/kernel/adam_model_update_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3031f7bf2349babc574628dc6de3da6cff67f436 --- /dev/null +++ b/oneflow/core/kernel/adam_model_update_kernel.h @@ -0,0 +1,40 @@ +#ifndef ONEFLOW_CORE_KERNEL_ADAM_MODEL_UPDATE_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_ADAM_MODEL_UDPATE_KERNEL_H_ + +#include "oneflow/core/kernel/normal_model_update_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class AdamMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> { + public: + OF_DISALLOW_COPY_AND_MOVE(AdamMdUpdateKernel); + AdamMdUpdateKernel() = default; + ~AdamMdUpdateKernel() = default; + + private: + void InitModelBlobsWithRandomSeed( + DeviceCtx* ctx, std::mt19937* random_seed_gen, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num, + const std::string& model_load_dir, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, + int64_t next_model_vid, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +template<DeviceType device_type, typename T> +class AdamMdUpdateKernelUtil final { + public: + static void UpdateModel(DeviceCtx*, int64_t n, const T* batch_instance_num_ptr, T learning_rate, + T l1, T l2, T beta1, T beta2, T epsilon, bool do_bias_correction, + int64_t next_model_vid, const T* beta1_t, const T* beta2_t, T* model_diff, + T* model, T* m, T* v); +}; + +DECLARE_MDUPDT_KERNEL_CREATOR(Adam); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_ADAM_MODEL_UPDATE_KERNEL_H_ diff --git a/oneflow/core/kernel/batch_gather_kernel.cpp b/oneflow/core/kernel/batch_gather_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..501252a85cd4ac1030474e0e054eaf0c8d414d08 --- /dev/null +++ b/oneflow/core/kernel/batch_gather_kernel.cpp @@ -0,0 +1,118 @@ +#include "oneflow/core/kernel/batch_gather_kernel.h" + +namespace oneflow { + +namespace { + +Shape GetFlatShape(const Shape& shape, const int64_t axis) { + CHECK_GT(shape.NumAxes(), 0); + CHECK_GE(axis, 0); + CHECK_LT(axis, shape.NumAxes()); + return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); +} + +template<DeviceType device_type, typename T, typename K> +void BatchGatherForward(DeviceCtx* ctx, const Blob* in, const Blob* indices, Blob* out) { + const int64_t axis = indices->shape().NumAxes() - 1; + const Shape flat_out_shape = GetFlatShape(out->shape(), axis); + BatchGatherKernelUtil<device_type, T, K>::Forward(ctx, in->dptr<T>(), indices->dptr<K>(), + flat_out_shape, in->shape().At(axis), + out->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T, typename K> +void BatchGatherBackward(DeviceCtx* ctx, const Blob* out_diff, const Blob* indices, Blob* in_diff) { + Memset<device_type>(ctx, in_diff->mut_dptr<T>(), 0, in_diff->ByteSizeOfDataContentField()); + const int64_t axis = indices->shape().NumAxes() - 1; + const Shape flat_out_diff_shape = GetFlatShape(out_diff->shape(), axis); + BatchGatherKernelUtil<device_type, T, K>::Backward(ctx, out_diff->dptr<T>(), indices->dptr<K>(), + flat_out_diff_shape, in_diff->shape().At(axis), + in_diff->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +struct BatchGatherSwitchUtil final { +#define MAKE_BATCH_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K> +#define DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(func_name) \ + DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_BATCH_GATHER_SWITCH_ENTRY, \ + MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); + DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherForward); + DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC(BatchGatherBackward); +#undef DEFINE_BATCH_GATHER_STATIC_SWITCH_FUNC +#undef MAKE_BATCH_GATHER_SWITCH_ENTRY +}; + +} // namespace + +template<DeviceType device_type, typename T> +const PbMessage& BatchGatherKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().batch_gather_conf(); +} + +template<DeviceType device_type, typename T> +void BatchGatherKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + BatchGatherSwitchUtil<device_type, T>::SwitchBatchGatherForward( + SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, BnInOp2Blob("in"), + BnInOp2Blob("indices"), BnInOp2Blob("out")); +} + +template<DeviceType device_type, typename T> +void BatchGatherKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + BatchGatherSwitchUtil<device_type, T>::SwitchBatchGatherBackward( + SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, + BnInOp2Blob(GenDiffBn("out")), BnInOp2Blob("indices"), BnInOp2Blob(GenDiffBn("in"))); +} + +template<typename T, typename K> +struct BatchGatherKernelUtil<DeviceType::kCPU, T, K> final { + static void Forward(DeviceCtx* ctx, const T* in, const K* indices, const Shape& flat_out_shape, + const int64_t gather_dim_size, T* out); + static void Backward(DeviceCtx* ctx, const T* out_diff, const K* indices, + const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff); +}; + +template<typename T, typename K> +void BatchGatherKernelUtil<DeviceType::kCPU, T, K>::Forward(DeviceCtx* ctx, const T* in, + const K* indices, + const Shape& flat_out_shape, + const int64_t gather_dim_size, T* out) { + const int64_t batch_num = flat_out_shape.At(0); + const int64_t indices_num = flat_out_shape.At(1); + const int64_t instance_size = flat_out_shape.At(2); + FOR_RANGE(int64_t, batch_idx, 0, batch_num) { + FOR_RANGE(int64_t, i, 0, indices_num) { + const K idx = indices[batch_idx * indices_num + i]; + CHECK(idx >= 0 && idx < gather_dim_size); + const T* from = in + batch_idx * gather_dim_size * instance_size + idx * instance_size; + T* to = out + batch_idx * indices_num * instance_size + i * instance_size; + std::copy(from, from + instance_size, to); + } + } +} + +template<typename T, typename K> +void BatchGatherKernelUtil<DeviceType::kCPU, T, K>::Backward(DeviceCtx* ctx, const T* out_diff, + const K* indices, + const Shape& flat_out_diff_shape, + const int64_t gather_dim_size, + T* in_diff) { + const int64_t batch_num = flat_out_diff_shape.At(0); + const int64_t indices_num = flat_out_diff_shape.At(1); + const int64_t instance_size = flat_out_diff_shape.At(2); + FOR_RANGE(int64_t, batch_idx, 0, batch_num) { + FOR_RANGE(int64_t, i, 0, indices_num) { + const int64_t idx = indices[batch_idx * indices_num + i]; + CHECK(idx >= 0 && idx < gather_dim_size); + const T* from = out_diff + batch_idx * indices_num * instance_size + i * instance_size; + T* to = in_diff + batch_idx * gather_dim_size * instance_size + idx * instance_size; + std::transform(from, from + instance_size, to, to, std::plus<T>()); + } + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBatchGatherConf, BatchGatherKernel, + FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/batch_gather_kernel.cu b/oneflow/core/kernel/batch_gather_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..4292f4fdfaa3d69587db6eda0dacb1d0b91d9fa5 --- /dev/null +++ b/oneflow/core/kernel/batch_gather_kernel.cu @@ -0,0 +1,87 @@ +#include "oneflow/core/kernel/batch_gather_kernel.h" +#include "oneflow/core/kernel/kernel_util.cuh" +#include <assert.h> + +namespace oneflow { + +namespace { + +template<typename K> +__device__ int64_t GetInOffset(const int64_t out_offset, const K* indices, + const int64_t indices_num, const int64_t instance_size, + const int64_t gather_dim_size) { + const int64_t batch_idx = out_offset / (indices_num * instance_size); + const int64_t indices_idx = out_offset % (indices_num * instance_size) / instance_size; + const int64_t inner_idx = out_offset % instance_size; + const int64_t idx = indices[batch_idx * indices_num + indices_idx]; + assert(idx >= 0 && idx < gather_dim_size); + return batch_idx * gather_dim_size * instance_size + idx * instance_size + inner_idx; +} + +template<typename T, typename K> +__global__ void BatchGatherForwardGpu(const int64_t elem_cnt, const T* in, const K* indices, + const int64_t indices_num, const int64_t instance_size, + const int64_t gather_dim_size, T* out) { + CUDA_1D_KERNEL_LOOP(i, elem_cnt) { + out[i] = in[GetInOffset<K>(i, indices, indices_num, instance_size, gather_dim_size)]; + } +} + +template<typename T, typename K> +__global__ void BatchGatherBackwardGpu(const int64_t elem_cnt, const T* out_diff, const K* indices, + const int64_t indices_num, const int64_t instance_size, + const int64_t gather_dim_size, T* in_diff) { + CUDA_1D_KERNEL_LOOP(i, elem_cnt) { + gpu_atomic_add( + in_diff + GetInOffset<K>(i, indices, indices_num, instance_size, gather_dim_size), + out_diff[i]); + } +} + +} // namespace + +template<typename T, typename K> +struct BatchGatherKernelUtil<DeviceType::kGPU, T, K> final { + static void Forward(DeviceCtx* ctx, const T* in, const K* indices, const Shape& flat_out_shape, + const int64_t gather_dim_size, T* out); + static void Backward(DeviceCtx* ctx, const T* out_diff, const K* indices, + const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff); +}; + +template<typename T, typename K> +void BatchGatherKernelUtil<DeviceType::kGPU, T, K>::Forward(DeviceCtx* ctx, const T* in, + const K* indices, + const Shape& flat_out_shape, + const int64_t gather_dim_size, T* out) { + const int64_t batch_num = flat_out_shape.At(0); + const int64_t indices_num = flat_out_shape.At(1); + const int64_t instance_size = flat_out_shape.At(2); + const int64_t elem_cnt = batch_num * indices_num * instance_size; + BatchGatherForwardGpu<T, K> + <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + elem_cnt, in, indices, indices_num, instance_size, gather_dim_size, out); +} + +template<typename T, typename K> +void BatchGatherKernelUtil<DeviceType::kGPU, T, K>::Backward(DeviceCtx* ctx, const T* out_diff, + const K* indices, + const Shape& flat_out_diff_shape, + const int64_t gather_dim_size, + T* in_diff) { + const int64_t batch_num = flat_out_diff_shape.At(0); + const int64_t indices_num = flat_out_diff_shape.At(1); + const int64_t instance_size = flat_out_diff_shape.At(2); + const int64_t elem_cnt = batch_num * indices_num * instance_size; + BatchGatherBackwardGpu<T, K> + <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + elem_cnt, out_diff, indices, indices_num, instance_size, gather_dim_size, in_diff); +} + +#define MAKE_BATCH_GATHER_KERNEL_UTIL_ENTRY(in_type_pair, index_type_pair) \ + template struct BatchGatherKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(in_type_pair), \ + OF_PP_PAIR_FIRST(index_type_pair)>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_BATCH_GATHER_KERNEL_UTIL_ENTRY, FLOATING_DATA_TYPE_SEQ, + INT_DATA_TYPE_SEQ); +#undef MAKE_BATCH_GATHER_KERNEL_UTIL_ENTRY + +} // namespace oneflow diff --git a/oneflow/core/kernel/batch_gather_kernel.h b/oneflow/core/kernel/batch_gather_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..57521b14ff054ba8fde0d6aee8457e433e0c964f --- /dev/null +++ b/oneflow/core/kernel/batch_gather_kernel.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_KERNEL_BATCH_GATHER_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_BATCH_GATHER_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class BatchGatherKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(BatchGatherKernel); + BatchGatherKernel() = default; + ~BatchGatherKernel() override = default; + + private: + const PbMessage& GetCustomizedOpConf() const override; + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BackwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +template<DeviceType device_type, typename T, typename K> +struct BatchGatherKernelUtil final { + static void Forward(DeviceCtx* ctx, const T* in, const K* indices, const Shape& flat_out_shape, + const int64_t gather_dim_size, T* out); + static void Backward(DeviceCtx* ctx, const T* out_diff, const K* indices, + const Shape& flat_out_diff_shape, const int64_t gather_dim_size, T* in_diff); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BATCH_GATHER_KERNEL_H_ diff --git a/oneflow/core/kernel/bias_add_kernel.cpp b/oneflow/core/kernel/bias_add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f44094296033db870fec0804cb717d978e9056ef --- /dev/null +++ b/oneflow/core/kernel/bias_add_kernel.cpp @@ -0,0 +1,56 @@ +#include "oneflow/core/kernel/bias_add_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void BiasAddKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a_blob = BnInOp2Blob("a"); + const Blob* b_blob = BnInOp2Blob("b"); + const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); + Blob* out_blob = BnInOp2Blob("out"); + + // out = bias_multiplier * b + a + Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr<T>(), a_blob->dptr<T>(), + a_blob->ByteSizeOfDataContentField()); + KernelUtil<device_type, T>::OFGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, + out_blob->shape().At(0), out_blob->shape().At(1), 1, + OneVal<T>::value, bias_mul_blob->dptr<T>(), b_blob->dptr<T>(), + OneVal<T>::value, out_blob->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void BiasAddKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); + Blob* a_diff_blob = BnInOp2Blob("a_diff"); + Blob* b_diff_blob = BnInOp2Blob("b_diff"); + + Memcpy<device_type>(ctx.device_ctx, a_diff_blob->mut_dptr<T>(), out_diff_blob->dptr<T>(), + out_diff_blob->ByteSizeOfDataContentField()); + // b_diff = bias_multiplier * out_diff + KernelUtil<device_type, T>::OFGemm( + ctx.device_ctx, CblasTrans, CblasNoTrans, 1, b_diff_blob->shape().At(0), + out_diff_blob->shape().At(0), OneVal<T>::value, bias_mul_blob->dptr<T>(), + out_diff_blob->dptr<T>(), ZeroVal<T>::value, b_diff_blob->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void BiasAddKernel<device_type, T>::InitConstBufBlobs( + DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + InitializerConf bias_multiplier_initializer_conf; + bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f); + KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0, + BnInOp2Blob("bias_multiplier")); +} + +template<DeviceType device_type, typename T> +const PbMessage& BiasAddKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().bias_add_conf(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBiasAddConf, BiasAddKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/bias_add_kernel.h b/oneflow/core/kernel/bias_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..78b5c16cd7430c10c962d52017311d30bf70cfdb --- /dev/null +++ b/oneflow/core/kernel/bias_add_kernel.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_KERNEL_BIAS_ADD_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_BIAS_ADD_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class BiasAddKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(BiasAddKernel); + BiasAddKernel() = default; + ~BiasAddKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void InitConstBufBlobs(DeviceCtx*, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BIAS_ADD_KERNEL_H_ diff --git a/oneflow/core/kernel/boxing_kernel.cpp b/oneflow/core/kernel/boxing_kernel.cpp index 645110161ae2d064ce6211b80d5359a4c2096163..c35d42f0c5a720447465fd0eef9e4a4177c0edbb 100644 --- a/oneflow/core/kernel/boxing_kernel.cpp +++ b/oneflow/core/kernel/boxing_kernel.cpp @@ -351,6 +351,41 @@ void BoxingKernel<T>::SetMaxColId(const KernelCtx& ctx, } } +template<typename T> +void BoxingKernel<T>::ForwardLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const PbRpf<std::string>& input_bns = op_attribute().input_bns(); + const PbRpf<std::string>& output_bns = op_attribute().output_bns(); + CHECK_GT(input_bns.size(), 0); + float in_loss_instance_num = BnInOp2Blob(input_bns.Get(0))->loss_instance_num(); + const float loss_instance_num_epsilon = 1e-8; + const BoxingOpConf& conf = op_conf().boxing_conf(); + if (conf.in_box_case() == BoxingOpConf::kConcatBox) { + FOR_RANGE(int32_t, i, 1, input_bns.size()) { + in_loss_instance_num += BnInOp2Blob(input_bns.Get(i))->loss_instance_num(); + } + } else if (conf.in_box_case() == BoxingOpConf::kAddBox) { + FOR_RANGE(int32_t, i, 1, input_bns.size()) { + CHECK_LT(std::fabs(BnInOp2Blob(input_bns.Get(i))->loss_instance_num() - in_loss_instance_num), + loss_instance_num_epsilon); + } + } else { + UNIMPLEMENTED(); + } + if (conf.out_box_case() == BoxingOpConf::kSplitBox) { + const float out_loss_instance_num = in_loss_instance_num / output_bns.size(); + FOR_RANGE(int32_t, i, 0, output_bns.size()) { + BnInOp2Blob(output_bns.Get(i))->set_loss_instance_num(out_loss_instance_num); + } + } else if (conf.out_box_case() == BoxingOpConf::kCloneBox) { + FOR_RANGE(int32_t, i, 0, output_bns.size()) { + BnInOp2Blob(output_bns.Get(i))->set_loss_instance_num(in_loss_instance_num); + } + } else { + UNIMPLEMENTED(); + } +} + ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kBoxingConf, BoxingKernel, ARITHMETIC_DATA_TYPE_SEQ); } // namespace oneflow diff --git a/oneflow/core/kernel/boxing_kernel.h b/oneflow/core/kernel/boxing_kernel.h index f8440525908dd1ecc1cafaafdb756efd75699d08..2f332e89ddbe43f97000f0f0f9b634ccc68abeee 100644 --- a/oneflow/core/kernel/boxing_kernel.h +++ b/oneflow/core/kernel/boxing_kernel.h @@ -33,6 +33,8 @@ class BoxingKernel final : public KernelIf<DeviceType::kCPU> { void SetColId(const KernelCtx&, std::function<Blob*(const std::string&)>) const; void SetMaxColId(const KernelCtx&, std::function<Blob*(const std::string&)>) const; + void ForwardLossInstanceNum(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; PbRpf<std::string> ibn_0_; PbRpf<std::string> obn_0_; diff --git a/oneflow/core/kernel/broadcast_add_kernel.cpp b/oneflow/core/kernel/broadcast_add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c626c361a83968b37bfa2bb958406db3000d5590 --- /dev/null +++ b/oneflow/core/kernel/broadcast_add_kernel.cpp @@ -0,0 +1,43 @@ +#include "oneflow/core/kernel/broadcast_add_kernel.h" +#include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/common/preprocessor.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void BroadcastAddKernel<device_type, T>::ForwardDataContent( + const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a_blob = BnInOp2Blob("a"); + const Blob* b_blob = BnInOp2Blob("b"); + Blob* out_blob = BnInOp2Blob("out"); + size_t num_axes = out_blob->shape().NumAxes(); + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncAdd>( + kernel_ctx.device_ctx, XpuVarNdarray<T>(out_blob, num_axes), + XpuVarNdarray<const T>(a_blob, num_axes), XpuVarNdarray<const T>(b_blob, num_axes)); +} + +template<DeviceType device_type, typename T> +void BroadcastAddKernel<device_type, T>::BackwardDataContent( + const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + Blob* a_diff_blob = BnInOp2Blob("a_diff"); + Blob* b_diff_blob = BnInOp2Blob("b_diff"); + Blob* bw_buf_blob = BnInOp2Blob("bw_buf"); + size_t num_axes = out_diff_blob->shape().NumAxes(); + if (a_diff_blob) { + NdarrayUtil<device_type, T>::ReduceSum( + kernel_ctx.device_ctx, XpuVarNdarray<T>(a_diff_blob, num_axes), + XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes)); + } + if (b_diff_blob) { + NdarrayUtil<device_type, T>::ReduceSum( + kernel_ctx.device_ctx, XpuVarNdarray<T>(b_diff_blob, num_axes), + XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes)); + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastAddConf, BroadcastAddKernel, + ARITHMETIC_DATA_TYPE_SEQ); +} // namespace oneflow diff --git a/oneflow/core/kernel/broadcast_add_kernel.h b/oneflow/core/kernel/broadcast_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..196f056b7707354a354d74da2756836bda7fe110 --- /dev/null +++ b/oneflow/core/kernel/broadcast_add_kernel.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_ADD_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_BROADCAST_ADD_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class BroadcastAddKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastAddKernel); + BroadcastAddKernel() = default; + ~BroadcastAddKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BROADCAST_ADD_KERNEL_H_ diff --git a/oneflow/core/kernel/broadcast_div_kernel.cpp b/oneflow/core/kernel/broadcast_div_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06f43a1a998bbe1af9bbb5d06a812a4b95ea3714 --- /dev/null +++ b/oneflow/core/kernel/broadcast_div_kernel.cpp @@ -0,0 +1,101 @@ +#include "oneflow/core/kernel/broadcast_div_kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void BroadcastDivKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a = BnInOp2Blob("a"); + const Blob* b = BnInOp2Blob("b"); + Blob* out = BnInOp2Blob("out"); + int64_t n = out->shape().elem_cnt(); + if (a->shape().elem_cnt() == 1) { + CHECK_EQ(n, b->shape().elem_cnt()); + KernelUtil<device_type, T>::Replicate(ctx.device_ctx, n, out->mut_dptr<T>(), a->dptr<T>()); + KernelUtil<device_type, T>::Div(ctx.device_ctx, n, out->dptr<T>(), b->dptr<T>(), + out->mut_dptr<T>()); + } else if (b->shape().elem_cnt() == 1) { + CHECK_EQ(n, a->shape().elem_cnt()); + KernelUtil<device_type, T>::Replicate(ctx.device_ctx, n, out->mut_dptr<T>(), b->dptr<T>()); + KernelUtil<device_type, T>::Div(ctx.device_ctx, n, a->dptr<T>(), out->dptr<T>(), + out->mut_dptr<T>()); + } else { + size_t num_axes = out->shape().NumAxes(); + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncDiv>( + ctx.device_ctx, XpuVarNdarray<T>(out, num_axes), XpuVarNdarray<const T>(a, num_axes), + XpuVarNdarray<const T>(b, num_axes)); + } +} + +template<DeviceType device_type, typename T> +void BroadcastDivKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a = BnInOp2Blob("a"); + const Blob* b = BnInOp2Blob("b"); + const Blob* out_diff = BnInOp2Blob(GenDiffBn("out")); + Blob* tmp = BnInOp2Blob("bw_buf"); + Blob* a_diff = BnInOp2Blob(GenDiffBn("a")); + Blob* b_diff = BnInOp2Blob(GenDiffBn("b")); + + int64_t n = out_diff->shape().elem_cnt(); + KernelUtil<device_type, T>::Reciprocal(ctx.device_ctx, b->shape().elem_cnt(), b->dptr<T>(), + tmp->mut_dptr<T>()); + if (a->shape().elem_cnt() == 1) { + if (a_diff) { + KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, tmp->dptr<T>(), 1, + a_diff->mut_dptr<T>()); + } + if (b_diff) { + KernelUtil<device_type, T>::Square(ctx.device_ctx, n, tmp->dptr<T>(), tmp->mut_dptr<T>()); + KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, tmp->dptr<T>(), a->dptr<T>(), + tmp->mut_dptr<T>()); + KernelUtil<device_type, T>::Axpy(ctx.device_ctx, n, static_cast<T>(-2), tmp->dptr<T>(), 1, + tmp->mut_dptr<T>(), 1); + KernelUtil<device_type, T>::Mul(ctx.device_ctx, n, out_diff->dptr<T>(), tmp->dptr<T>(), + b_diff->mut_dptr<T>()); + } + } else if (b->shape().elem_cnt() == 1) { + if (a_diff) { + KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, out_diff->dptr<T>(), + tmp->dptr<T>(), a_diff->mut_dptr<T>()); + } + if (b_diff) { + KernelUtil<device_type, T>::Square(ctx.device_ctx, 1, tmp->dptr<T>(), tmp->mut_dptr<T>()); + KernelUtil<device_type, T>::Axpy(ctx.device_ctx, 1, static_cast<T>(-2), tmp->dptr<T>(), 1, + tmp->mut_dptr<T>(), 1); + KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, a->dptr<T>(), 1, + b_diff->mut_dptr<T>()); + KernelUtil<device_type, T>::Mul(ctx.device_ctx, 1, tmp->dptr<T>(), b_diff->dptr<T>(), + b_diff->mut_dptr<T>()); + } + } else { + size_t num_axes = out_diff->shape().NumAxes(); + XpuVarNdarray<const T> out_diff_tensor(out_diff, num_axes); + Blob* bw_buf_blob = BnInOp2Blob("bw_buf"); + XpuVarNdarray<const T> const_tmp(out_diff_tensor.shape(), bw_buf_blob->dptr<T>()); + XpuVarNdarray<T> tmp(out_diff_tensor.shape(), bw_buf_blob->mut_dptr<T>()); + if (a_diff) { + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncDiv>( + ctx.device_ctx, tmp, out_diff_tensor, XpuVarNdarray<const T>(b, num_axes)); + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(a_diff, num_axes), + const_tmp, tmp); + } + if (b_diff) { + const Blob* out_blob = BnInOp2Blob("out"); + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncDiv>( + ctx.device_ctx, tmp, XpuVarNdarray<const T>(out_blob, num_axes), + XpuVarNdarray<const T>(b, num_axes)); + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, tmp, out_diff_tensor, const_tmp); + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(b_diff, num_axes), + const_tmp, tmp); + NdarrayUtil<device_type, T>::template ImplaceApplyUnary<UnaryFuncMinus>( + ctx.device_ctx, XpuVarNdarray<T>(b_diff, num_axes)); + } + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastDivConf, BroadcastDivKernel, + FLOATING_DATA_TYPE_SEQ); +} // namespace oneflow diff --git a/oneflow/core/kernel/broadcast_div_kernel.h b/oneflow/core/kernel/broadcast_div_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7c8974f41fe4832a22ae36a620ebd03b4a02730e --- /dev/null +++ b/oneflow/core/kernel/broadcast_div_kernel.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_DIV_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_BROADCAST_DIV_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class BroadcastDivKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastDivKernel); + BroadcastDivKernel() = default; + ~BroadcastDivKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BROADCAST_DIV_KERNEL_H_ diff --git a/oneflow/core/kernel/broadcast_mul_kernel.cpp b/oneflow/core/kernel/broadcast_mul_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1dd306f45d57d9b463fc657c66a90e60756ac77 --- /dev/null +++ b/oneflow/core/kernel/broadcast_mul_kernel.cpp @@ -0,0 +1,80 @@ +#include "oneflow/core/kernel/broadcast_mul_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void BroadcastMulKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a = BnInOp2Blob("a"); + const Blob* b = BnInOp2Blob("b"); + Blob* out = BnInOp2Blob("out"); + int64_t n = out->shape().elem_cnt(); + if (a->shape().elem_cnt() == 1) { + CHECK_EQ(n, b->shape().elem_cnt()); + KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, b->dptr<T>(), a->dptr<T>(), + out->mut_dptr<T>()); + } else if (b->shape().elem_cnt() == 1) { + CHECK_EQ(n, a->shape().elem_cnt()); + KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, a->dptr<T>(), b->dptr<T>(), + out->mut_dptr<T>()); + } else { + size_t num_axes = out->shape().NumAxes(); + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, XpuVarNdarray<T>(out, num_axes), XpuVarNdarray<const T>(a, num_axes), + XpuVarNdarray<const T>(b, num_axes)); + } +} + +template<DeviceType device_type, typename T> +void BroadcastMulKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a = BnInOp2Blob("a"); + const Blob* b = BnInOp2Blob("b"); + const Blob* out_diff = BnInOp2Blob(GenDiffBn("out")); + Blob* a_diff = BnInOp2Blob(GenDiffBn("a")); + Blob* b_diff = BnInOp2Blob(GenDiffBn("b")); + int64_t n = out_diff->shape().elem_cnt(); + if (a->shape().elem_cnt() == 1) { + if (a_diff) { + KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, b->dptr<T>(), 1, + a_diff->mut_dptr<T>()); + } + if (b_diff) { + KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, out_diff->dptr<T>(), a->dptr<T>(), + b_diff->mut_dptr<T>()); + } + } else if (b->shape().elem_cnt() == 1) { + if (a_diff) { + KernelUtil<device_type, T>::MulByScalar(ctx.device_ctx, n, out_diff->dptr<T>(), b->dptr<T>(), + a_diff->mut_dptr<T>()); + } + if (b_diff) { + KernelUtil<device_type, T>::Dot(ctx.device_ctx, n, out_diff->dptr<T>(), 1, a->dptr<T>(), 1, + b_diff->mut_dptr<T>()); + } + } else { + size_t num_axes = out_diff->shape().NumAxes(); + XpuVarNdarray<const T> out_diff_tensor(out_diff, num_axes); + Blob* bw_buf_blob = BnInOp2Blob("bw_buf"); + XpuVarNdarray<const T> const_tmp(out_diff_tensor.shape(), bw_buf_blob->dptr<T>()); + XpuVarNdarray<T> tmp(out_diff_tensor.shape(), bw_buf_blob->mut_dptr<T>()); + if (a_diff) { + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, tmp, out_diff_tensor, XpuVarNdarray<const T>(b, num_axes)); + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(a_diff, num_axes), + const_tmp, tmp); + } + if (b_diff) { + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, tmp, out_diff_tensor, XpuVarNdarray<const T>(a, num_axes)); + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(b_diff, num_axes), + const_tmp, tmp); + } + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastMulConf, BroadcastMulKernel, + FLOATING_DATA_TYPE_SEQ); +} // namespace oneflow diff --git a/oneflow/core/kernel/broadcast_mul_kernel.h b/oneflow/core/kernel/broadcast_mul_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..24f66a20ed1e8b6fa1bc06e5991a93ef4b7017e2 --- /dev/null +++ b/oneflow/core/kernel/broadcast_mul_kernel.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_MUL_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_BROADCAST_MUL_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class BroadcastMulKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastMulKernel); + BroadcastMulKernel() = default; + ~BroadcastMulKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BROADCAST_MUL_KERNEL_H_ diff --git a/oneflow/core/kernel/broadcast_sub_kernel.cpp b/oneflow/core/kernel/broadcast_sub_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0d1d51d899c9a4fa0187244b46fd2fc2948553e --- /dev/null +++ b/oneflow/core/kernel/broadcast_sub_kernel.cpp @@ -0,0 +1,44 @@ +#include "oneflow/core/kernel/broadcast_sub_kernel.h" +#include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/ndarray/unary_func.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void BroadcastSubKernel<device_type, T>::ForwardDataContent( + const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a_blob = BnInOp2Blob("a"); + const Blob* b_blob = BnInOp2Blob("b"); + Blob* out_blob = BnInOp2Blob("out"); + size_t num_axes = out_blob->shape().NumAxes(); + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncSub>( + kernel_ctx.device_ctx, XpuVarNdarray<T>(out_blob, num_axes), + XpuVarNdarray<const T>(a_blob, num_axes), XpuVarNdarray<const T>(b_blob, num_axes)); +} + +template<DeviceType device_type, typename T> +void BroadcastSubKernel<device_type, T>::BackwardDataContent( + const KernelCtx& kernel_ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + Blob* bw_buf_blob = BnInOp2Blob("bw_buf"); + Blob* a_diff_blob = BnInOp2Blob("a_diff"); + Blob* b_diff_blob = BnInOp2Blob("b_diff"); + size_t num_axes = out_diff_blob->shape().NumAxes(); + if (a_diff_blob) { + NdarrayUtil<device_type, T>::ReduceSum( + kernel_ctx.device_ctx, XpuVarNdarray<T>(a_diff_blob, num_axes), + XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes)); + } + if (b_diff_blob) { + NdarrayUtil<device_type, T>::ReduceSum( + kernel_ctx.device_ctx, XpuVarNdarray<T>(b_diff_blob, num_axes), + XpuVarNdarray<const T>(out_diff_blob, num_axes), XpuVarNdarray<T>(bw_buf_blob, num_axes)); + NdarrayUtil<device_type, T>::template ImplaceApplyUnary<UnaryFuncMinus>( + kernel_ctx.device_ctx, XpuVarNdarray<T>(b_diff_blob, num_axes)); + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kBroadcastSubConf, BroadcastSubKernel, + FLOATING_DATA_TYPE_SEQ); +} // namespace oneflow diff --git a/oneflow/core/kernel/broadcast_sub_kernel.h b/oneflow/core/kernel/broadcast_sub_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8f72f4a4bbb0293b73e9a40e6889c904b685c72e --- /dev/null +++ b/oneflow/core/kernel/broadcast_sub_kernel.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_KERNEL_BROADCAST_SUB_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_BROADCAST_SUB_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class BroadcastSubKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastSubKernel); + BroadcastSubKernel() = default; + ~BroadcastSubKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_BROADCAST_SUB_KERNEL_H_ diff --git a/oneflow/core/kernel/cast_kernel.cpp b/oneflow/core/kernel/cast_kernel.cpp index 3125df6969ab62819d2d20474167d0de8e5e6f2d..2750b8233e3a4a63d8402723058da7bb2dab4daf 100644 --- a/oneflow/core/kernel/cast_kernel.cpp +++ b/oneflow/core/kernel/cast_kernel.cpp @@ -6,26 +6,47 @@ namespace oneflow { namespace { -template<typename T, typename U> -void CopyBlob(const Blob* src, Blob* dst) { +template<DeviceType device_type, typename T, typename U> +void CopyBlob(DeviceCtx* ctx, const Blob* src, Blob* dst) { CHECK_EQ(src->shape(), dst->shape()); - CopyElem(src->dptr<T>(), dst->mut_dptr<U>(), src->shape().elem_cnt()); + if (device_type == DeviceType::kCPU) { + CopyElem(src->dptr<T>(), dst->mut_dptr<U>(), src->shape().elem_cnt()); + } else if (device_type == DeviceType::kGPU) { + CopyElemOnGpu(ctx, src->dptr<T>(), dst->mut_dptr<U>(), src->shape().elem_cnt()); + } else { + UNIMPLEMENTED(); + } } -#define MAKE_CAST_SWITCH_ENTRY(func_name, T, U) func_name<T, U> -DEFINE_STATIC_SWITCH_FUNC(void, CopyBlob, MAKE_CAST_SWITCH_ENTRY, - MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), - MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ)); - } // namespace -void CastKernel::ForwardDataContent(const KernelCtx& ctx, - std::function<Blob*(const std::string&)> BnInOp2Blob) const { +#define MAKE_CAST_SWITCH_ENTRY(func_name, T, U) func_name<device_type, T, U> +template<DeviceType device_type> +struct CastUtil final { + DEFINE_STATIC_SWITCH_FUNC(void, CopyBlob, MAKE_CAST_SWITCH_ENTRY, + MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), + MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ)); +}; + +template<DeviceType device_type> +void CastKernel<device_type>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* in_blob = BnInOp2Blob("in"); Blob* out_blob = BnInOp2Blob("out"); - SwitchCopyBlob(SwitchCase(in_blob->data_type(), out_blob->data_type()), in_blob, out_blob); + CastUtil<device_type>::SwitchCopyBlob(SwitchCase(in_blob->data_type(), out_blob->data_type()), + ctx.device_ctx, in_blob, out_blob); +} + +template<DeviceType device_type> +void CastKernel<device_type>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in")); + CastUtil<device_type>::SwitchCopyBlob( + SwitchCase(out_diff_blob->data_type(), in_diff_blob->data_type()), ctx.device_ctx, + out_diff_blob, in_diff_blob); } -REGISTER_KERNEL(OperatorConf::kCastConf, CastKernel); +ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kCastConf, CastKernel); } // namespace oneflow diff --git a/oneflow/core/kernel/cast_kernel.h b/oneflow/core/kernel/cast_kernel.h index 381df7f1d61986df03100506606512a2cca886bf..259397c785ea1d0064ccc280e378e8ad742904c7 100644 --- a/oneflow/core/kernel/cast_kernel.h +++ b/oneflow/core/kernel/cast_kernel.h @@ -5,7 +5,8 @@ namespace oneflow { -class CastKernel final : public KernelIf<DeviceType::kCPU> { +template<DeviceType device_type> +class CastKernel final : public KernelIf<device_type> { public: OF_DISALLOW_COPY_AND_MOVE(CastKernel); CastKernel() = default; @@ -15,9 +16,7 @@ class CastKernel final : public KernelIf<DeviceType::kCPU> { void ForwardDataContent(const KernelCtx&, std::function<Blob*(const std::string&)>) const override; void BackwardDataContent(const KernelCtx&, - std::function<Blob*(const std::string&)>) const override { - TODO(); - } + std::function<Blob*(const std::string&)>) const override; }; } // namespace oneflow diff --git a/oneflow/core/kernel/clone_kernel.cpp b/oneflow/core/kernel/clone_kernel.cpp index 3f1dedda242f93cfcd96f46d1c0e7fa9cd478900..a7e111f5e4a91519e889cb0d69547882543a58f2 100644 --- a/oneflow/core/kernel/clone_kernel.cpp +++ b/oneflow/core/kernel/clone_kernel.cpp @@ -21,7 +21,7 @@ void CloneKernel<device_type, T>::BackwardDataContent( size_t out_num = odbns.size(); if (out_num == 0) return; Blob* in_diff_blob = BnInOp2Blob(this->op_attribute().input_diff_bns(0)); - auto out_diff = [&](int32_t idx) { + auto out_diff = [=](int32_t idx) { return BnInOp2Blob(this->op_attribute().output_diff_bns(idx)); }; static const int kWidth = 8; diff --git a/oneflow/core/kernel/constant_kernel.cpp b/oneflow/core/kernel/constant_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91ae77632bacf04950b645c729b314a7aa873a31 --- /dev/null +++ b/oneflow/core/kernel/constant_kernel.cpp @@ -0,0 +1,23 @@ +#include "oneflow/core/kernel/constant_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void ConstantKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + if (is_init_) { return; } + CHECK(this->kernel_conf().has_constant_conf()); + const ConstantKernelConf& conf = this->kernel_conf().constant_conf(); + KernelUtil<device_type, T>::InitializeWithConf(ctx.device_ctx, conf.initializer(), + conf.random_seed(), BnInOp2Blob("out")); + is_init_ = true; +} + +template<DeviceType device_type, typename T> +const PbMessage& ConstantKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().constant_conf(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kConstantConf, ConstantKernel, ARITHMETIC_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/constant_kernel.h b/oneflow/core/kernel/constant_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4895ae245338b7a0bac0e086d06d59b162f10323 --- /dev/null +++ b/oneflow/core/kernel/constant_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_CONSTANT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_CONSTANT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class ConstantKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(ConstantKernel); + ConstantKernel() : is_init_(false) {} + ~ConstantKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override; + + mutable bool is_init_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_CONSTANT_KERNEL_H_ diff --git a/oneflow/core/kernel/debug_kernel.cpp b/oneflow/core/kernel/debug_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d561fc9edda88cacb99d9ba6feba9aafba54dd0e --- /dev/null +++ b/oneflow/core/kernel/debug_kernel.cpp @@ -0,0 +1,117 @@ +#include "oneflow/core/kernel/debug_kernel.h" +#include "oneflow/core/operator/op_conf.pb.h" +#include "oneflow/core/record/record.pb.h" +#include "oneflow/core/record/ofrecord_raw_decoder.h" +#include "oneflow/core/record/ofrecord_raw_encoder.h" + +namespace oneflow { + +namespace { + +size_t FeatureSize(const Feature& feature) { + if (feature.has_double_list()) { return feature.double_list().value_size(); } + if (feature.has_float_list()) { return feature.float_list().value_size(); } + if (feature.has_int32_list()) { return feature.int32_list().value_size(); } + UNIMPLEMENTED(); + return 0; +} + +template<typename T> +void Decode(Blob* blob, const Feature& feature) { + OFRecordDecoderImpl<EncodeCase::kRaw, T> decoder; + CHECK_EQ(blob->shape().elem_cnt(), FeatureSize(feature)); + decoder.ReadOneCol(nullptr, feature, BlobConf(), 0, blob->mut_dptr<T>(), blob->shape().elem_cnt(), + []() { return 0; }); +} + +template<typename T> +void EncodeAndDump(const Blob* blob, PersistentOutStream* out_stream) { + Feature feature; + OFRecordEncoderImpl<EncodeCase::kRaw, T> encoder; + encoder.EncodeBlob(nullptr, blob, &feature); + *out_stream << feature; + out_stream->Flush(); +} + +void InitConstFeature(Feature* feature, const std::string& filepath) { + PersistentInStream in_stream(SnapshotFS(), filepath); + int feature_size = -1; + CHECK_EQ(in_stream.Read(reinterpret_cast<char*>(&feature_size), sizeof(int64_t)), 0); + std::unique_ptr<char[]> buffer(new char[feature_size]); + CHECK_EQ(in_stream.Read(buffer.get(), feature_size), 0); + feature->ParseFromArray(buffer.get(), feature_size); +} + +} // namespace + +template<typename T> +void DebugKernel<T>::InitOutStream(std::unique_ptr<PersistentOutStream>* out_stream, + const ParallelContext* parallel_ctx, const std::string& dir) { + const auto& conf = this->op_conf().debug_conf(); + OfCallOnce(dir, SnapshotFS(), &fs::FileSystem::RecursivelyCreateDir); + int32_t part_name_suffix_length = conf.part_name_suffix_length(); + std::string num = std::to_string(parallel_ctx->parallel_id()); + int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0); + std::string file_path = + JoinPath(dir, conf.part_name_prefix() + std::string(zero_count, '0') + num); + out_stream->reset(new PersistentOutStream(SnapshotFS(), file_path)); +} + +template<typename T> +void DebugKernel<T>::VirtualKernelInit(const ParallelContext* parallel_ctx) { + const auto& conf = this->op_conf().debug_conf(); + if (conf.has_in_blob_dump_dir()) { + InitOutStream(&in_blob_out_stream_, parallel_ctx, conf.in_blob_dump_dir()); + } + if (conf.has_out_diff_blob_dump_dir()) { + InitOutStream(&out_diff_blob_out_stream_, parallel_ctx, conf.out_diff_blob_dump_dir()); + } + if (conf.has_const_out_feature_load_filepath()) { + InitConstFeature(&const_out_blob_feature_, conf.const_out_feature_load_filepath()); + } + if (conf.has_const_in_diff_feature_load_filepath()) { + InitConstFeature(&const_in_diff_blob_feature_, conf.const_in_diff_feature_load_filepath()); + } +} + +template<typename T> +void DebugKernel<T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + const auto& conf = this->op_conf().debug_conf(); + if (conf.has_in_blob_dump_dir()) { EncodeAndDump<T>(in_blob, in_blob_out_stream_.get()); } + if (conf.const_out_case() == DebugOpConf::ConstOutCase::CONST_OUT_NOT_SET) { + out_blob->CopyDataContentFrom(ctx.device_ctx, in_blob); + } else if (conf.has_const_out_feature_load_filepath()) { + Decode<T>(out_blob, const_out_blob_feature_); + } else if (conf.has_const_out_feature()) { + Decode<T>(out_blob, conf.const_out_feature()); + } else { + UNIMPLEMENTED(); + } +} + +template<typename T> +void DebugKernel<T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Blob* in_diff_blob = BnInOp2Blob("in_diff"); + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + const auto& conf = this->op_conf().debug_conf(); + if (conf.has_out_diff_blob_dump_dir()) { + EncodeAndDump<T>(out_diff_blob, out_diff_blob_out_stream_.get()); + } + if (conf.const_in_diff_case() == DebugOpConf::ConstInDiffCase::CONST_IN_DIFF_NOT_SET) { + in_diff_blob->CopyDataContentFrom(ctx.device_ctx, out_diff_blob); + } else if (conf.has_const_in_diff_feature_load_filepath()) { + Decode<T>(in_diff_blob, const_in_diff_blob_feature_); + } else if (conf.has_const_in_diff_feature()) { + Decode<T>(in_diff_blob, conf.const_in_diff_feature()); + } else { + UNIMPLEMENTED(); + } +} + +ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kDebugConf, DebugKernel, ARITHMETIC_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/debug_kernel.h b/oneflow/core/kernel/debug_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..45dccfafe650d8ef426fd4ba962d07201eb2443d --- /dev/null +++ b/oneflow/core/kernel/debug_kernel.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_KERNEL_DEBUG_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_DEBUG_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<typename T> +class DebugKernel final : public KernelIf<DeviceType::kCPU> { + public: + OF_DISALLOW_COPY_AND_MOVE(DebugKernel); + DebugKernel() = default; + ~DebugKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void InitOutStream(std::unique_ptr<PersistentOutStream>* out_stream, + const ParallelContext* parallel_ctx, const std::string& dir); + void VirtualKernelInit(const ParallelContext* parallel_ctx); + + std::unique_ptr<PersistentOutStream> in_blob_out_stream_; + std::unique_ptr<PersistentOutStream> out_diff_blob_out_stream_; + Feature const_out_blob_feature_; + Feature const_in_diff_blob_feature_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_DEBUG_KERNEL_H_ diff --git a/oneflow/core/kernel/dot_kernel.cpp b/oneflow/core/kernel/dot_kernel.cpp index 43c5ef3c818df80bd238f7c0e0e29c0dfd0dd4b0..c08acfcf8d467abe146d684a485993d2e19998d6 100644 --- a/oneflow/core/kernel/dot_kernel.cpp +++ b/oneflow/core/kernel/dot_kernel.cpp @@ -19,7 +19,7 @@ void DotKernel<device_type, T>::ForwardDataContent( KernelUtil<device_type, T>::RowSum(ctx.device_ctx, piece_size, dim, tmp_blob->dptr<T>(), out_blob->mut_dptr<T>(), tmp_storage_blob->mut_dptr<T>(), sizeof(T) * piece_size * dim); - if (this->op_conf().matmul_conf().has_bias()) { + if (this->op_conf().dot_conf().has_bias()) { const Blob* bias_blob = BnInOp2Blob("bias"); // out += bias KernelUtil<device_type, T>::Axpy(ctx.device_ctx, piece_size, OneVal<T>::value, @@ -49,7 +49,7 @@ void DotKernel<device_type, T>::BackwardDataContent( tmp_blob->dptr<T>(), weight_blob->dptr<T>(), in_diff_blob->mut_dptr<T>()); - if (this->op_conf().matmul_conf().has_bias()) { + if (this->op_conf().dot_conf().has_bias()) { Blob* bias_diff_blob = BnInOp2Blob("bias_diff"); // bias_diff = out_diff KernelUtil<device_type, T>::Copy(ctx.device_ctx, out_diff_blob->shape().elem_cnt(), diff --git a/oneflow/core/kernel/dropout_kernel.cpp b/oneflow/core/kernel/dropout_kernel.cpp index da2f118db72501687307c1dd4df68af12e8c8597..5bc54419e7282def5b54aecbac6f31f981cb6bf3 100644 --- a/oneflow/core/kernel/dropout_kernel.cpp +++ b/oneflow/core/kernel/dropout_kernel.cpp @@ -37,7 +37,7 @@ void DropoutKernel<device_type, T>::BackwardDataContent( } template<DeviceType device_type, typename T> -void DropoutKernel<device_type, T>::Dropout(DeviceCtx* ctx, const int64_t n, double dropout_rate, +void DropoutKernel<device_type, T>::Dropout(DeviceCtx* ctx, const int64_t n, float dropout_rate, const T* x, float* random_mask, T* y) const { random_generator_->Uniform(n, random_mask); DropoutKernelUtil<device_type, T>::MaskAndScale(ctx, n, dropout_rate, 1 / (1 - dropout_rate), x, @@ -46,7 +46,7 @@ void DropoutKernel<device_type, T>::Dropout(DeviceCtx* ctx, const int64_t n, dou template<DeviceType device_type, typename T> void DropoutKernel<device_type, T>::DropoutBackward(DeviceCtx* ctx, const int64_t n, - double dropout_rate, const T* dy, + float dropout_rate, const T* dy, const float* random_mask, T* dx) const { DropoutKernelUtil<device_type, T>::MaskAndScale(ctx, n, dropout_rate, 1 / (1 - dropout_rate), dy, random_mask, dx); @@ -54,7 +54,7 @@ void DropoutKernel<device_type, T>::DropoutBackward(DeviceCtx* ctx, const int64_ template<typename T> struct DropoutKernelUtil<DeviceType::kCPU, T> final { - static void MaskAndScale(DeviceCtx* ctx, const int64_t n, double threshold, double scale, + static void MaskAndScale(DeviceCtx* ctx, const int64_t n, float threshold, float scale, const T* x, const float* random_mask, T* y) { for (int64_t i = 0; i < n; ++i) { y[i] = x[i] * (random_mask[i] > threshold) * scale; } } diff --git a/oneflow/core/kernel/dropout_kernel.cu b/oneflow/core/kernel/dropout_kernel.cu index c33c535f739c2a3306daf20cd90bd78214017cea..10e6a4ff932d0efa302b1b14c108e036088db7ca 100644 --- a/oneflow/core/kernel/dropout_kernel.cu +++ b/oneflow/core/kernel/dropout_kernel.cu @@ -8,7 +8,7 @@ namespace oneflow { namespace { template<typename T> -__global__ void MaskAndScaleGpu(const int64_t n, double threshold, double scale, const T* x, +__global__ void MaskAndScaleGpu(const int64_t n, float threshold, float scale, const T* x, const float* random_mask, T* y) { CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * (random_mask[i] > threshold) * scale; } } @@ -17,7 +17,7 @@ __global__ void MaskAndScaleGpu(const int64_t n, double threshold, double scale, template<typename T> struct DropoutKernelUtil<DeviceType::kGPU, T> final { - static void MaskAndScale(DeviceCtx* ctx, const int64_t n, double threshold, double scale, + static void MaskAndScale(DeviceCtx* ctx, const int64_t n, float threshold, float scale, const T* x, const float* random_mask, T* y) { MaskAndScaleGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( n, threshold, scale, x, random_mask, y); diff --git a/oneflow/core/kernel/dropout_kernel.h b/oneflow/core/kernel/dropout_kernel.h index e181d1f1218fa8e0995976af28fe509c642f49e0..c40c7df50d9f6bddae7125aaa7f7aa0d4cb7860c 100644 --- a/oneflow/core/kernel/dropout_kernel.h +++ b/oneflow/core/kernel/dropout_kernel.h @@ -23,10 +23,10 @@ class DropoutKernel final : public KernelIf<device_type> { // random_mask = random_uniform(0, 1) // y = dropout(x, random_mask, dropout_rate) - void Dropout(DeviceCtx* ctx, const int64_t n, double dropout_rate, const T* x, float* random_mask, + void Dropout(DeviceCtx* ctx, const int64_t n, float dropout_rate, const T* x, float* random_mask, T* y) const; // y = dropout(x, random_mask) - void DropoutBackward(DeviceCtx* ctx, const int64_t n, double dropout_rate, const T* dy, + void DropoutBackward(DeviceCtx* ctx, const int64_t n, float dropout_rate, const T* dy, const float* random_mask, T* dx) const; std::unique_ptr<RandomGenerator<device_type>> random_generator_; @@ -34,7 +34,7 @@ class DropoutKernel final : public KernelIf<device_type> { template<DeviceType device_type, typename T> struct DropoutKernelUtil final { - static void MaskAndScale(DeviceCtx* ctx, const int64_t n, double threshold, double scale, + static void MaskAndScale(DeviceCtx* ctx, const int64_t n, float threshold, float scale, const T* x, const float* random_mask, T* y); }; diff --git a/oneflow/core/kernel/gather_kernel.cpp b/oneflow/core/kernel/gather_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e0181e9ae107ae95743a9537408a636fd2bd731 --- /dev/null +++ b/oneflow/core/kernel/gather_kernel.cpp @@ -0,0 +1,113 @@ +#include "oneflow/core/kernel/gather_kernel.h" + +namespace oneflow { + +namespace { + +Shape GetFlatShape(const Shape& shape, int64_t axis) { + CHECK_GT(shape.NumAxes(), 0); + CHECK_GE(axis, 0); + CHECK_LT(axis, shape.NumAxes()); + return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); +} + +template<DeviceType device_type, typename T, typename K> +void GatherForward(DeviceCtx* ctx, const Blob* indices, const Blob* in, int64_t axis, Blob* out) { + const Shape flat_in_shape = GetFlatShape(in->shape(), axis); + GatherKernelUtil<device_type, T, K>::Forward(ctx, indices->dptr<K>(), indices->shape().elem_cnt(), + in->dptr<T>(), flat_in_shape, out->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T, typename K> +void GatherBackward(DeviceCtx* ctx, const Blob* indices, const Blob* out_diff, int64_t axis, + Blob* in_diff) { + Memset<device_type>(ctx, in_diff->mut_dptr<T>(), 0, in_diff->ByteSizeOfDataContentField()); + const Shape flat_in_shape = GetFlatShape(in_diff->shape(), axis); + GatherKernelUtil<device_type, T, K>::Backward(ctx, indices->dptr<K>(), + indices->shape().elem_cnt(), out_diff->dptr<T>(), + flat_in_shape, in_diff->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +struct GatherSwitchUtil final { +#define MAKE_GATHER_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K> +#define DEFINE_GATHER_STATIC_SWITCH_FUNC(func_name) \ + DEFINE_STATIC_SWITCH_FUNC(void, func_name, MAKE_GATHER_SWITCH_ENTRY, \ + MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); + DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherForward); + DEFINE_GATHER_STATIC_SWITCH_FUNC(GatherBackward); +#undef DEFINE_GATHER_STATIC_SWITCH_FUNC +#undef MAKE_GATHER_SWITCH_ENTRY +}; + +} // namespace + +template<DeviceType device_type, typename T> +const PbMessage& GatherKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().gather_conf(); +} + +template<DeviceType device_type, typename T> +void GatherKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + GatherSwitchUtil<device_type, T>::SwitchGatherForward( + SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, BnInOp2Blob("indices"), + BnInOp2Blob("in"), this->kernel_conf().gather_conf().axis(), BnInOp2Blob("out")); +} + +template<DeviceType device_type, typename T> +void GatherKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + GatherSwitchUtil<device_type, T>::SwitchGatherBackward( + SwitchCase(BnInOp2Blob("indices")->data_type()), ctx.device_ctx, BnInOp2Blob("indices"), + BnInOp2Blob(GenDiffBn("out")), this->kernel_conf().gather_conf().axis(), + BnInOp2Blob(GenDiffBn("in"))); +} + +template<typename T, typename K> +struct GatherKernelUtil<DeviceType::kCPU, T, K> final { + static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in, + const Shape& flat_in_shape, T* out); + static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff, + const Shape& flat_in_shape, T* in_diff); +}; + +template<typename T, typename K> +void GatherKernelUtil<DeviceType::kCPU, T, K>::Forward(DeviceCtx* ctx, const K* indices, + int64_t num_indices, const T* in, + const Shape& flat_in_shape, T* out) { + const int64_t outer_dim_size = flat_in_shape.At(0); + const int64_t gather_dim_size = flat_in_shape.At(1); + const int64_t inner_dim_size = flat_in_shape.At(2); + FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) { + FOR_RANGE(int64_t, i, 0, num_indices) { + const int64_t idx = indices[i]; + CHECK(idx >= 0 && idx < gather_dim_size); + const T* from = in + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size; + T* to = out + outer_idx * num_indices * inner_dim_size + i * inner_dim_size; + std::copy(from, from + inner_dim_size, to); + } + } +} + +template<typename T, typename K> +void GatherKernelUtil<DeviceType::kCPU, T, K>::Backward(DeviceCtx* ctx, const K* indices, + int64_t num_indices, const T* out_diff, + const Shape& flat_in_shape, T* in_diff) { + const int64_t outer_dim_size = flat_in_shape.At(0); + const int64_t gather_dim_size = flat_in_shape.At(1); + const int64_t inner_dim_size = flat_in_shape.At(2); + FOR_RANGE(int64_t, outer_idx, 0, outer_dim_size) { + FOR_RANGE(int64_t, i, 0, num_indices) { + const int64_t idx = indices[i]; + CHECK(idx >= 0 && idx < gather_dim_size); + const T* from = out_diff + outer_idx * num_indices * inner_dim_size + i * inner_dim_size; + T* to = in_diff + outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size; + std::transform(from, from + inner_dim_size, to, to, std::plus<T>()); + } + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGatherConf, GatherKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/gather_kernel.cu b/oneflow/core/kernel/gather_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2f4bddfbab73cde764fce7088c2dec5555ded853 --- /dev/null +++ b/oneflow/core/kernel/gather_kernel.cu @@ -0,0 +1,79 @@ +#include "oneflow/core/kernel/gather_kernel.h" +#include "oneflow/core/kernel/kernel_util.cuh" +#include <assert.h> + +namespace oneflow { + +namespace { + +template<typename K> +__device__ int64_t get_in_offset(int64_t out_offset, const K* indices, int64_t num_indices, + int64_t gather_dim_size, int64_t inner_dim_size) { + const int64_t outer_dim_elem_cnt = num_indices * inner_dim_size; + const int64_t outer_idx = out_offset / outer_dim_elem_cnt; + const int64_t indices_idx = out_offset % outer_dim_elem_cnt / inner_dim_size; + const int64_t inner_idx = out_offset % inner_dim_size; + const int64_t idx = indices[indices_idx]; + assert(idx >= 0 && idx < gather_dim_size); + return outer_idx * gather_dim_size * inner_dim_size + idx * inner_dim_size + inner_idx; +} + +template<typename T, typename K> +__global__ void GatherForwardGpu(int64_t elem_cnt, const K* indices, int64_t num_indices, + const T* in, int64_t gather_dim_size, int64_t inner_dim_size, + T* out) { + CUDA_1D_KERNEL_LOOP(i, elem_cnt) { + out[i] = in[get_in_offset<K>(i, indices, num_indices, gather_dim_size, inner_dim_size)]; + } +} + +template<typename T, typename K> +__global__ void GatherBackwardGpu(int64_t elem_cnt, const K* indices, int64_t num_indices, + const T* out_diff, int64_t gather_dim_size, + int64_t inner_dim_size, T* in_diff) { + CUDA_1D_KERNEL_LOOP(i, elem_cnt) { + gpu_atomic_add( + in_diff + get_in_offset<K>(i, indices, num_indices, gather_dim_size, inner_dim_size), + out_diff[i]); + } +} + +} // namespace + +template<typename T, typename K> +struct GatherKernelUtil<DeviceType::kGPU, T, K> final { + static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in, + const Shape& flat_in_shape, T* out); + static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff, + const Shape& flat_in_shape, T* in_diff); +}; + +template<typename T, typename K> +void GatherKernelUtil<DeviceType::kGPU, T, K>::Forward(DeviceCtx* ctx, const K* indices, + int64_t num_indices, const T* in, + const Shape& flat_in_shape, T* out) { + const int64_t elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2); + GatherForwardGpu<T, K> + <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + elem_cnt, indices, num_indices, in, flat_in_shape.At(1), flat_in_shape.At(2), out); +} + +template<typename T, typename K> +void GatherKernelUtil<DeviceType::kGPU, T, K>::Backward(DeviceCtx* ctx, const K* indices, + int64_t num_indices, const T* out_diff, + const Shape& flat_in_shape, T* in_diff) { + const int64_t elem_cnt = flat_in_shape.At(0) * num_indices * flat_in_shape.At(2); + GatherBackwardGpu<T, K> + <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + elem_cnt, indices, num_indices, out_diff, flat_in_shape.At(1), flat_in_shape.At(2), + in_diff); +} + +#define MAKE_GATHER_KERNEL_UTIL_ENTRY(in_type_pair, index_type_pair) \ + template struct GatherKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(in_type_pair), \ + OF_PP_PAIR_FIRST(index_type_pair)>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_GATHER_KERNEL_UTIL_ENTRY, FLOATING_DATA_TYPE_SEQ, + INT_DATA_TYPE_SEQ); +#undef MAKE_GATHER_KERNEL_UTIL_ENTRY + +} // namespace oneflow diff --git a/oneflow/core/kernel/gather_kernel.h b/oneflow/core/kernel/gather_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ae711a7f64d274fbd2a40c6e37d357ab65d1b367 --- /dev/null +++ b/oneflow/core/kernel/gather_kernel.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_KERNEL_GATHER_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_GATHER_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class GatherKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(GatherKernel); + GatherKernel() = default; + ~GatherKernel() override = default; + + private: + const PbMessage& GetCustomizedOpConf() const override; + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BackwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +template<DeviceType device_type, typename T, typename K> +struct GatherKernelUtil final { + static void Forward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* in, + const Shape& flat_in_shape, T* out); + static void Backward(DeviceCtx* ctx, const K* indices, int64_t num_indices, const T* out_diff, + const Shape& flat_in_shape, T* in_diff); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_GATHER_KERNEL_H_ diff --git a/oneflow/core/kernel/gelu_kernel.cpp b/oneflow/core/kernel/gelu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be1b273a70af24914372e3e76ae967f034c9e818 --- /dev/null +++ b/oneflow/core/kernel/gelu_kernel.cpp @@ -0,0 +1,52 @@ +#include "oneflow/core/kernel/gelu_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void GeluKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + GeluKernelUtil<device_type, T>::GeluForward(ctx.device_ctx, in_blob->static_shape().elem_cnt(), + in_blob->dptr<T>(), + BnInOp2Blob("out")->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void GeluKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + GeluKernelUtil<device_type, T>::GeluBackward( + ctx.device_ctx, in_blob->static_shape().elem_cnt(), in_blob->dptr<T>(), + BnInOp2Blob("out_diff")->dptr<T>(), BnInOp2Blob("in_diff")->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +const PbMessage& GeluKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().gelu_conf(); +} + +template<typename T> +struct GeluKernelUtil<DeviceType::kCPU, T> { + static void GeluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { + T inv_sqrt2 = std::sqrt(0.5); + FOR_RANGE(int32_t, i, 0, n) { y[i] = 0.5 * x[i] * (1.0 + std::erf(inv_sqrt2 * x[i])); } + } + + static void GeluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* dy, T* dx) { + T inv_sqrt2 = std::sqrt(0.5); + T coef = std::sqrt(2.0 / std::acos(-1.0)); + FOR_RANGE(int32_t, i, 0, n) { + dx[i] = 0.5 * (1.0 + std::erf(inv_sqrt2 * x[i]) + x[i] * coef * std::exp(-0.5 * x[i] * x[i])) + * dy[i]; + } + } +}; + +#define INSTANTIATE_GELU_KERNEL_UTIL(type_cpp, type_proto) \ + template struct GeluKernelUtil<DeviceType::kCPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GELU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ) + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kGeluConf, GeluKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/gelu_kernel.cu b/oneflow/core/kernel/gelu_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a36e506eb07eb5b982d3fc82d523c87a633ee387 --- /dev/null +++ b/oneflow/core/kernel/gelu_kernel.cu @@ -0,0 +1,69 @@ +#include "oneflow/core/kernel/gelu_kernel.h" +#include "oneflow/core/kernel/kernel_util.cuh" + +namespace oneflow { + +namespace { + +template<typename T> +__global__ void GeluForwardGpu(const int64_t n, const T* x, const T inv_sqrt2, T* y) { + UNIMPLEMENTED(); +} + +template<typename T> +__global__ void GeluBackwardGpu(const int64_t n, const T* x, const T* dy, const T inv_sqrt2, + const T coef, T* dx) { + UNIMPLEMENTED(); +} + +template<> +__global__ void GeluForwardGpu(const int64_t n, const float* x, const float inv_sqrt2, float* y) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = 0.5f * x[i] * (1.0f + erff(inv_sqrt2 * x[i])); } +} + +template<> +__global__ void GeluBackwardGpu(const int64_t n, const float* x, const float* dy, + const float inv_sqrt2, const float coef, float* dx) { + CUDA_1D_KERNEL_LOOP(i, n) { + dx[i] = + 0.5f * (1.0f + erff(inv_sqrt2 * x[i]) + x[i] * coef * expf(-0.5f * x[i] * x[i])) * dy[i]; + } +} + +template<> +__global__ void GeluForwardGpu(const int64_t n, const double* x, const double inv_sqrt2, + double* y) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = 0.5 * x[i] * (1.0 + erf(inv_sqrt2 * x[i])); } +} + +template<> +__global__ void GeluBackwardGpu(const int64_t n, const double* x, const double* dy, + const double inv_sqrt2, const double coef, double* dx) { + CUDA_1D_KERNEL_LOOP(i, n) { + dx[i] = 0.5 * (1.0 + erf(inv_sqrt2 * x[i]) + x[i] * coef * exp(-0.5 * x[i] * x[i])) * dy[i]; + } +} + +} // namespace + +template<typename T> +struct GeluKernelUtil<DeviceType::kGPU, T> { + static void GeluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { + const T inv_sqrt2 = sqrt(0.5); + GeluForwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + n, x, inv_sqrt2, y); + } + + static void GeluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* dy, T* dx) { + const T inv_sqrt2 = sqrt(0.5); + const T coef = sqrt(2.0 / acos(-1.0)); + GeluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + n, x, dy, inv_sqrt2, coef, dx); + } +}; + +#define INSTANTIATE_GELU_KERNEL_UTIL(type_cpp, type_proto) \ + template struct GeluKernelUtil<DeviceType::kGPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GELU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/kernel/gelu_kernel.h b/oneflow/core/kernel/gelu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a6859c6009086c253ef792c0cc6b08585193f6f2 --- /dev/null +++ b/oneflow/core/kernel/gelu_kernel.h @@ -0,0 +1,32 @@ +#ifndef ONEFLOW_CORE_KERNEL_GELU_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_GELU_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class GeluKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(GeluKernel); + GeluKernel() = default; + ~GeluKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + const PbMessage& GetCustomizedOpConf() const override; +}; + +template<DeviceType device_type, typename T> +struct GeluKernelUtil { + static void GeluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y); + + static void GeluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* dy, T* dx); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_GELU_KERNEL_H_ diff --git a/oneflow/core/kernel/hinge_loss_kernel.h b/oneflow/core/kernel/hinge_loss_kernel.h index befa52b31f2e9dc80873fc3e04556071dd604d2f..606e996880d8cf20c4647e76a57f759ecb6fedd0 100644 --- a/oneflow/core/kernel/hinge_loss_kernel.h +++ b/oneflow/core/kernel/hinge_loss_kernel.h @@ -6,7 +6,7 @@ namespace oneflow { template<DeviceType device_type, typename PredType, typename LabelType> -class HingeLossKernel final : public LossKernel<device_type, PredType, LabelType> { +class HingeLossKernel final : public LossKernel<device_type, PredType> { public: OF_DISALLOW_COPY_AND_MOVE(HingeLossKernel); HingeLossKernel() = default; diff --git a/oneflow/core/kernel/identity_kernel.cpp b/oneflow/core/kernel/identity_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a76b5c8fabc1f191c99aa9a5c01581294af3ff07 --- /dev/null +++ b/oneflow/core/kernel/identity_kernel.cpp @@ -0,0 +1,19 @@ +#include "oneflow/core/kernel/identity_kernel.h" + +namespace oneflow { + +template<DeviceType device_type> +void IdentityKernel<device_type>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in")); +} + +template<DeviceType device_type> +void IdentityKernel<device_type>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + BnInOp2Blob(GenDiffBn("in"))->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(GenDiffBn("out"))); +} + +ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kParallelCastConf, IdentityKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/identity_kernel.h b/oneflow/core/kernel/identity_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..22eef17ef6e6240d8389adc4d89904e478a1b498 --- /dev/null +++ b/oneflow/core/kernel/identity_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_IDENTITY_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_IDENTITY_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<DeviceType device_type> +class IdentityKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(IdentityKernel); + IdentityKernel() = default; + ~IdentityKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_IDENTITY_KERNEL_H_ diff --git a/oneflow/core/kernel/identity_loss_kernel.cpp b/oneflow/core/kernel/identity_loss_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02345d8307c15896e21413213907c0f1b28efec3 --- /dev/null +++ b/oneflow/core/kernel/identity_loss_kernel.cpp @@ -0,0 +1,39 @@ +#include "oneflow/core/kernel/identity_loss_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename PredType> +void IdentityLossKernel<device_type, PredType>::InitConstBufBlobs( + DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + InitializerConf initializer; + initializer.mutable_constant_conf()->set_value(1.0); + KernelUtil<device_type, PredType>::InitializeWithConf(ctx, initializer, 0, BnInOp2Blob("ones")); +} + +template<DeviceType device_type, typename PredType> +void IdentityLossKernel<device_type, PredType>::VirtualLossForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* prediction = BnInOp2Blob("prediction"); + const Blob* ones = BnInOp2Blob("ones"); + Blob* loss = BnInOp2Blob("loss"); + Blob* prediction_diff = BnInOp2Blob(GenDiffBn("prediction")); + loss->CopyDataContentFrom(ctx.device_ctx, prediction); + prediction_diff->CopyDataContentFrom(ctx.device_ctx, ones); +} + +template<DeviceType device_type, typename PredType> +const LossKernelConf& IdentityLossKernel<device_type, PredType>::GetLossKernelConf( + const KernelConf& kernel_conf) const { + return kernel_conf.identity_loss_conf().loss_conf(); +} + +template<DeviceType device_type, typename PredType> +int64_t IdentityLossKernel<device_type, PredType>::CalcLossInstanceNum( + const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const { + return BnInOp2Blob("prediction")->shape().elem_cnt(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kIdentityLossConf, IdentityLossKernel, + FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/identity_loss_kernel.h b/oneflow/core/kernel/identity_loss_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1f3ae8c62953e2551e1ce22130a321d3bee730e7 --- /dev/null +++ b/oneflow/core/kernel/identity_loss_kernel.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_KERNEL_IDENTITY_LOSS_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_IDENTITY_LOSS_KERNEL_H_ + +#include "oneflow/core/kernel/loss_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename PredType> +class IdentityLossKernel final : public LossKernel<device_type, PredType> { + public: + OF_DISALLOW_COPY_AND_MOVE(IdentityLossKernel); + IdentityLossKernel() = default; + ~IdentityLossKernel() = default; + + private: + void VirtualLossForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const override; + void InitConstBufBlobs(DeviceCtx* ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const; + int64_t CalcLossInstanceNum(const KernelCtx& ctx, + const std::function<Blob*(const std::string&)>& BnInOp2Blob) const; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_IDENTITY_LOSS_KERNEL_H_ diff --git a/oneflow/core/kernel/kernel.cpp b/oneflow/core/kernel/kernel.cpp index 95f0ab8095b9c0f8641723da3e1ad182a5d4aa89..ac473c14cad2b43a8280f3d0910959a973cba740 100644 --- a/oneflow/core/kernel/kernel.cpp +++ b/oneflow/core/kernel/kernel.cpp @@ -25,6 +25,44 @@ void ClearBlobDim0ValidNumIfNeed(const PbRpf<std::string>& bns, } } +void CheckLossInstanceNumField(const PbRpf<std::string>& bns, + const std::function<Blob*(const std::string&)>& BnInOp2Blob, + bool expected) { + for (const std::string& bn : bns) { + const Blob* blob = BnInOp2Blob(bn); + if (blob != nullptr) { CHECK_EQ(blob->has_loss_instance_num_field(), expected); } + } +} + +bool NeedCopyLossInstanceNum(const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns, + const std::function<Blob*(const std::string&)>& BnInOp2Blob) { + const auto& first_bn_has_loss_instance_num_it = + std::find_if(from_bns.cbegin(), from_bns.cend(), [&BnInOp2Blob](const std::string& bn) { + const Blob* blob = BnInOp2Blob(bn); + return blob != nullptr && blob->has_loss_instance_num_field(); + }); + const bool need_copy_loss_instance_num = first_bn_has_loss_instance_num_it != from_bns.end(); + CheckLossInstanceNumField(from_bns, BnInOp2Blob, need_copy_loss_instance_num); + CheckLossInstanceNumField(to_bns, BnInOp2Blob, need_copy_loss_instance_num); + return need_copy_loss_instance_num; +} + +void NaiveCopyLossInstanceNum(const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns, + const std::function<Blob*(const std::string&)>& BnInOp2Blob) { + CHECK_GT(from_bns.size(), 0); + CHECK(BnInOp2Blob(from_bns.Get(0))->has_loss_instance_num_field()); + const float loss_instance_num = BnInOp2Blob(from_bns.Get(0))->loss_instance_num(); + const float loss_instance_num_epsilon = 1e-8; + FOR_RANGE(int32_t, i, 1, from_bns.size()) { + CHECK_LT(std::fabs(BnInOp2Blob(from_bns.Get(i))->loss_instance_num() - loss_instance_num), + loss_instance_num_epsilon); + } + FOR_RANGE(int32_t, i, 0, to_bns.size()) { + Blob* blob = BnInOp2Blob(to_bns.Get(i)); + if (blob != nullptr) { blob->set_loss_instance_num(loss_instance_num); } + } +} + } // namespace void Kernel::Init(const ParallelContext* parallel_ctx, const KernelConf& kernel_conf, @@ -94,6 +132,7 @@ void Kernel::Forward(const KernelCtx& ctx, CHECK(!kernel_conf_.need_do_opaque_header()); ForwardDim0ValidNum(ctx, BnInOp2Blob); } + if (NeedForwardLossInstanceNum(ctx, BnInOp2Blob)) { ForwardLossInstanceNum(ctx, BnInOp2Blob); } if (HasEmptyShapeBlob(op_attribute().input_bns(), BnInOp2Blob) && !NeedForwardIfBlobEmpty()) { ClearBlobDim0ValidNumIfNeed(op_attribute().output_bns(), BnInOp2Blob); return; @@ -135,6 +174,7 @@ void Kernel::Backward(const KernelCtx& ctx, CHECK(!kernel_conf_.need_do_opaque_header()); BackwardInDiffDim0ValidNum(ctx, BnInOp2Blob); } + BackwardInDiffLossInstanceNum(ctx, BnInOp2Blob); if (HasEmptyShapeBlob(op_attribute().output_diff_bns(), BnInOp2Blob) && !NeedBackwardIfBlobEmpty()) { ClearBlobDim0ValidNumIfNeed(op_attribute().input_diff_bns(), BnInOp2Blob); @@ -204,6 +244,19 @@ void KernelIf<device_type>::ForwardRecordIdInDevicePiece( op_attribute().output_bns(), &Blob::CopyRecordIdInDevicePieceFrom); } +template<DeviceType device_type> +void KernelIf<device_type>::ForwardLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + NaiveCopyLossInstanceNum(op_attribute().input_bns(), op_attribute().output_bns(), BnInOp2Blob); +} + +template<DeviceType device_type> +bool KernelIf<device_type>::NeedForwardLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + return NeedCopyLossInstanceNum(op_attribute().input_bns(), op_attribute().output_bns(), + BnInOp2Blob); +} + template<DeviceType device_type> void KernelIf<device_type>::BackwardModelDiffDim0ValidNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { @@ -233,6 +286,15 @@ void KernelIf<device_type>::BackwardInDiffDim0ValidNum( input_diff_bns, &Blob::CopyDim0ValidNumFrom); } +template<DeviceType device_type> +void KernelIf<device_type>::BackwardInDiffLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + CHECK(NeedCopyLossInstanceNum(op_attribute().output_diff_bns(), op_attribute().input_diff_bns(), + BnInOp2Blob)); + NaiveCopyLossInstanceNum(op_attribute().output_diff_bns(), op_attribute().input_diff_bns(), + BnInOp2Blob); +} + template<DeviceType device_type> void KernelIf<device_type>::ForwardPackedHeader( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { @@ -257,10 +319,10 @@ template<DeviceType device_type, typename T> void KernelIfWithModel<device_type, T>::SetTotalInstanceNumDiffBlob( const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const { CHECK_GE(this->op_attribute().model_bns().size(), 2); - int64_t dim0_valid_num_sum = - BnInOp2Blob(this->op_attribute().output_diff_bns(0))->CalcDim0ValidNumSum(); + const float loss_instance_num = + BnInOp2Blob(this->op_attribute().output_diff_bns(0))->loss_instance_num(); Blob* total_instance_num_diff_blob = BnInOp2Blob("total_instance_num_diff"); - KernelUtil<device_type, T>::Set(ctx.device_ctx, static_cast<T>(dim0_valid_num_sum), + KernelUtil<device_type, T>::Set(ctx.device_ctx, static_cast<T>(loss_instance_num), total_instance_num_diff_blob->mut_dptr<T>()); } diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h index 9a008a640acd11420cbe3efade654c45afaf8d7e..8884b1409b3aebc855dfcb6c235ae44f66d8e3a5 100644 --- a/oneflow/core/kernel/kernel.h +++ b/oneflow/core/kernel/kernel.h @@ -76,6 +76,14 @@ class Kernel { const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { UNIMPLEMENTED(); } + virtual bool NeedForwardLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + UNIMPLEMENTED(); + } + virtual void ForwardLossInstanceNum(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + UNIMPLEMENTED(); + } virtual void ForwardPackedHeader(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { UNIMPLEMENTED(); @@ -100,6 +108,10 @@ class Kernel { const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { UNIMPLEMENTED(); } + virtual void BackwardInDiffLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + UNIMPLEMENTED(); + } virtual void BackwardModelDiffDim0ValidNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { UNIMPLEMENTED(); @@ -158,7 +170,10 @@ class KernelIf : public Kernel { const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; virtual void ForwardRecordIdInDevicePiece( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - + virtual void ForwardLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + virtual bool NeedForwardLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; virtual void ForwardPackedHeader( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; virtual void BackwardDataId(const KernelCtx& ctx, @@ -167,6 +182,8 @@ class KernelIf : public Kernel { std::function<Blob*(const std::string&)> BnInOp2Blob) const override; virtual void BackwardInDiffDim0ValidNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + virtual void BackwardInDiffLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; virtual void BackwardModelDiffDim0ValidNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; void CopyField(DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob, diff --git a/oneflow/core/kernel/kernel.proto b/oneflow/core/kernel/kernel.proto index 7027c63408b716fafa1be8e6c8905734f5eb6bb3..ed19b1817491e0a0d6bd04736f45f9224f560c01 100644 --- a/oneflow/core/kernel/kernel.proto +++ b/oneflow/core/kernel/kernel.proto @@ -41,6 +41,14 @@ message SparseCrossEntropyLossKernelConf { required LossKernelConf loss_conf = 1; } +message SigmoidCrossEntropyLossKernelConf { + required LossKernelConf loss_conf = 1; +} + +message IdentityLossKernelConf { + required LossKernelConf loss_conf = 1; +} + message PoolingKernelConf { required ShapeProto in = 1; required ShapeProto out = 2; @@ -61,7 +69,7 @@ message MaxPoolingKernelConf { } message ReduceSumKernelConf { - optional int32 axis = 1; + required ShapeProto kept_dims_shape = 1; } message SoftmaxKernelConf { @@ -142,6 +150,28 @@ message HingeLossKernelConf { required LossKernelConf loss_conf = 1; } +message SliceKernelConf { + required ShapeProto in_shape = 1; +} + +message ConstantKernelConf { + required InitializerConf initializer = 1; + required uint32 random_seed = 2; +} + +message GatherKernelConf { + required int64 axis = 1; +} + +message VariableKernelConf { + required bool is_fw_inplace = 1; + required bool is_bw_inplace = 2; +} + +message RecordLoadKernelConf { + required int64 device_piece_size = 1; +} + message KernelConf { required OpAttribute op_attribute = 1; required bool is_forward = 2; @@ -177,5 +207,12 @@ message KernelConf { ReduceConcatKernelConf reduce_concat_conf = 351; ReduceSplitKernelConf reduce_split_conf = 352; AccuracyKernelConf accuracy_conf = 401; + SliceKernelConf slice_conf = 402; + ConstantKernelConf constant_conf = 403; + SigmoidCrossEntropyLossKernelConf sigmoid_cross_entropy_loss_conf = 404; + IdentityLossKernelConf identity_loss_conf = 405; + GatherKernelConf gather_conf = 406; + VariableKernelConf variable_conf = 407; + RecordLoadKernelConf record_load_conf = 408; } } diff --git a/oneflow/core/kernel/kernel_util.cpp b/oneflow/core/kernel/kernel_util.cpp index 86503c0b940c2d389498ea49a1fb008654177492..ca8eed6f5c7b4bf228d88a7463330005646c8cb2 100644 --- a/oneflow/core/kernel/kernel_util.cpp +++ b/oneflow/core/kernel/kernel_util.cpp @@ -13,7 +13,7 @@ void RngUniform(const int64_t elem_cnt, const T min, const T max, uint32_t rando CHECK(dptr); CHECK_LE(min, max); std::mt19937 generator(random_seed); - std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, MaxVal<T>())); + std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>())); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(generator); } } @@ -24,7 +24,7 @@ void RngIntUniform(const int64_t elem_cnt, const T min, const T max, uint32_t ra CHECK(dptr); CHECK_LE(min, max); std::mt19937 generator(random_seed); - std::uniform_int_distribution<T> random_distribution(min, std::nextafter(max, MaxVal<T>())); + std::uniform_int_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>())); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(generator); } } @@ -39,6 +39,24 @@ void RngNormal(const int64_t elem_cnt, const T mean, const T std, uint32_t rando for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(generator); } } +template<typename T> +void RngTruncatedNormal(const int64_t elem_cnt, const T std, uint32_t random_seed, T* dptr) { + CHECK_GE(elem_cnt, 0); + CHECK(dptr); + CHECK_GT(std, 0.0); + T truncated_value = 2 * std; + std::mt19937 generator(random_seed); + std::normal_distribution<T> random_distribution(0, std); + int64_t index = 0; + while (true) { + T val = random_distribution(generator); + if (std::abs(val) < truncated_value) { + dptr[index++] = val; + if (index >= elem_cnt) { break; } + } + } +} + template<typename T> void ConstantInitializer(const T& value, Blob* blob) { T* dptr = blob->mut_dptr<T>(); @@ -71,6 +89,14 @@ void RandomNormalInitializer(const RandomNormalInitializerConf& initializer_conf static_cast<T>(initializer_conf.std()), random_seed, blob->mut_dptr<T>()); } +template<typename T> +void TruncatedNormalInitializer(const TruncatedNormalInitializerConf& initializer_conf, + uint32_t random_seed, Blob* blob) { + CHECK(blob->shape().elem_cnt()); + RngTruncatedNormal<T>(blob->shape().elem_cnt(), static_cast<T>(initializer_conf.std()), + random_seed, blob->mut_dptr<T>()); +} + template<typename T> T GenInitialFan(VarianceNorm variance_norm, Blob* blob, const std::string& data_format) { T fan = ZeroVal<T>::value; @@ -116,6 +142,44 @@ void MsraInitializer(const MsraInitializerConf& initializer_conf, uint32_t rando blob->mut_dptr<T>()); } +template<typename T> +void RangeInitializer(int64_t outer_size, int64_t idx_dim_size, int64_t inner_size, T start, + T stride, T* out) { + FOR_RANGE(int64_t, i, 0, outer_size) { + FOR_RANGE(int64_t, j, 0, idx_dim_size) { + FOR_RANGE(int64_t, k, 0, inner_size) { + *(out + i * idx_dim_size * inner_size + j * inner_size + k) = start + j * stride; + } + } + } +} + +template<typename T, typename RangeInitializerConfT> +void RangeInitializer(const RangeInitializerConfT& initializer_conf, uint32_t random_seed, + Blob* blob) { + CHECK_GT(blob->shape().NumAxes(), 0); + const int64_t axis = initializer_conf.axis() < 0 + ? blob->shape().NumAxes() + initializer_conf.axis() + : initializer_conf.axis(); + CHECK_GE(axis, 0); + CHECK_LT(axis, blob->shape().NumAxes()); + RangeInitializer<T>(blob->shape().Count(0, axis), blob->shape().At(axis), + blob->shape().Count(axis + 1), static_cast<T>(initializer_conf.start()), + static_cast<T>(initializer_conf.stride()), blob->mut_dptr<T>()); +} + +template<typename T> +void RangeInitializer(const RangeInitializerConf& initializer_conf, uint32_t random_seed, + Blob* blob) { + RangeInitializer<T, RangeInitializerConf>(initializer_conf, random_seed, blob); +} + +template<typename T> +void IntSequenceInitializer(const IntRangeInitializerConf& initializer_conf, uint32_t random_seed, + Blob* blob) { + RangeInitializer<T, IntRangeInitializerConf>(initializer_conf, random_seed, blob); +} + void ComputeOffset(const int32_t num_axes, const int64_t* shape, const int32_t* permutation, std::vector<int64_t>& offset) { offset.resize(num_axes); @@ -157,6 +221,7 @@ void Memcpy<DeviceType::kCPU>(DeviceCtx* ctx, void* dst, const void* src, size_t #endif ) { + if (dst == src) { return; } memcpy(dst, src, sz); } @@ -271,6 +336,15 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num in_stream.Read(blob->mut_dptr<char>(), blob_size); } KU_IF_METHOD Set(DeviceCtx* ctx, const T value, T* addr) { *addr = value; } +KU_IF_METHOD Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x) { + for (int64_t i = 0; i < n; ++i) { y[i] = *x; } +} +KU_IF_METHOD AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) { + for (int64_t i = 0; i < n; ++i) { z[i] = x[i] + y; } +} +KU_IF_METHOD MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) { + for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y; } +} #define KU_FLOATING_METHOD \ template<typename T> \ @@ -307,6 +381,19 @@ KU_FLOATING_METHOD Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, const int ldc) { cblas_gemm<T>(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } +KU_FLOATING_METHOD BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, + const enum CBLAS_TRANSPOSE trans_b, int batch_size, int m, int n, + int k, const T alpha, const T* a, const T* b, const T beta, T* c, + T** buf) { + const int a_stride = m * k; + const int b_stride = k * n; + const int c_stride = m * n; + FOR_RANGE(int32_t, i, 0, batch_size) { + KernelUtil<DeviceType::kCPU, T>::OFGemm(ctx, trans_a, trans_b, m, n, k, alpha, a + i * a_stride, + b + i * b_stride, beta, c + i * c_stride); + } +} KU_FLOATING_METHOD Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { for (int64_t i = 0; i < n; ++i) { y[i] = std::exp(x[i]); } @@ -314,12 +401,27 @@ KU_FLOATING_METHOD Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha) { for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / (*alpha); } } -KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { - for (int64_t i = 0; i < n; ++i) { z[i] = x[i] / y[i]; } +KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha) { + for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / alpha; } } KU_FLOATING_METHOD Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y[i]; } } +KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { + for (int64_t i = 0; i < n; ++i) { z[i] = x[i] / y[i]; } +} +KU_FLOATING_METHOD Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { + for (int64_t i = 0; i < n; ++i) { y[i] = x[i] * x[i]; } +} +KU_FLOATING_METHOD Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { + for (int64_t i = 0; i < n; ++i) { y[i] = std::sqrt(x[i]); } +} +KU_FLOATING_METHOD MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { + for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y[0]; } +} +KU_FLOATING_METHOD Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y) { + for (int64_t i = 0; i < n; ++i) { y[i] = static_cast<T>(1.0) / x[i]; } +} KU_FLOATING_METHOD Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon) { for (int64_t i = 0; i < n; ++i) { x[i] = 1.0 / std::sqrt(x[i] + epsilon); } } @@ -411,10 +513,14 @@ KU_FLOATING_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& ini RandomUniformInitializer<T>(initializer_conf.random_uniform_conf(), random_seed, blob); } else if (initializer_conf.has_random_normal_conf()) { RandomNormalInitializer<T>(initializer_conf.random_normal_conf(), random_seed, blob); + } else if (initializer_conf.has_truncated_normal_conf()) { + TruncatedNormalInitializer<T>(initializer_conf.truncated_normal_conf(), random_seed, blob); } else if (initializer_conf.has_xavier_conf()) { XavierInitializer<T>(initializer_conf.xavier_conf(), random_seed, blob, data_format); } else if (initializer_conf.has_msra_conf()) { MsraInitializer<T>(initializer_conf.msra_conf(), random_seed, blob, data_format); + } else if (initializer_conf.has_range_conf()) { + RangeInitializer<T>(initializer_conf.range_conf(), random_seed, blob); } else { UNIMPLEMENTED(); } @@ -438,6 +544,8 @@ KU_INTEGRAL_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& ini ConstantInitializer<T>(static_cast<T>(initializer_conf.constant_int_conf().value()), blob); } else if (initializer_conf.has_random_uniform_int_conf()) { RandomIntUniformInitializer<T>(initializer_conf.random_uniform_int_conf(), random_seed, blob); + } else if (initializer_conf.has_int_range_conf()) { + IntSequenceInitializer<T>(initializer_conf.int_range_conf(), random_seed, blob); } else { UNIMPLEMENTED(); } diff --git a/oneflow/core/kernel/kernel_util.cu b/oneflow/core/kernel/kernel_util.cu index b0ef96d5623bdeebf6dc001fa6b06da24527d1e0..bc2677ce1fd68887624a600f741f601a59ccabd8 100644 --- a/oneflow/core/kernel/kernel_util.cu +++ b/oneflow/core/kernel/kernel_util.cu @@ -20,15 +20,60 @@ __global__ void ExpGpu(const int64_t n, const T* x, T* y) { } template<typename T> -__global__ void DivGpu(const int64_t n, T* x, const T* alpha_ptr) { +__global__ void DivByConstParaPtrGpu(const int64_t n, T* x, const T* alpha_ptr) { CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / (*alpha_ptr); } } +template<typename T> +__global__ void DivGpu(const int64_t n, const T* x, const T* y, T* z) { + CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] / y[i]; } +} + +template<typename T> +__global__ void DivByConstParaGpu(const int64_t n, T* x, const T alpha) { + CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / alpha; } +} + +template<typename T> +__global__ void ReplicateGpu(const int64_t n, T* y, const T* x) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = *x; } +} + template<typename T> __global__ void MulGpu(const int64_t n, const T* x, const T* y, T* z) { CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y[i]; } } +template<typename T> +__global__ void MulByScalarGpu(const int64_t n, const T* x, const T* y, T* z) { + CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y[0]; } +} + +template<typename T> +__global__ void AddByScalarGpu(const int64_t n, const T* x, const T y, T* z) { + CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] + y; } +} + +template<typename T> +__global__ void MulByScalarParaGpu(const int64_t n, const T* x, const T y, T* z) { + CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y; } +} + +template<typename T> +__global__ void ReciprocalGpu(const int64_t n, const T* x, T* y) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = static_cast<T>(1.0) / x[i]; } +} + +template<typename T> +__global__ void SquareGpu(const int64_t n, const T* x, T* y) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * x[i]; } +} + +template<typename T> +__global__ void SqrtGpu(const int64_t n, const T* x, T* y) { + CUDA_1D_KERNEL_LOOP(i, n) { y[i] = std::sqrt(x[i]); } +} + template<typename T> __global__ void AxpyGpu(const int n, const T alpha, const T* x, const int incx, T* y, const int incy) { @@ -142,37 +187,11 @@ cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) { return cublas_trans; } -const int32_t kMaxDim = OF_PP_SEQ_SIZE(DIM_SEQ); - +template<int32_t NDIMS> struct Int32Array { - int32_t val[kMaxDim]; -}; - -struct Int64Array { - int64_t val[kMaxDim]; + int32_t val[NDIMS]; }; -__device__ void ComputeOffset(const int32_t num_axis, const int64_t* x_dims, - const int32_t* permutation, int64_t* x_strides) { - int64_t buff[kMaxDim]; - int64_t cur_stride = 1; - for (int32_t i = num_axis - 1; i >= 0; --i) { - buff[i] = cur_stride; -#if __CUDA_ARCH__ >= 350 - cur_stride *= __ldg(x_dims + i); -#else - cur_stride *= x_dims[i]; -#endif - } - for (int32_t i = 0; i < num_axis; ++i) { -#if __CUDA_ARCH__ >= 350 - x_strides[i] = buff[__ldg(permutation + i)]; -#else - x_strides[i] = buff[permutation[i]]; -#endif - } -} - template<typename T> __global__ void CopyColsRegionGpu(const int64_t row_num, const int64_t col_num, const T* x, const int64_t x_col_offset, const int64_t x_lda, T* y, @@ -184,35 +203,29 @@ __global__ void CopyColsRegionGpu(const int64_t row_num, const int64_t col_num, } } -__device__ int64_t GetXIndex(const int32_t num_axis, const int64_t* y_shape, - const int64_t* x_strides, int64_t y_idx) { - int64_t x_idx = 0; - for (int32_t i = num_axis - 1; i >= 0 && y_idx > 0; --i) { +template<int32_t NDIMS> +__device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, int32_t y_idx) { + int32_t x_idx = 0; + for (int32_t i = NDIMS - 1; i >= 0; --i) { x_idx += (y_idx % y_shape[i]) * x_strides[i]; y_idx /= y_shape[i]; } return x_idx; } -template<typename T> -__global__ void TransposeGpu(const int32_t num_axis, const Int64Array x_shape, - const Int64Array y_shape, const Int32Array permutation, - const int64_t elem_cnt, const T* x, T* y) { - __shared__ int64_t x_strides[kMaxDim]; - __shared__ int64_t x_dims_shared[kMaxDim]; - __shared__ int64_t y_dims_shared[kMaxDim]; - __shared__ int32_t perm_shared[kMaxDim]; +template<int32_t NDIMS, typename T> +__global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<NDIMS> x_strides, + const int32_t elem_cnt, const T* x, T* y) { + __shared__ int32_t x_strides_shared[NDIMS]; + __shared__ int32_t y_dims_shared[NDIMS]; const int32_t tid = threadIdx.x; - if (tid < num_axis) { - x_dims_shared[tid] = x_shape.val[tid]; + if (tid < NDIMS) { y_dims_shared[tid] = y_shape.val[tid]; - perm_shared[tid] = permutation.val[tid]; + x_strides_shared[tid] = x_strides.val[tid]; } __syncthreads(); - if (tid == 0) { ComputeOffset(num_axis, x_dims_shared, perm_shared, x_strides); } - __syncthreads(); CUDA_1D_KERNEL_LOOP(y_idx, elem_cnt) { - const int64_t x_idx = GetXIndex(num_axis, y_dims_shared, x_strides, y_idx); + const int32_t x_idx = GetXIndex<NDIMS>(y_dims_shared, x_strides_shared, y_idx); #if __CUDA_ARCH__ >= 350 y[y_idx] = __ldg(x + x_idx); #else @@ -221,6 +234,32 @@ __global__ void TransposeGpu(const int32_t num_axis, const Int64Array x_shape, } } +template<int32_t NDIMS, typename T> +void Transpose(DeviceCtx* ctx, const Shape& x_shape, const Shape& y_shape, + const PbRf<int32_t>& permutation, const int64_t elem_cnt, const T* x, T* y) { + CHECK_LE(y_shape.elem_cnt(), MaxVal<int32_t>::value); + Int32Array<NDIMS> y_shape_struct; + FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); } + Int32Array<NDIMS> x_strides; + int32_t buff[NDIMS]; + int32_t cur_stride = 1; + for (int32_t i = NDIMS - 1; i >= 0; --i) { + buff[i] = cur_stride; + cur_stride *= x_shape.At(i); + } + for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; } + TransposeGpu<NDIMS, T> + <<<SMBlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + y_shape_struct, x_strides, elem_cnt, x, y); +} + +template<typename T> +struct TransposeUtil final { +#define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T> + DEFINE_STATIC_SWITCH_FUNC(void, Transpose, MAKE_TRANSPOSE_SWITCH_ENTRY, + MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); +}; + template<typename T, T (*reduce_core_func)(const T, const T)> __device__ void MatrixShrinkCols(const size_t row_num, const size_t thread_col_num, const T* x, const size_t x_col_num, const size_t x_lda, T* y, @@ -274,11 +313,25 @@ void MatrixRowReduce(DeviceCtx* ctx, const size_t row_num, const size_t col_num, ctx->cuda_stream()>>>(row_num, col_num, x, y, static_cast<T*>(temp_storage), temp_col_num); } +template<typename T> +__global__ void AssignStridedAddrGpu(T** dev_ptrs, T* start_ptr, int32_t stride_len, + int32_t stride_num) { + CUDA_1D_KERNEL_LOOP(i, stride_num) { dev_ptrs[i] = start_ptr + i * stride_len; } +} + +template<typename T> +void AssignStridedAddr(DeviceCtx* ctx, T** dev_ptrs, T* start_ptr, int stride_len, int stride_num) { + AssignStridedAddrGpu<T> + <<<BlocksNum4ThreadsNum(stride_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + dev_ptrs, start_ptr, stride_len, stride_num); +} + } // namespace template<> void Memcpy<DeviceType::kGPU>(DeviceCtx* ctx, void* dst, const void* src, size_t sz, cudaMemcpyKind kind) { + if (dst == src) { return; } CudaCheck(cudaMemcpyAsync(dst, src, sz, kind, ctx->cuda_stream())); } @@ -299,21 +352,6 @@ size_t GetTmpSizeForReduceSum(DataType data_type, int64_t sum_elem_num) { #undef MAKE_CUB_DEVICE_REDUCE_SWITCH_ENTRY -// create temporary host blob store initializer result -#define BEFORE_CPU_INITIALIZE() \ - RtBlobDesc blob_desc(blob->blob_desc().blob_desc_proto()); \ - char* host_raw_dptr = nullptr; \ - CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); \ - std::unique_ptr<Blob> host_blob; \ - host_blob.reset(new Blob(nullptr, &blob_desc, host_raw_dptr)); - -// asynchronous copy to device -#define AFTER_CPU_INITIALIZE() \ - Memcpy<DeviceType::kGPU>(ctx, blob->mut_dptr(), host_blob->dptr(), \ - blob->ByteSizeOfDataContentField(), cudaMemcpyHostToDevice); \ - CudaCheck(cudaStreamSynchronize(ctx->cuda_stream())); \ - CudaCheck(cudaFreeHost(host_raw_dptr)); - #define KU_IF_METHOD \ template<typename T, typename Derived> \ void GpuKernelUtilIf<T, Derived>:: @@ -343,21 +381,15 @@ KU_IF_METHOD RowSum(DeviceCtx* ctx, const int64_t row_num, const int64_t col_num void* temp_storage, const size_t temp_storage_bytes) { MatrixRowReduce<T, ReduceCoreAdd>(ctx, row_num, col_num, x, y, temp_storage, temp_storage_bytes); } + KU_IF_METHOD Transpose(DeviceCtx* ctx, const int32_t num_axis, const Shape& x_shape, const Shape& y_shape, const PbRf<int32_t>& permutation, const int64_t elem_cnt, const T* x, T* y) { - CHECK_LE(num_axis, kMaxDim); - Int64Array x_shape_struct; - Int64Array y_shape_struct; - Int32Array perm_struct; - FOR_RANGE(int32_t, i, 0, num_axis) { - x_shape_struct.val[i] = x_shape.At(i); - y_shape_struct.val[i] = y_shape.At(i); - perm_struct.val[i] = permutation[i]; - } - TransposeGpu<T> - <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( - num_axis, x_shape_struct, y_shape_struct, perm_struct, elem_cnt, x, y); + CHECK_LE(y_shape.elem_cnt(), MaxVal<int32_t>::value); + CHECK_EQ(num_axis, y_shape.NumAxes()); + CHECK_EQ(num_axis, x_shape.NumAxes()); + TransposeUtil<T>::SwitchTranspose(SwitchCase(num_axis), ctx, x_shape, y_shape, permutation, + elem_cnt, x, y); } KU_IF_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf, @@ -388,6 +420,18 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num KU_IF_METHOD Set(DeviceCtx* ctx, const T value, T* addr) { gpu_set<T><<<1, 1, 0, ctx->cuda_stream()>>>(value, addr); } +KU_IF_METHOD Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x) { + ReplicateGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, y, x); +} +KU_IF_METHOD AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) { + AddByScalarGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z); +} +KU_IF_METHOD MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z) { + MulByScalarParaGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z); +} #define KU_FLOATING_METHOD \ template<typename T> \ @@ -432,18 +476,75 @@ KU_FLOATING_METHOD Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, cublas_gemm<T>(ctx->cublas_pmh_handle(), cublas_trans_b, cublas_trans_a, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc); } +KU_FLOATING_METHOD BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, + const enum CBLAS_TRANSPOSE trans_b, int batch_size, int m, int n, + int k, const T alpha, const T* a, const T* b, const T beta, T* c, + T** buf) { + const int a_stride = m * k; + const int b_stride = k * n; + const int c_stride = m * n; + const int lda = (trans_a == CblasNoTrans) ? k : m; + const int ldb = (trans_b == CblasNoTrans) ? n : k; + const int ldc = n; + cublasOperation_t cublas_trans_a = CblasTrans2CublasTrans(trans_a); + cublasOperation_t cublas_trans_b = CblasTrans2CublasTrans(trans_b); + T** dev_a_ptrs = buf; + T** dev_b_ptrs = buf + batch_size; + T** dev_c_ptrs = buf + 2 * batch_size; + AssignStridedAddr<T>(ctx, dev_a_ptrs, const_cast<T*>(a), a_stride, batch_size); + AssignStridedAddr<T>(ctx, dev_b_ptrs, const_cast<T*>(b), b_stride, batch_size); + AssignStridedAddr<T>(ctx, dev_c_ptrs, c, c_stride, batch_size); +#if CUDA_VERSION >= 9010 + cudaDataType_t data_type = CudaDataType<T>::value; + cublasGemmBatchedEx(ctx->cublas_pmh_handle(), cublas_trans_b, cublas_trans_a, n, m, k, + reinterpret_cast<const void*>(&alpha), + reinterpret_cast<const void**>(const_cast<const T**>(dev_b_ptrs)), data_type, + ldb, reinterpret_cast<const void**>(const_cast<const T**>(dev_a_ptrs)), + data_type, lda, reinterpret_cast<const void*>(&beta), + reinterpret_cast<void**>(dev_c_ptrs), data_type, ldc, batch_size, data_type, + CUBLAS_GEMM_DEFAULT); +#else + cublas_gemmBatched<T>(ctx->cublas_pmh_handle(), cublas_trans_b, cublas_trans_a, n, m, k, &alpha, + const_cast<const T**>(dev_b_ptrs), ldb, const_cast<const T**>(dev_a_ptrs), + lda, &beta, dev_c_ptrs, ldc, batch_size); +#endif +} KU_FLOATING_METHOD Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { ExpGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y); } KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha) { - DivGpu<T> + DivByConstParaPtrGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, alpha); +} +KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha) { + DivByConstParaGpu<T> <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, alpha); } +KU_FLOATING_METHOD Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { + DivGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z); +} KU_FLOATING_METHOD Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { MulGpu<T> <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z); } +KU_FLOATING_METHOD MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { + MulByScalarGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y, z); +} +KU_FLOATING_METHOD Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y) { + ReciprocalGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y); +} +KU_FLOATING_METHOD Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { + SquareGpu<T> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y); +} +KU_FLOATING_METHOD Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { + SqrtGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, y); +} KU_FLOATING_METHOD Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon) { RsqrtGpu<T> <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, x, epsilon); @@ -594,4 +695,28 @@ __device__ double gpu_atomic_max(double* address, const double val) { return __longlong_as_double(old); } +template<typename T, typename U> +__global__ void CastOnGpu(const T* in, U* out, int64_t elem_num) { + CUDA_1D_KERNEL_LOOP(i, elem_num) { out[i] = static_cast<U>(in[i]); } +} + +template<typename T, typename U> +void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num) { + if (std::is_same<T, U>::value) { + Memcpy<DeviceType::kGPU>(ctx, out_dptr, in_dptr, elem_num * sizeof(T)); + } else { + CastOnGpu<T, U> + <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + in_dptr, out_dptr, elem_num); + } +} + +#define INSTANTIATE_COPY_ELEM_ON_GPU(T, U) \ + template void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num); + +#define MAKE_COPY_ELEM_ON_GPU_ENTRY(TPair, UPair) \ + INSTANTIATE_COPY_ELEM_ON_GPU(OF_PP_PAIR_FIRST(TPair), OF_PP_PAIR_FIRST(UPair)) + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_COPY_ELEM_ON_GPU_ENTRY, POD_DATA_TYPE_SEQ, POD_DATA_TYPE_SEQ) + } // namespace oneflow diff --git a/oneflow/core/kernel/kernel_util.h b/oneflow/core/kernel/kernel_util.h index 158c2dc0efc71a1c28a5267193d071bdb2340892..ebab13fe3e7147eec9cb137aba44f65097e26f1d 100644 --- a/oneflow/core/kernel/kernel_util.h +++ b/oneflow/core/kernel/kernel_util.h @@ -41,7 +41,7 @@ template<DeviceType device_type> void Memset(DeviceCtx*, void* dst, const char value, size_t sz); #if defined(__CUDACC__) -#define OF_DEVICE_FUNC __device__ +#define OF_DEVICE_FUNC __device__ __host__ __forceinline__ #else #define OF_DEVICE_FUNC #endif @@ -88,6 +88,14 @@ struct KernelUtilIf { c->mut_dptr<T>()); } + static void OFBatchedGemm(DeviceCtx* ctx, enum CBLAS_TRANSPOSE trans_a, + enum CBLAS_TRANSPOSE trans_b, const int batch_size, const int m, + const int n, const int k, const T alpha, const T* a, const T* b, + const T beta, T* c, T** buf) { + Derived::BatchedGemm(ctx, CblasRowMajor, trans_a, trans_b, batch_size, m, n, k, alpha, a, b, + beta, c, buf); + } + static void InitializeWithProperConf(DeviceCtx* ctx, const InitializerConf* initializer_conf, uint32_t random_seed, Blob* blob, const std::string& data_format = "") { @@ -141,6 +149,9 @@ struct CpuKernelUtilIf { const std::string& bn_in_op, int32_t dim_num, int64_t num_in_each_dim); static void Set(DeviceCtx* ctx, const T value, T* addr); + static void Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x); + static void AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z); + static void MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z); }; // CPU, Floating @@ -162,11 +173,20 @@ struct KernelUtil<DeviceType::kCPU, T, typename std::enable_if<IsFloating<T>::va const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const T alpha, const T* a, const int lda, const T* b, const int ldb, const T beta, T* c, const int ldc); + static void BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, + int batch_size, int m, int n, int k, const T alpha, const T* a, + const T* b, const T beta, T* c, T** buf); static void Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y); static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha); + static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha); static void Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); + static void MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); + static void Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y); + static void Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y); + static void Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y); static void Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon); static void Powx(DeviceCtx* ctx, const int64_t n, const T* x, const float power, T* y); @@ -247,6 +267,9 @@ struct GpuKernelUtilIf { const std::string& bn_in_op, int32_t dim_num, int64_t num_in_each_dim); static void Set(DeviceCtx* ctx, const T value, T* addr); + static void Replicate(DeviceCtx* ctx, const int64_t n, T* y, const T* x); + static void AddByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z); + static void MulByScalarPara(DeviceCtx* ctx, const int64_t n, const T* x, const T y, T* z); }; // GPU, Floating @@ -270,10 +293,20 @@ struct KernelUtil<DeviceType::kGPU, T, typename std::enable_if<IsFloating<T>::va const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const T alpha, const T* a, const int lda, const T* b, const int ldb, const T beta, T* c, const int ldc); + static void BatchedGemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, + int batch_size, int m, int n, int k, const T alpha, const T* a, + const T* b, const T beta, T* c, T** buf); static void Exp(DeviceCtx* ctx, const int64_t n, const T* x, T* y); static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T* alpha); + static void Div(DeviceCtx* ctx, const int64_t n, T* x, const T alpha); + static void Div(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); + static void MulByScalar(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); + static void Reciprocal(DeviceCtx* ctx, const int n, const T* x, T* y); + static void Square(DeviceCtx* ctx, const int64_t n, const T* x, T* y); + static void Sqrt(DeviceCtx* ctx, const int64_t n, const T* x, T* y); static void Rsqrt(DeviceCtx* ctx, const int64_t n, T* x, const float epsilon); static void Sigmoid(DeviceCtx* ctx, int64_t n, const T* x, T* y); @@ -495,6 +528,24 @@ typename std::enable_if<!std::is_same<T, U>::value>::type CopyElem(const T* in_d FOR_RANGE(int64_t, i, 0, elem_num) { *(out_dptr++) = static_cast<U>(*(in_dptr++)); } } +template<typename T, typename U> +void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num); + +// create temporary host blob store initializer result +#define BEFORE_CPU_INITIALIZE() \ + RtBlobDesc blob_desc(blob->blob_desc().blob_desc_proto()); \ + char* host_raw_dptr = nullptr; \ + CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); \ + std::unique_ptr<Blob> host_blob; \ + host_blob.reset(new Blob(nullptr, &blob_desc, host_raw_dptr)); + +// asynchronous copy to device +#define AFTER_CPU_INITIALIZE() \ + Memcpy<DeviceType::kGPU>(ctx, blob->mut_dptr(), host_blob->dptr(), \ + blob->ByteSizeOfDataContentField(), cudaMemcpyHostToDevice); \ + CudaCheck(cudaStreamSynchronize(ctx->cuda_stream())); \ + CudaCheck(cudaFreeHost(host_raw_dptr)); + } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_KERNEL_UTIL_H_ diff --git a/oneflow/core/kernel/layer_norm_kernel.cpp b/oneflow/core/kernel/layer_norm_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35d28c095071ce7ac324182c4b83e0cd4408bcbe --- /dev/null +++ b/oneflow/core/kernel/layer_norm_kernel.cpp @@ -0,0 +1,171 @@ +#include "oneflow/core/kernel/layer_norm_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +namespace { + +InitializerConf ConstantInitializerConf(float val) { + InitializerConf conf; + conf.mutable_constant_conf()->set_value(val); + return conf; +} + +InitializerConf OnesInitializerConf() { return ConstantInitializerConf(1.0f); } + +InitializerConf ZerosInitializerConf() { return ConstantInitializerConf(0.0f); } + +} // namespace + +template<DeviceType device_type, typename T> +void LayerNormKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const LayerNormOpConf& conf = this->op_conf().layer_norm_conf(); + const Blob* in = BnInOp2Blob("in"); + const Blob* bn_scale = BnInOp2Blob("cudnn_bn_scale_ones"); + const Blob* bn_bias = BnInOp2Blob("cudnn_bn_bias_zeros"); + Blob* out = BnInOp2Blob("out"); + Blob* normalize_out = conf.scale() ? BnInOp2Blob("normalize_out") : out; + Blob* mean = BnInOp2Blob("cudnn_bn_mean"); + Blob* inv_variance = BnInOp2Blob("cudnn_bn_inv_variance"); + LayerNormKernelUtil<device_type, T>::NormalizeForward( + ctx.device_ctx, in, bn_scale, bn_bias, conf.epsilon(), normalize_out, mean, inv_variance); + if (conf.scale()) { + const Blob* gamma = BnInOp2Blob("gamma"); + const int64_t m = gamma->shape().elem_cnt(); + CHECK_EQ(out->shape().elem_cnt() % m, 0); + const int64_t n = out->shape().elem_cnt() / m; + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, XpuVarNdarray<T>({n, m}, out->mut_dptr<T>()), + XpuVarNdarray<const T>({n, m}, normalize_out->dptr<T>()), + XpuVarNdarray<const T>({1, m}, gamma->dptr<T>())); + } + if (conf.center()) { + const Blob* beta = BnInOp2Blob("beta"); + const int64_t m = beta->shape().elem_cnt(); + CHECK_EQ(out->shape().elem_cnt() % m, 0); + const int64_t n = out->shape().elem_cnt() / m; + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncAdd>( + ctx.device_ctx, XpuVarNdarray<T>({n, m}, out->mut_dptr<T>()), + XpuVarNdarray<const T>({n, m}, out->dptr<T>()), + XpuVarNdarray<const T>({1, m}, beta->dptr<T>())); + } +} + +template<DeviceType device_type, typename T> +void LayerNormKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const LayerNormOpConf& conf = this->op_conf().layer_norm_conf(); + const Blob* out_diff = BnInOp2Blob(GenDiffBn("out")); + Blob* in_diff = BnInOp2Blob(GenDiffBn("in")); + Blob* bw_buf = BnInOp2Blob("bw_reduce_buf"); + if (conf.center() && this->op_conf().trainable()) { + Blob* beta_diff = BnInOp2Blob(GenDiffBn("beta")); + const int64_t m = beta_diff->shape().elem_cnt(); + CHECK_EQ(out_diff->shape().elem_cnt() % m, 0); + const int64_t n = out_diff->shape().elem_cnt() / m; + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, + XpuVarNdarray<T>({1, m}, beta_diff->mut_dptr<T>()), + XpuVarNdarray<const T>({n, m}, out_diff->dptr<T>()), + XpuVarNdarray<T>({n, m}, bw_buf->mut_dptr<T>())); + } + if (conf.scale()) { + Blob* normalize_out = BnInOp2Blob("normalize_out"); + const Blob* gamma = BnInOp2Blob("gamma"); + Blob* gamma_diff = BnInOp2Blob(GenDiffBn("gamma")); + const int64_t m = gamma_diff->shape().elem_cnt(); + CHECK_EQ(out_diff->shape().elem_cnt() % m, 0); + const int64_t n = out_diff->shape().elem_cnt() / m; + if (this->op_conf().trainable()) { + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, XpuVarNdarray<T>({n, m}, bw_buf->mut_dptr<T>()), + XpuVarNdarray<const T>({n, m}, normalize_out->dptr<T>()), + XpuVarNdarray<const T>({n, m}, out_diff->dptr<T>())); + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, + XpuVarNdarray<T>({1, m}, gamma_diff->mut_dptr<T>()), + XpuVarNdarray<const T>({n, m}, bw_buf->dptr<T>()), + XpuVarNdarray<T>({n, m}, bw_buf->mut_dptr<T>())); + } + if (in_diff) { + NdarrayUtil<device_type, T>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, XpuVarNdarray<T>({n, m}, normalize_out->mut_dptr<T>()), + XpuVarNdarray<const T>({n, m}, out_diff->dptr<T>()), + XpuVarNdarray<const T>({1, m}, gamma->dptr<T>())); + } + } + if (in_diff) { + const Blob* in = BnInOp2Blob("in"); + const Blob* normalize_out_diff = conf.scale() ? BnInOp2Blob("normalize_out") : out_diff; + const Blob* bn_scale = BnInOp2Blob("cudnn_bn_scale_ones"); + const Blob* mean = BnInOp2Blob("cudnn_bn_mean"); + const Blob* inv_variance = BnInOp2Blob("cudnn_bn_inv_variance"); + Blob* bn_scale_diff = BnInOp2Blob("cudnn_bn_scale_diff_buf"); + Blob* bn_bias_diff = BnInOp2Blob("cudnn_bn_bias_diff_buf"); + LayerNormKernelUtil<device_type, T>::NormalizeBackward( + ctx.device_ctx, in, bn_scale, mean, inv_variance, normalize_out_diff, conf.epsilon(), + in_diff, bn_scale_diff, bn_bias_diff); + } +} + +template<DeviceType device_type, typename T> +void LayerNormKernel<device_type, T>::InitModelBlobsWithRandomSeed( + DeviceCtx* ctx, std::mt19937*, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const LayerNormOpConf& conf = this->op_conf().layer_norm_conf(); + if (conf.scale()) { + InitializerConf ones_initializer = OnesInitializerConf(); + KernelUtil<device_type, T>::InitializeWithConf(ctx, ones_initializer, 0, BnInOp2Blob("gamma")); + } + if (conf.center()) { + InitializerConf zeros_initializer = ZerosInitializerConf(); + KernelUtil<device_type, T>::InitializeWithConf(ctx, zeros_initializer, 0, BnInOp2Blob("beta")); + } +} + +template<DeviceType device_type, typename T> +void LayerNormKernel<device_type, T>::InitModelBlobsWithDir( + DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const LayerNormOpConf& conf = this->op_conf().layer_norm_conf(); + if (conf.scale()) { + Blob* gamma = BnInOp2Blob("gamma"); + KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, gamma, + "gamma", gamma->shape().At(0), + gamma->shape().Count(1)); + } + if (conf.center()) { + Blob* beta = BnInOp2Blob("beta"); + KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, beta, + "beta", beta->shape().At(0), + beta->shape().Count(1)); + } +} + +template<DeviceType device_type, typename T> +void LayerNormKernel<device_type, T>::InitConstBufBlobs( + DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + InitializerConf ones_initializer = OnesInitializerConf(); + KernelUtil<device_type, T>::InitializeWithConf(ctx, ones_initializer, 0, + BnInOp2Blob("cudnn_bn_scale_ones")); + InitializerConf zeros_initializer = ZerosInitializerConf(); + KernelUtil<device_type, T>::InitializeWithConf(ctx, zeros_initializer, 0, + BnInOp2Blob("cudnn_bn_bias_zeros")); +} + +template<typename T> +struct LayerNormKernelUtil<DeviceType::kCPU, T> { + static void NormalizeForward(const DeviceCtx* ctx, const Blob* in, const Blob* scale, + const Blob* bias, double epsilon, Blob* out, Blob* mean, + Blob* inv_variance) { + UNIMPLEMENTED(); + } + static void NormalizeBackward(const DeviceCtx* ctx, const Blob* in, const Blob* scale, + const Blob* mean, const Blob* inv_variance, const Blob* out_diff, + double epsilon, Blob* in_diff, Blob* scale_diff, Blob* bias_diff) { + UNIMPLEMENTED(); + } +}; + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kLayerNormConf, LayerNormKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/layer_norm_kernel.cu b/oneflow/core/kernel/layer_norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5fe0da7e86ec0a746aa2a256c7ff049983175b3a --- /dev/null +++ b/oneflow/core/kernel/layer_norm_kernel.cu @@ -0,0 +1,85 @@ +#include "oneflow/core/kernel/layer_norm_kernel.h" + +namespace oneflow { + +namespace { + +class LayerNormCudnnBnCtx final { + public: + LayerNormCudnnBnCtx(const Shape& data_shape, const Shape& param_shape, DataType data_type) { + const int64_t cudnn_c = param_shape.elem_cnt(); + CHECK_EQ(data_shape.elem_cnt() % cudnn_c, 0); + const int64_t cudnn_w = data_shape.elem_cnt() / cudnn_c; + CHECK_LT(cudnn_c, MaxVal<int32_t>::value); + CHECK_LT(cudnn_w, MaxVal<int32_t>::value); + data_tensor_desc_.reset(new CudnnTensorDesc(CUDNN_TENSOR_NCHW, data_type, 1, + static_cast<int32_t>(cudnn_c), 1, + static_cast<int32_t>(cudnn_w))); + param_tensor_desc_.reset( + new CudnnTensorDesc(CUDNN_TENSOR_NCHW, data_type, 1, static_cast<int32_t>(cudnn_c), 1, 1)); +#if (CUDNN_VERSION >= 7000) + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode_ = CUDNN_BATCHNORM_SPATIAL; +#endif + } + ~LayerNormCudnnBnCtx() = default; + + const cudnnTensorDescriptor_t& data_tensor_desc() const { return data_tensor_desc_->Get(); } + const cudnnTensorDescriptor_t& param_tensor_desc() const { return param_tensor_desc_->Get(); } + cudnnBatchNormMode_t mode() const { return mode_; }; + + private: + std::unique_ptr<CudnnTensorDesc> data_tensor_desc_; + std::unique_ptr<CudnnTensorDesc> param_tensor_desc_; + cudnnBatchNormMode_t mode_; +}; + +} // namespace + +template<typename T> +struct LayerNormKernelUtil<DeviceType::kGPU, T> { + static void NormalizeForward(const DeviceCtx* ctx, const Blob* in, const Blob* scale, + const Blob* bias, double epsilon, Blob* out, Blob* mean, + Blob* inv_variance); + static void NormalizeBackward(const DeviceCtx* ctx, const Blob* in, const Blob* scale, + const Blob* mean, const Blob* inv_variance, const Blob* out_diff, + double epsilon, Blob* in_diff, Blob* scale_diff, Blob* bias_diff); +}; + +template<typename T> +void LayerNormKernelUtil<DeviceType::kGPU, T>::NormalizeForward(const DeviceCtx* ctx, + const Blob* in, const Blob* scale, + const Blob* bias, double epsilon, + Blob* out, Blob* mean, + Blob* inv_variance) { + CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON); + LayerNormCudnnBnCtx bn_ctx(in->static_shape(), mean->shape(), in->data_type()); + CudaCheck(cudnnBatchNormalizationForwardTraining( + ctx->cudnn_handle(), bn_ctx.mode(), OnePtr<T>::value, ZeroPtr<T>::value, + bn_ctx.data_tensor_desc(), in->dptr<T>(), bn_ctx.data_tensor_desc(), out->mut_dptr<T>(), + bn_ctx.param_tensor_desc(), scale->dptr<T>(), bias->dptr<T>(), 1.0, nullptr, nullptr, epsilon, + mean->mut_dptr<T>(), inv_variance->mut_dptr<T>())); +} + +template<typename T> +void LayerNormKernelUtil<DeviceType::kGPU, T>::NormalizeBackward( + const DeviceCtx* ctx, const Blob* in, const Blob* scale, const Blob* mean, + const Blob* inv_variance, const Blob* out_diff, double epsilon, Blob* in_diff, Blob* scale_diff, + Blob* bias_diff) { + CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON); + LayerNormCudnnBnCtx bn_ctx(in->static_shape(), mean->shape(), in->data_type()); + CudaCheck(cudnnBatchNormalizationBackward( + ctx->cudnn_handle(), bn_ctx.mode(), OnePtr<T>::value, ZeroPtr<T>::value, OnePtr<T>::value, + ZeroPtr<T>::value, bn_ctx.data_tensor_desc(), in->dptr<T>(), bn_ctx.data_tensor_desc(), + out_diff->dptr<T>(), bn_ctx.data_tensor_desc(), in_diff->mut_dptr<T>(), + bn_ctx.param_tensor_desc(), scale->dptr<T>(), scale_diff->mut_dptr<T>(), + bias_diff->mut_dptr<T>(), epsilon, mean->dptr<T>(), inv_variance->dptr<T>())); +} + +#define INSTANTIATE_LAYER_NORM_KERNEL_UTIL_GPU(type_cpp, type_proto) \ + template struct LayerNormKernelUtil<DeviceType::kGPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_LAYER_NORM_KERNEL_UTIL_GPU, FLOATING_DATA_TYPE_SEQ) +#undef INSTANTIATE_LAYER_NORM_KERNEL_UTIL_GPU + +} // namespace oneflow diff --git a/oneflow/core/kernel/layer_norm_kernel.h b/oneflow/core/kernel/layer_norm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1241934fbb498974afaa5062ac11adbe0fd2c65e --- /dev/null +++ b/oneflow/core/kernel/layer_norm_kernel.h @@ -0,0 +1,45 @@ +#ifndef ONEFLOW_CORE_KERNEL_LAYER_NORM_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_LAYER_NORM_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class LayerNormKernel final : public KernelIfWithModel<device_type, T>, + public KernelIfWithActivation<device_type, T> { + public: + OF_DISALLOW_COPY_AND_MOVE(LayerNormKernel); + LayerNormKernel() = default; + ~LayerNormKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void InitModelBlobsWithRandomSeed(DeviceCtx*, std::mt19937*, + std::function<Blob*(const std::string&)>) const override; + void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num, + const std::string& model_load_dir, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void InitConstBufBlobs(DeviceCtx* ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const; + const PbMessage& GetCustomizedOpConf() const override { + return this->op_conf().layer_norm_conf(); + } +}; + +template<DeviceType device_type, typename T> +struct LayerNormKernelUtil { + static void NormalizeForward(const DeviceCtx* ctx, const Blob* in, const Blob* scale, + const Blob* bias, double epsilon, Blob* out, Blob* mean, + Blob* inv_variance); + static void NormalizeBackward(const DeviceCtx* ctx, const Blob* in, const Blob* scale, + const Blob* mean, const Blob* inv_variance, const Blob* out_diff, + double epsilon, Blob* in_diff, Blob* scale_diff, Blob* bias_diff); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_LAYER_NORM_KERNEL_H_ diff --git a/oneflow/core/kernel/loss_kernel.cpp b/oneflow/core/kernel/loss_kernel.cpp index 7f84396a62f395a393f094f18688568484530857..bea078dead4fcc77e0be7640aac1677215c42bf9 100644 --- a/oneflow/core/kernel/loss_kernel.cpp +++ b/oneflow/core/kernel/loss_kernel.cpp @@ -1,23 +1,24 @@ #include "oneflow/core/kernel/loss_kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { -template<DeviceType device_type, typename PredType, typename LabelType> -void LossKernel<device_type, PredType, LabelType>::SetLossInstanceNumBlob( +template<DeviceType device_type, typename PredType> +void LossKernel<device_type, PredType>::SetLossInstanceNum( const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const { - CHECK_GE(this->op_attribute().input_bns().size(), 2); - // already did CheckSameDim0ValidNum in Kernel::Forward - int64_t dim0_valid_num_sum = - BnInOp2Blob(this->op_attribute().input_bns(0))->CalcDim0ValidNumSum(); - KernelUtil<device_type, PredType>::Set(ctx.device_ctx, static_cast<PredType>(dim0_valid_num_sum), + const int64_t loss_instance_num = CalcLossInstanceNum(ctx, BnInOp2Blob); + KernelUtil<device_type, PredType>::Set(ctx.device_ctx, static_cast<PredType>(loss_instance_num), BnInOp2Blob("loss_instance_num")->mut_dptr<PredType>()); + CHECK(BnInOp2Blob(GenDiffBn("prediction"))->has_loss_instance_num_field()); + BnInOp2Blob(GenDiffBn("prediction")) + ->set_loss_instance_num(static_cast<float>(loss_instance_num)); } -template<DeviceType device_type, typename PredType, typename LabelType> -void LossKernel<device_type, PredType, LabelType>::ForwardDataContent( +template<DeviceType device_type, typename PredType> +void LossKernel<device_type, PredType>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { VirtualLossForwardDataContent(ctx, BnInOp2Blob); - SetLossInstanceNumBlob(ctx, BnInOp2Blob); + SetLossInstanceNum(ctx, BnInOp2Blob); const LossKernelConf& conf = GetLossKernelConf(this->kernel_conf()); int64_t n = BnInOp2Blob("prediction")->shape().At(0); @@ -30,8 +31,11 @@ void LossKernel<device_type, PredType, LabelType>::ForwardDataContent( if (weight_blob != nullptr) { PredType* weight = weight_blob->mut_dptr<PredType>(); if (weight_blob->shape().elem_cnt() == n) { - KernelUtil<device_type, PredType>::Mul(ctx.device_ctx, n, weight, prediction_diff, - prediction_diff); + const int64_t m = prediction_diff_blob->shape().Count(1); + NdarrayUtil<device_type, PredType>::template BroadcastApply<BinaryFuncMul>( + ctx.device_ctx, XpuVarNdarray<PredType>({n, m}, prediction_diff), + XpuVarNdarray<const PredType>({n, 1}, weight), + XpuVarNdarray<const PredType>({n, m}, prediction_diff)); } else if (weight_blob->shape().elem_cnt() == 1) { KernelUtil<device_type, PredType>::Scal(ctx.device_ctx, n, weight, prediction_diff, 1); } else { @@ -53,32 +57,38 @@ void LossKernel<device_type, PredType, LabelType>::ForwardDataContent( } } -template<DeviceType device_type, typename PredType, typename LabelType> -void LossKernel<device_type, PredType, LabelType>::ForwardDataId( +template<DeviceType device_type, typename PredType> +void LossKernel<device_type, PredType>::ForwardDataId( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { BnInOp2Blob("loss")->CopyDataIdFrom(ctx.device_ctx, BnInOp2Blob("prediction")); } -template<DeviceType device_type, typename PredType, typename LabelType> -void LossKernel<device_type, PredType, LabelType>::ForwardColNum( +template<DeviceType device_type, typename PredType> +void LossKernel<device_type, PredType>::ForwardColNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { BnInOp2Blob(GenDiffBn("prediction"))->CopyColNumFrom(ctx.device_ctx, BnInOp2Blob("prediction")); } -template<DeviceType device_type, typename PredType, typename LabelType> -void LossKernel<device_type, PredType, LabelType>::ForwardDim0ValidNum( +template<DeviceType device_type, typename PredType> +void LossKernel<device_type, PredType>::ForwardDim0ValidNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { BnInOp2Blob(GenDiffBn("prediction")) ->CopyDim0ValidNumFrom(ctx.device_ctx, BnInOp2Blob("prediction")); BnInOp2Blob("loss")->CopyDim0ValidNumFrom(ctx.device_ctx, BnInOp2Blob("prediction")); } -template<DeviceType device_type, typename PredType, typename LabelType> -void LossKernel<device_type, PredType, LabelType>::ForwardRecordIdInDevicePiece( +template<DeviceType device_type, typename PredType> +void LossKernel<device_type, PredType>::ForwardRecordIdInDevicePiece( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { // do nothing } +template<DeviceType device_type, typename PredType> +int64_t LossKernel<device_type, PredType>::CalcLossInstanceNum( + const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const { + return BnInOp2Blob("prediction")->CalcDim0ValidNumSum(); +} + template<typename T> struct LossKernelUtil<DeviceType::kCPU, T> { static void ComputeReductionCoefficient(DeviceCtx* ctx, int64_t data_num, int64_t weight_length, @@ -120,11 +130,9 @@ struct LossKernelUtil<DeviceType::kCPU, T> { OF_PP_FOR_EACH_TUPLE(MAKE_LOSS_KERNEL_UTIL_ENTRY, FLOATING_DATA_TYPE_SEQ) -#define MAKE_LOSS_ENTRY(device_type, data_type_pair, label_type_pair) \ - template class LossKernel<device_type, OF_PP_PAIR_FIRST(data_type_pair), \ - OF_PP_PAIR_FIRST(label_type_pair)>; +#define MAKE_LOSS_ENTRY(device_type, data_type_pair) \ + template class LossKernel<device_type, OF_PP_PAIR_FIRST(data_type_pair)>; -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_LOSS_ENTRY, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, - ARITHMETIC_DATA_TYPE_SEQ) +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_LOSS_ENTRY, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ) } // namespace oneflow diff --git a/oneflow/core/kernel/loss_kernel.h b/oneflow/core/kernel/loss_kernel.h index 3e42b906aadc6359252981ef00ff36587a1c0291..7905587308f5ca37a74b9c953acc074c81eaed6f 100644 --- a/oneflow/core/kernel/loss_kernel.h +++ b/oneflow/core/kernel/loss_kernel.h @@ -5,7 +5,7 @@ namespace oneflow { -template<DeviceType device_type, typename PredType, typename LabelType> +template<DeviceType device_type, typename PredType> class LossKernel : public KernelIf<device_type> { public: OF_DISALLOW_COPY_AND_MOVE(LossKernel); @@ -16,6 +16,8 @@ class LossKernel : public KernelIf<device_type> { virtual void VirtualLossForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0; virtual const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const = 0; + virtual int64_t CalcLossInstanceNum( + const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const; private: void ForwardDataContent(const KernelCtx& ctx, @@ -28,8 +30,8 @@ class LossKernel : public KernelIf<device_type> { std::function<Blob*(const std::string&)> BnInOp2Blob) const override; void ForwardRecordIdInDevicePiece( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; - void SetLossInstanceNumBlob(const KernelCtx& ctx, - const std::function<Blob*(const std::string&)>& BnInOp2Blob) const; + void SetLossInstanceNum(const KernelCtx& ctx, + const std::function<Blob*(const std::string&)>& BnInOp2Blob) const; }; template<DeviceType device_type, typename T> diff --git a/oneflow/core/kernel/loss_print_kernel.cpp b/oneflow/core/kernel/loss_print_kernel.cpp index 7b6b26875403d49b884bea9575b57f39582a5fbb..c7cdfce1807928ae12f1c48ae75ff457604c6c02 100644 --- a/oneflow/core/kernel/loss_print_kernel.cpp +++ b/oneflow/core/kernel/loss_print_kernel.cpp @@ -20,7 +20,15 @@ void LossPrintKernel<T>::Forward(const KernelCtx& kernel_ctx, } loss_reduced /= reduction_coefficient; const char* loss_op_name = op_conf().name().c_str() + LossPrintPrefix.length(); - LOG(INFO) << loss_op_name << ":" << loss_reduced; + double* prev_ts = static_cast<double*>(kernel_ctx.other); + const double cur_ts = GetCurTime() / 1e9; + if (*prev_ts == 0) { + LOG(INFO) << loss_op_name << ":" << loss_reduced; + } else { + LOG(INFO) << loss_op_name << ":" << std::fixed << std::setprecision(3) << loss_reduced << " (" + << (cur_ts - *prev_ts) << " sec)"; + } + *prev_ts = cur_ts; } template<typename T> diff --git a/oneflow/core/kernel/matmul_kernel.cpp b/oneflow/core/kernel/matmul_kernel.cpp index e47f6e0bfd6c066ad90cd82df7382dff5f6abcb3..4a0130e566109f08e9e71f8ab257400fa40cf262 100644 --- a/oneflow/core/kernel/matmul_kernel.cpp +++ b/oneflow/core/kernel/matmul_kernel.cpp @@ -6,59 +6,98 @@ namespace oneflow { template<DeviceType device_type, typename T> void MatmulKernel<device_type, T>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - const Blob* in_blob = BnInOp2Blob("in"); - const Blob* weight_blob = BnInOp2Blob("weight"); + const Blob* a_blob = BnInOp2Blob("a"); + const Blob* b_blob = BnInOp2Blob("b"); + Blob* fw_buf_blob = BnInOp2Blob("fw_buf"); Blob* out_blob = BnInOp2Blob("out"); - // out = in * weight' - KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasTrans, OneVal<T>::value, - ZeroVal<T>::value, in_blob, weight_blob, out_blob); - if (this->op_conf().matmul_conf().has_bias()) { - const Blob* bias_blob = BnInOp2Blob("bias"); - const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); - // out = bias_multiplier * bias + out - KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, - OneVal<T>::value, OneVal<T>::value, bias_mul_blob, - bias_blob, out_blob); + bool transpose_a = this->op_conf().matmul_conf().transpose_a(); + bool transpose_b = this->op_conf().matmul_conf().transpose_b(); + if (a_blob->static_shape().dim_vec().size() == 2) { + Calc2DMatMul(ctx.device_ctx, a_blob, transpose_a, b_blob, transpose_b, out_blob, false); + } else { + CalcBatchMatMul(ctx.device_ctx, a_blob, transpose_a, b_blob, transpose_b, out_blob, fw_buf_blob, + false); } } template<DeviceType device_type, typename T> void MatmulKernel<device_type, T>::BackwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* a_blob = BnInOp2Blob("a"); + const Blob* b_blob = BnInOp2Blob("b"); const Blob* out_diff_blob = BnInOp2Blob("out_diff"); - const Blob* in_blob = BnInOp2Blob("in"); - Blob* in_diff_blob = BnInOp2Blob("in_diff"); - const Blob* weight_blob = BnInOp2Blob("weight"); - Blob* weight_diff_blob = BnInOp2Blob("weight_diff"); - // weight_diff = out_diff * in' - KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal<T>::value, - ZeroVal<T>::value, out_diff_blob, in_blob, weight_diff_blob); - // in_diff = out_diff * weight - KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, OneVal<T>::value, - ZeroVal<T>::value, out_diff_blob, weight_blob, in_diff_blob); - if (this->op_conf().matmul_conf().has_bias()) { - const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); - Blob* bias_diff_blob = BnInOp2Blob("bias_diff"); - // bias_diff = bias_multiplier' * out_diff - KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal<T>::value, - ZeroVal<T>::value, bias_mul_blob, out_diff_blob, - bias_diff_blob); + Blob* a_diff_blob = BnInOp2Blob("a_diff"); + Blob* b_diff_blob = BnInOp2Blob("b_diff"); + Blob* bw_buf_blob = BnInOp2Blob("bw_buf"); + bool transpose_a = this->op_conf().matmul_conf().transpose_a(); + bool transpose_b = this->op_conf().matmul_conf().transpose_b(); + // trans_a trans_b a_diff b_diff + // T T b'g' g'a' + // T F bg' ag + // F T gb g'a + // F F gb' a'g + if (a_blob->static_shape().dim_vec().size() == 2) { + Calc2DMatMul(ctx.device_ctx, b_blob, !(transpose_a ^ transpose_b), out_diff_blob, transpose_a, + a_diff_blob, !transpose_a); + Calc2DMatMul(ctx.device_ctx, a_blob, !(transpose_a ^ transpose_b), out_diff_blob, transpose_b, + b_diff_blob, transpose_b); + } else { + CalcBatchMatMul(ctx.device_ctx, b_blob, !(transpose_a ^ transpose_b), out_diff_blob, + transpose_a, a_diff_blob, bw_buf_blob, !transpose_a); + CalcBatchMatMul(ctx.device_ctx, a_blob, !(transpose_a ^ transpose_b), out_diff_blob, + transpose_b, b_diff_blob, bw_buf_blob, transpose_b); } } template<DeviceType device_type, typename T> -void MatmulKernel<device_type, T>::InitConstBufBlobs( - DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - if (!this->op_conf().matmul_conf().has_bias()) { return; } - InitializerConf bias_multiplier_initializer_conf; - bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f); - KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0, - BnInOp2Blob("bias_multiplier")); +const PbMessage& MatmulKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().matmul_conf(); } template<DeviceType device_type, typename T> -const PbMessage& MatmulKernel<device_type, T>::GetCustomizedOpConf() const { - return this->op_conf().matmul_conf(); +void MatmulKernel<device_type, T>::Calc2DMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, + const Blob* b, bool trans_b, Blob* c, + bool swap_in) const { + CBLAS_TRANSPOSE blas_trans_a = trans_a ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE blas_trans_b = trans_b ? CblasTrans : CblasNoTrans; + if (swap_in) { + KernelUtil<device_type, T>::BlobGemm(ctx, blas_trans_b, blas_trans_a, OneVal<T>::value, + ZeroVal<T>::value, b, a, c); + } else { + KernelUtil<device_type, T>::BlobGemm(ctx, blas_trans_a, blas_trans_b, OneVal<T>::value, + ZeroVal<T>::value, a, b, c); + } +} + +template<DeviceType device_type, typename T> +void MatmulKernel<device_type, T>::CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, + const Blob* b, bool trans_b, Blob* c, Blob* buf, + bool swap_in) const { + if (swap_in) { + CalcBatchMatMul(ctx, b, trans_b, a, trans_a, c, buf); + } else { + CalcBatchMatMul(ctx, a, trans_a, b, trans_b, c, buf); + } +} + +template<DeviceType device_type, typename T> +void MatmulKernel<device_type, T>::CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, + const Blob* b, bool trans_b, Blob* c, + Blob* buf) const { + CBLAS_TRANSPOSE blas_trans_a = trans_a ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE blas_trans_b = trans_b ? CblasTrans : CblasNoTrans; + int32_t dim_num = a->shape().dim_vec().size(); + int32_t batch_size = a->shape().Count(0, dim_num - 2); + int m = trans_a ? a->shape().At(dim_num - 1) : a->shape().At(dim_num - 2); + int k = trans_a ? a->shape().At(dim_num - 2) : a->shape().At(dim_num - 1); + int n = trans_b ? b->shape().At(dim_num - 2) : b->shape().At(dim_num - 1); + const T* a_dptr = a->dptr<T>(); + const T* b_dptr = b->dptr<T>(); + T* c_dptr = c->mut_dptr<T>(); + T** buf_dptr = reinterpret_cast<T**>(buf->mut_dptr<int64_t>()); + KernelUtil<device_type, T>::OFBatchedGemm(ctx, blas_trans_a, blas_trans_b, batch_size, m, n, k, + OneVal<T>::value, a_dptr, b_dptr, ZeroVal<T>::value, + c_dptr, buf_dptr); } ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMatmulConf, MatmulKernel, FLOATING_DATA_TYPE_SEQ); diff --git a/oneflow/core/kernel/matmul_kernel.h b/oneflow/core/kernel/matmul_kernel.h index 4a82c3da53b4dab006fb322335c2549cc84f3a0b..6917a10b6f3f2cdbd9947d77efbe58d282378f18 100644 --- a/oneflow/core/kernel/matmul_kernel.h +++ b/oneflow/core/kernel/matmul_kernel.h @@ -16,8 +16,12 @@ class MatmulKernel final : public KernelIfWithModel<device_type, T> { std::function<Blob*(const std::string&)>) const override; void BackwardDataContent(const KernelCtx&, std::function<Blob*(const std::string&)>) const override; - void InitConstBufBlobs(DeviceCtx*, - std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void Calc2DMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, const Blob* b, bool trans_b, + Blob* c, bool swap_in) const; + void CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, const Blob* b, bool trans_b, + Blob* c, Blob* buf, bool swap_in) const; + void CalcBatchMatMul(DeviceCtx* ctx, const Blob* a, bool trans_a, const Blob* b, bool trans_b, + Blob* c, Blob* buf) const; const PbMessage& GetCustomizedOpConf() const override; }; diff --git a/oneflow/core/kernel/max_pooling_kernel.cpp b/oneflow/core/kernel/max_pooling_kernel.cpp index bf32b21b39b13780ac601f4a3c66959a8241a5dd..0ea7e8f2fd0fae727b82bce1e22a07dd5ac07909 100644 --- a/oneflow/core/kernel/max_pooling_kernel.cpp +++ b/oneflow/core/kernel/max_pooling_kernel.cpp @@ -4,7 +4,7 @@ namespace oneflow { template<typename T> T MaxPoolingKernel<DeviceType::kCPU, T>::ForwardInitialize() const { - return MinVal<T>(); + return GetMinVal<T>(); } template<typename T> diff --git a/oneflow/core/kernel/mean_kernel.cpp b/oneflow/core/kernel/mean_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..79f5e84c32cd701403069b56fb65643dfb14dffe --- /dev/null +++ b/oneflow/core/kernel/mean_kernel.cpp @@ -0,0 +1,45 @@ +#include "oneflow/core/kernel/mean_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ndarray/ndarray_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void MeanKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + size_t count = in_blob->shape().elem_cnt() / out_blob->shape().elem_cnt(); + Blob* fw_tmp_blob = BnInOp2Blob("fw_tmp"); + size_t num_axes = in_blob->shape().NumAxes(); + NdarrayUtil<device_type, T>::ReduceSum(ctx.device_ctx, XpuVarNdarray<T>(out_blob, num_axes), + XpuVarNdarray<const T>(in_blob, num_axes), + XpuVarNdarray<T>(fw_tmp_blob, num_axes)); + + KernelUtil<device_type, T>::Div(ctx.device_ctx, out_blob->shape().elem_cnt(), + out_blob->mut_dptr<T>(), static_cast<T>(count)); +} + +template<DeviceType device_type, typename T> +void MeanKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + CHECK_EQ(1, out_diff_blob->shape().dim_vec().back()); + Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in")); + Blob* bw_tmp_blob = BnInOp2Blob("bw_tmp"); + + Memcpy<device_type>(ctx.device_ctx, bw_tmp_blob->mut_dptr(), out_diff_blob->dptr(), + out_diff_blob->ByteSizeOfDataContentField()); + size_t count = in_diff_blob->shape().elem_cnt() / out_diff_blob->shape().elem_cnt(); + // bw_tmp = out_diff/M + KernelUtil<device_type, T>::Div(ctx.device_ctx, bw_tmp_blob->shape().elem_cnt(), + bw_tmp_blob->mut_dptr<T>(), static_cast<T>(count)); + size_t num_axes = in_diff_blob->shape().NumAxes(); + NdarrayUtil<device_type, T>::template BroadcastApply<UnaryFuncIdentity>( + ctx.device_ctx, XpuVarNdarray<T>(in_diff_blob, num_axes), + XpuVarNdarray<const T>(out_diff_blob->shape(), bw_tmp_blob->dptr<T>())); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMeanConf, MeanKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/mean_kernel.h b/oneflow/core/kernel/mean_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2f4d45d39054d726e57d26feee194f4f2e9f6e52 --- /dev/null +++ b/oneflow/core/kernel/mean_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_MEAN_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_MEAN_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class MeanKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(MeanKernel); + MeanKernel() = default; + ~MeanKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + const PbMessage& GetCustomizedOpConf() const override { return this->op_conf().mean_conf(); } +}; + +} // namespace oneflow + +#endif // ONEFLOE_CORE_KERNEL_MEAN_KERNEL_H_ diff --git a/oneflow/core/kernel/normal_model_update_kernel.cpp b/oneflow/core/kernel/normal_model_update_kernel.cpp index 1857bde209b69520f4ef1ed654191f5d3d6707b5..1d6c1cedeac1e5f4567861a68bd27844e0139455 100644 --- a/oneflow/core/kernel/normal_model_update_kernel.cpp +++ b/oneflow/core/kernel/normal_model_update_kernel.cpp @@ -3,6 +3,7 @@ #include "oneflow/core/kernel/momentum_model_update_kernel.h" #include "oneflow/core/kernel/rmsprop_model_update_kernel.h" #include "oneflow/core/kernel/lars_model_update_kernel.h" +#include "oneflow/core/kernel/adam_model_update_kernel.h" namespace oneflow { @@ -13,13 +14,17 @@ void NormalMdUpdateKernel<device_type, T>::Forward( int64_t cur_batch_num = next_model_vid - 1; const NormalModelUpdateOpUserConf& conf = this->op_conf().normal_mdupdt_conf().user_conf(); float learning_rate = this->op_conf().normal_mdupdt_conf().learning_rate(); + const T* batch_instance_num_ptr = BnInOp2Blob("total_instance_num_diff")->dptr<T>(); + if (conf.has_clip_conf()) { + ClipGradient(ctx.device_ctx, cur_batch_num, conf.clip_conf(), batch_instance_num_ptr, + BnInOp2Blob); + } if (TriggerWarmup(conf, learning_rate, cur_batch_num)) { learning_rate = GetWarmupLearningRate(conf.warmup_conf(), learning_rate, cur_batch_num); } else if (conf.has_learning_rate_decay()) { learning_rate = GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num); } - const T* batch_instance_num_ptr = BnInOp2Blob("total_instance_num_diff")->dptr<T>(); float l1 = this->op_conf().normal_mdupdt_conf().l1(); float l2 = this->op_conf().normal_mdupdt_conf().l2(); UpdateModel(ctx.device_ctx, batch_instance_num_ptr, static_cast<T>(learning_rate), @@ -43,6 +48,8 @@ Kernel* CreateMdUpdtKernel(const KernelConf& kernel_conf) { return CreateRMSPropMdUpdtKernel(kernel_conf); } else if (user_conf.has_lars_conf()) { return CreateLARSMdUpdtKernel(kernel_conf); + } else if (user_conf.has_adam_conf()) { + return CreateAdamMdUpdtKernel(kernel_conf); } else { UNIMPLEMENTED(); } @@ -138,7 +145,7 @@ double ConstantWarmupLearningRate(const ConstantWarmupConf& conf, double lr, double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.warmup_batches(), 0); - CHECK_GT(conf.start_multiplier(), 0); + CHECK_GE(conf.start_multiplier(), 0); CHECK_LT(conf.start_multiplier(), 1); double start_multiplier = conf.start_multiplier(); double multiplier = 1.0; @@ -149,9 +156,32 @@ double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t return lr * multiplier; } +template<DeviceType device_type, typename T> +void ClipByGlobalNorm(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipByGlobalNormConf& conf, + const T* batch_instance_num_ptr, + std::function<Blob*(const std::string&)> BnInOp2Blob) { + int64_t n = BnInOp2Blob("model_diff")->shape().elem_cnt(); + T* model_diff = BnInOp2Blob("model_diff")->mut_dptr<T>(); + T* global_norm_ptr = BnInOp2Blob("data_tmp")->mut_dptr<T>(); + if (conf.has_global_norm()) { + KernelUtil<device_type, T>::Set(ctx, static_cast<T>(conf.global_norm()), global_norm_ptr); + } else { + // The Dot does not read the result, so the global_norm need not be initialized. + KernelUtil<device_type, T>::Dot(ctx, n, model_diff, 1, model_diff, 1, global_norm_ptr); + KernelUtil<device_type, T>::Sqrt(ctx, 1, global_norm_ptr, global_norm_ptr); + KernelUtil<device_type, T>::Div(ctx, 1, global_norm_ptr, batch_instance_num_ptr); + } + T* ratio_ptr = BnInOp2Blob("data_tmp")->mut_dptr<T>(); + NormalMdUpdateKernelUtil<device_type, T>::CmptClipRatioByGlobalNorm( + ctx, global_norm_ptr, static_cast<T>(conf.clip_norm()), ratio_ptr); + KernelUtil<device_type, T>::Scal(ctx, n, ratio_ptr, model_diff, 1); +} + } // namespace -bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, int64_t cur_batch_num) { +template<DeviceType device_type, typename T> +bool NormalMdUpdateKernel<device_type, T>::TriggerWarmup(const NormalModelUpdateOpUserConf& conf, + double lr, int64_t cur_batch_num) const { if (!conf.has_warmup_conf()) { return false; } const WarmupConf& warmup_conf = conf.warmup_conf(); if (warmup_conf.has_constant_conf()) { @@ -163,7 +193,10 @@ bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, int64_t c } } -double GetWarmupLearningRate(const WarmupConf& conf, double lr, int64_t cur_batch_num) { +template<DeviceType device_type, typename T> +double NormalMdUpdateKernel<device_type, T>::GetWarmupLearningRate(const WarmupConf& conf, + double lr, + int64_t cur_batch_num) const { if (conf.has_constant_conf()) { return ConstantWarmupLearningRate(conf.constant_conf(), lr, cur_batch_num); } else if (conf.has_linear_conf()) { @@ -173,7 +206,21 @@ double GetWarmupLearningRate(const WarmupConf& conf, double lr, int64_t cur_batc } } -double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) { +template<DeviceType device_type, typename T> +void NormalMdUpdateKernel<device_type, T>::ClipGradient( + DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf, + const T* batch_instance_num_ptr, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + if (conf.has_clip_by_global_norm()) { + ClipByGlobalNorm<device_type, T>(ctx, cur_batch_num, conf.clip_by_global_norm(), + batch_instance_num_ptr, BnInOp2Blob); + } else { + UNIMPLEMENTED(); + } +} + +template<DeviceType device_type, typename T> +double NormalMdUpdateKernel<device_type, T>::GetDecayedLearningRate( + const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) const { if (conf.has_exponential_conf()) { return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num); } else if (conf.has_inverse_time_conf()) { @@ -193,6 +240,15 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6 } } +template<typename T> +class NormalMdUpdateKernelUtil<DeviceType::kCPU, T> final { + public: + static void CmptClipRatioByGlobalNorm(DeviceCtx* ctx, const T* global_norm_ptr, T clip_norm, + T* ratio_ptr) { + *ratio_ptr = clip_norm / std::max(*global_norm_ptr, clip_norm); + } +}; + REGISTER_KERNEL_CREATOR(OperatorConf::kNormalMdupdtConf, CreateMdUpdtKernel); } // namespace oneflow diff --git a/oneflow/core/kernel/normal_model_update_kernel.cu b/oneflow/core/kernel/normal_model_update_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8ee75da9cae581ae9b71cb540e956da04849041 --- /dev/null +++ b/oneflow/core/kernel/normal_model_update_kernel.cu @@ -0,0 +1,29 @@ +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/kernel/normal_model_update_kernel.h" + +namespace oneflow { + +namespace { + +template<typename T> +__global__ void CmptClipRatioByGlobalNormGpu(const T* global_norm_ptr, T clip_norm, T* ratio_ptr) { + *ratio_ptr = clip_norm / max(*global_norm_ptr, clip_norm); +} + +} // namespace + +template<typename T> +class NormalMdUpdateKernelUtil<DeviceType::kGPU, T> final { + public: + static void CmptClipRatioByGlobalNorm(DeviceCtx* ctx, const T* global_norm_ptr, T clip_norm, + T* ratio_ptr) { + CmptClipRatioByGlobalNormGpu<T> + <<<1, 1, 0, ctx->cuda_stream()>>>(global_norm_ptr, clip_norm, ratio_ptr); + } +}; + +#define INSTANTIATE_GPU_KERNEL_UTIL(type_cpp, type_proto) \ + template class NormalMdUpdateKernelUtil<DeviceType::kGPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GPU_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/kernel/normal_model_update_kernel.h b/oneflow/core/kernel/normal_model_update_kernel.h index f3f8ff1901a21fba79518bf312bd0d296801dcb3..ae5accae68144baf2db3ac891b0818c5614a6623 100644 --- a/oneflow/core/kernel/normal_model_update_kernel.h +++ b/oneflow/core/kernel/normal_model_update_kernel.h @@ -19,11 +19,24 @@ class NormalMdUpdateKernel : public KernelIf<device_type> { virtual void UpdateModel(DeviceCtx* ctx, const T* batch_instance_num_ptr, T learning_rate, T l1, T l2, int64_t next_model_vid, std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0; + + private: + bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, + int64_t cur_batch_num) const; + double GetWarmupLearningRate(const WarmupConf&, double lr, int64_t cur_batch_num) const; + double GetDecayedLearningRate(const LearningRateDecayConf&, double lr, + int64_t cur_batch_num) const; + void ClipGradient(DeviceCtx* ctx, const int64_t cur_batch_num, const ClipConf& conf, + const T* batch_instance_num_ptr, + std::function<Blob*(const std::string&)> BnInOp2Blob) const; }; -bool TriggerWarmup(const NormalModelUpdateOpUserConf& conf, double lr, int64_t cur_batch_num); -double GetWarmupLearningRate(const WarmupConf&, double lr, int64_t cur_batch_num); -double GetDecayedLearningRate(const LearningRateDecayConf&, double lr, int64_t cur_batch_num); +template<DeviceType device_type, typename T> +class NormalMdUpdateKernelUtil final { + public: + static void CmptClipRatioByGlobalNorm(DeviceCtx* ctx, const T* global_norm_ptr, T clip_norm, + T* ratio_ptr); +}; #define DECLARE_MDUPDT_KERNEL_CREATOR(x) Kernel* Create##x##MdUpdtKernel(const KernelConf&); diff --git a/oneflow/core/kernel/one_hot_kernel.cpp b/oneflow/core/kernel/one_hot_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd9ba25925680945dd01a3ac3161ff50ba0ca7a6 --- /dev/null +++ b/oneflow/core/kernel/one_hot_kernel.cpp @@ -0,0 +1,57 @@ +#include "oneflow/core/kernel/one_hot_kernel.h" + +namespace oneflow { + +namespace { + +template<DeviceType device_type, typename T, typename K> +void OneHot(DeviceCtx* ctx, const Blob* indices, Blob* out) { + const int64_t depth = out->shape().At(out->shape().NumAxes() - 1); + OneHotKernelUtil<device_type, T, K>::Encode(ctx, indices->dptr<K>(), indices->shape().elem_cnt(), + depth, out->mut_dptr<T>()); +} + +} // namespace + +template<DeviceType device_type, typename T> +struct OneHotUtil final { +#define MAKE_ONE_HOT_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K> + DEFINE_STATIC_SWITCH_FUNC(void, OneHot, MAKE_ONE_HOT_SWITCH_ENTRY, + MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); +#undef MAKE_ONE_HOT_SWITCH_ENTRY +}; + +template<DeviceType device_type, typename T> +const PbMessage& OneHotKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().one_hot_conf(); +} + +template<DeviceType device_type, typename T> +void OneHotKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* indices = BnInOp2Blob("indices"); + Blob* out = BnInOp2Blob("out"); + OneHotUtil<device_type, T>::SwitchOneHot(SwitchCase(indices->data_type()), ctx.device_ctx, + indices, out); +} + +template<typename T, typename K> +struct OneHotKernelUtil<DeviceType::kCPU, T, K> final { + static void Encode(DeviceCtx* ctx, const K* indices, int64_t num_indices, int64_t depth, T* out); +}; + +template<typename T, typename K> +void OneHotKernelUtil<DeviceType::kCPU, T, K>::Encode(DeviceCtx* ctx, const K* indices, + int64_t num_indices, int64_t depth, T* out) { + Memset<kCPU>(ctx, out, 0, num_indices * depth * sizeof(T)); + FOR_RANGE(int64_t, i, 0, num_indices) { + const K idx = indices[i]; + CHECK_GE(idx, 0); + CHECK_LT(idx, depth); + out[i * depth + idx] = OneVal<T>::value; + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kOneHotConf, OneHotKernel, ARITHMETIC_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/one_hot_kernel.cu b/oneflow/core/kernel/one_hot_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c996d9cea850659dcb2fb5ccd09e3529dc01f6b7 --- /dev/null +++ b/oneflow/core/kernel/one_hot_kernel.cu @@ -0,0 +1,43 @@ +#include "oneflow/core/kernel/one_hot_kernel.h" +#include "oneflow/core/kernel/kernel_util.cuh" +#include <assert.h> + +namespace oneflow { + +namespace { + +template<typename T, typename K> +__global__ void OneHotEncodeGpu(int64_t elem_cnt, const K* indices, int64_t depth, T* out) { + CUDA_1D_KERNEL_LOOP(i, elem_cnt) { + const int64_t row = i / depth; + const int64_t col = i % depth; + const int64_t idx = indices[row]; + assert(idx >= 0 && idx < depth); + out[i] = (idx == col); + } +} + +} // namespace + +template<typename T, typename K> +struct OneHotKernelUtil<DeviceType::kGPU, T, K> final { + static void Encode(DeviceCtx* ctx, const K* indices, int64_t num_indices, int64_t depth, T* out); +}; + +template<typename T, typename K> +void OneHotKernelUtil<DeviceType::kGPU, T, K>::Encode(DeviceCtx* ctx, const K* indices, + int64_t num_indices, int64_t depth, T* out) { + const int64_t elem_cnt = num_indices * depth; + OneHotEncodeGpu<T, K> + <<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + elem_cnt, indices, depth, out); +} + +#define INSTANTIATE_ONE_HOT_KERNEL_UTIL_GPU(data_type_pair, index_type_pair) \ + template struct OneHotKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), \ + OF_PP_PAIR_FIRST(index_type_pair)>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_ONE_HOT_KERNEL_UTIL_GPU, ARITHMETIC_DATA_TYPE_SEQ, + INT_DATA_TYPE_SEQ); +#undef INSTANTIATE_ONE_HOT_KERNEL_UTIL_GPU + +} // namespace oneflow diff --git a/oneflow/core/kernel/one_hot_kernel.h b/oneflow/core/kernel/one_hot_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5d5c8bd7fe7d744125d0b09b85f4097705faf5f4 --- /dev/null +++ b/oneflow/core/kernel/one_hot_kernel.h @@ -0,0 +1,28 @@ +#ifndef ONEFLOW_CORE_KERNEL_ONE_HOT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_ONE_HOT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class OneHotKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(OneHotKernel) + OneHotKernel() = default; + ~OneHotKernel() = default; + + private: + const PbMessage& GetCustomizedOpConf() const override; + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +template<DeviceType device_type, typename T, typename K> +struct OneHotKernelUtil final { + static void Encode(DeviceCtx* ctx, const K* indices, int64_t num_indices, int64_t depth, T* out); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_ONE_HOT_KERNEL_H_ diff --git a/oneflow/core/kernel/opkernel_test_case.cpp b/oneflow/core/kernel/opkernel_test_case.cpp index 798d18796c4f1b83893eccebc0747747d31b8fd6..5c0544bdd3f3494d3b7e20ec1e68c12b3b734a37 100644 --- a/oneflow/core/kernel/opkernel_test_case.cpp +++ b/oneflow/core/kernel/opkernel_test_case.cpp @@ -420,8 +420,8 @@ void OpKernelTestCase::InferBlobDesc(std::shared_ptr<Operator>* op, OpContext** op_conf_.set_device_type(default_device_type_); *op = ConstructOp(op_conf_); if (NeedInferBlobDescs(op->get())) { - (*op)->InferBlobDescs(MakeGetterBnInOp2BlobDesc(), ¶llel_ctx_, - [&](OpContext* op_ctx) { *op_context = op_ctx; }); + (*op)->InferBlobDescsIf(MakeGetterBnInOp2BlobDesc(), ¶llel_ctx_, 1, + [&](OpContext* op_ctx) { *op_context = op_ctx; }); } } diff --git a/oneflow/core/kernel/pack_kernel.h b/oneflow/core/kernel/pack_kernel.h index 66e5c0be9e69a4d08a4b1d439651d86f275a4309..0c73cc66bbbcacab1722e7347dc32d64254c0cf9 100644 --- a/oneflow/core/kernel/pack_kernel.h +++ b/oneflow/core/kernel/pack_kernel.h @@ -31,6 +31,10 @@ class PackKernel final : public KernelIf<device_type> { const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; void ForwardDataId(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BackwardInDiffLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override { + UNIMPLEMENTED(); + } }; } // namespace oneflow diff --git a/oneflow/core/kernel/random_generator.cpp b/oneflow/core/kernel/random_generator.cpp index f47526032f26d766e2b64756d7228fc4a9ad24f8..c939d075b5164015e5e610ed35b1a64d7b525bc2 100644 --- a/oneflow/core/kernel/random_generator.cpp +++ b/oneflow/core/kernel/random_generator.cpp @@ -14,7 +14,7 @@ void RandomGenerator<DeviceType::kCPU>::Uniform(const int64_t elem_cnt, const T CHECK_GE(elem_cnt, 0); CHECK(dptr); CHECK_LE(min, max); - std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, MaxVal<T>())); + std::uniform_real_distribution<T> random_distribution(min, std::nextafter(max, GetMaxVal<T>())); for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(mt19937_generator_); } } diff --git a/oneflow/core/kernel/record_load_kernel.cpp b/oneflow/core/kernel/record_load_kernel.cpp index 02c479a5c7795731d2652cd84a92cc3e8b60bec4..7cd3f403023a2a937b71dffab45d4e1ccf6af16b 100644 --- a/oneflow/core/kernel/record_load_kernel.cpp +++ b/oneflow/core/kernel/record_load_kernel.cpp @@ -21,24 +21,39 @@ void RecordLoadKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0); data_paths.push_back(JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num)); } + piece_size_in_one_loader_ = kernel_conf().record_load_conf().device_piece_size(); if (Global<JobDesc>::Get()->IsTrain()) { + const size_t num_max_read = + static_cast<size_t>(piece_size_in_one_loader_ * Global<JobDesc>::Get()->TotalBatchNum() + * Global<JobDesc>::Get()->NumOfPiecesInBatch()); in_stream_.reset(new PersistentInStream( DataFS(), data_paths, true, Global<JobDesc>::Get()->save_downloaded_file_to_local_fs())); + if (record_load_conf.has_random_shuffle_conf()) { + const int32_t shuffle_buffer_size = record_load_conf.random_shuffle_conf().buffer_size(); + CHECK_GT(shuffle_buffer_size, 0); + record_reader_.reset(new RandomShuffleOFRecordReader( + in_stream_.get(), static_cast<size_t>(shuffle_buffer_size), num_max_read)); + } else { + record_reader_.reset(new NaiveOFRecordReader(in_stream_.get(), num_max_read)); + } } else { in_stream_.reset(new PersistentInStream(DataFS(), data_paths, false, false)); + if (record_load_conf.has_random_shuffle_conf()) { + const int32_t shuffle_buffer_size = record_load_conf.random_shuffle_conf().buffer_size(); + CHECK_GT(shuffle_buffer_size, 0); + record_reader_.reset(new RandomShuffleOFRecordReader( + in_stream_.get(), static_cast<size_t>(shuffle_buffer_size))); + } else { + record_reader_.reset(new NaiveOFRecordReader(in_stream_.get())); + } } - int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize(); - CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0); - piece_size_in_one_loader_ = global_piece_size / parallel_ctx->parallel_num(); } void RecordLoadKernel::Forward(const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { auto status = static_cast<RecordLoadStatus*>(ctx.other); - Blob* out_blob = BnInOp2Blob("out"); - RecordBlob<OFRecord> record_blob(out_blob); - record_blob.ReadFrom(in_stream_.get()); - status->record_num = record_blob.record_num(); + status->record_num = record_reader_->Read(static_cast<size_t>(piece_size_in_one_loader_), + BnInOp2Blob("out")->mut_dptr<OFRecord>()); if (status->record_num < piece_size_in_one_loader_) { status->is_eof = true; } } diff --git a/oneflow/core/kernel/record_load_kernel.h b/oneflow/core/kernel/record_load_kernel.h index d21b26cf8cc80fb6d31f93d4ba7af50dbff880f7..42472836da1db1daff10ed4725cd3de0e5b0d37b 100644 --- a/oneflow/core/kernel/record_load_kernel.h +++ b/oneflow/core/kernel/record_load_kernel.h @@ -3,6 +3,7 @@ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/persistence/persistent_in_stream.h" +#include "oneflow/core/record/ofrecord_reader.h" namespace oneflow { @@ -15,7 +16,7 @@ class RecordLoadKernel final : public KernelIf<DeviceType::kCPU> { public: OF_DISALLOW_COPY_AND_MOVE(RecordLoadKernel); RecordLoadKernel() = default; - ~RecordLoadKernel() = default; + ~RecordLoadKernel() override = default; private: void VirtualKernelInit(const ParallelContext*) override; @@ -23,6 +24,7 @@ class RecordLoadKernel final : public KernelIf<DeviceType::kCPU> { std::function<Blob*(const std::string&)> BnInOp2Blob) const override; std::unique_ptr<PersistentInStream> in_stream_; + std::unique_ptr<OFRecordReader> record_reader_; int64_t piece_size_in_one_loader_; }; diff --git a/oneflow/core/kernel/reduce_concat_kernel.cpp b/oneflow/core/kernel/reduce_concat_kernel.cpp index f2884fe55973eab00a51c144ce7e90331a8f708d..f4baf582ad985b0cc24734399758ef0e65eca343 100644 --- a/oneflow/core/kernel/reduce_concat_kernel.cpp +++ b/oneflow/core/kernel/reduce_concat_kernel.cpp @@ -5,16 +5,15 @@ namespace oneflow { template<DeviceType device_type> void ReduceConcatKernel<device_type>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - const auto* other_val = static_cast<std::pair<int64_t, bool>*>(ctx.other); - int64_t in_bn_id = other_val->first; - bool is_inplace = other_val->second; - if (is_inplace) { return; } + if (device_type == DeviceType::kGPU && Global<JobDesc>::Get()->enable_mem_sharing()) { return; } Blob* out_blob = BnInOp2Blob("out"); char* dst_cur_dptr = out_blob->mut_dptr<char>(); - dst_cur_dptr += this->kernel_conf().reduce_concat_conf().data_offset().Get(in_bn_id); - Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id)); - size_t in_byte_size = in_blob->ByteSizeOfDataContentField(); - Memcpy<device_type>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size); + FOR_RANGE(int, in_bn_id, 0, this->op_attribute().input_bns().size()) { + dst_cur_dptr += this->kernel_conf().reduce_concat_conf().data_offset().Get(in_bn_id); + Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id)); + size_t in_byte_size = in_blob->ByteSizeOfDataContentField(); + Memcpy<device_type>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size); + } } ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceConcatConf, ReduceConcatKernel); diff --git a/oneflow/core/kernel/reduce_identity_kernel.cpp b/oneflow/core/kernel/reduce_identity_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6dab91c4cec96ac725ac8a1291de5a2465349a4 --- /dev/null +++ b/oneflow/core/kernel/reduce_identity_kernel.cpp @@ -0,0 +1,17 @@ +#include "oneflow/core/kernel/reduce_identity_kernel.h" + +namespace oneflow { + +template<DeviceType device_type> +void ReduceIdentityKernel<device_type>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + CHECK_EQ(out_blob->ByteSizeOfDataContentField(), in_blob->ByteSizeOfDataContentField()); + Memcpy<device_type>(ctx.device_ctx, out_blob->mut_dptr(), in_blob->dptr(), + out_blob->ByteSizeOfDataContentField()); +} + +ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceIdentityConf, ReduceIdentityKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/reduce_identity_kernel.h b/oneflow/core/kernel/reduce_identity_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1994fa1d4e98becf5adcde1eadcc64cd36312376 --- /dev/null +++ b/oneflow/core/kernel/reduce_identity_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_REDUCE_IDENTITY_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_REDUCE_IDENTITY_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type> +class ReduceIdentityKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceIdentityKernel); + ReduceIdentityKernel() = default; + ~ReduceIdentityKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override { + return this->op_conf().reduce_identity_conf(); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_REDUCE_IDENTITY_KERNEL_H_ diff --git a/oneflow/core/kernel/reduce_mean_kernel.cpp b/oneflow/core/kernel/reduce_mean_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b42148e1c0adde893e716a93c023d7c1ab3f855 --- /dev/null +++ b/oneflow/core/kernel/reduce_mean_kernel.cpp @@ -0,0 +1,43 @@ +#include "oneflow/core/kernel/reduce_mean_kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void ReduceMeanKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + Blob* fw_tmp_blob = BnInOp2Blob("fw_tmp"); + size_t count = in_blob->shape().elem_cnt() / out_blob->shape().elem_cnt(); + NdarrayUtil<device_type, T>::ReduceSum( + ctx.device_ctx, + XpuVarNdarray<T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()), + out_blob->mut_dptr<T>()), + XpuVarNdarray<const T>(in_blob, in_blob->shape().NumAxes()), + XpuVarNdarray<T>(fw_tmp_blob, in_blob->shape().NumAxes())); + KernelUtil<device_type, T>::Div(ctx.device_ctx, out_blob->shape().elem_cnt(), + out_blob->mut_dptr<T>(), static_cast<T>(count)); +} + +template<DeviceType device_type, typename T> +void ReduceMeanKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + Blob* in_diff_blob = BnInOp2Blob("in_diff"); + Blob* bw_tmp_blob = BnInOp2Blob("bw_tmp"); + size_t count = in_diff_blob->shape().elem_cnt() / out_diff_blob->shape().elem_cnt(); + Memcpy<device_type>(ctx.device_ctx, bw_tmp_blob->mut_dptr(), out_diff_blob->dptr(), + out_diff_blob->ByteSizeOfDataContentField()); + KernelUtil<device_type, T>::Div(ctx.device_ctx, bw_tmp_blob->shape().elem_cnt(), + bw_tmp_blob->mut_dptr<T>(), static_cast<T>(count)); + NdarrayUtil<device_type, T>::BroadcastTo( + ctx.device_ctx, XpuVarNdarray<T>(in_diff_blob, in_diff_blob->shape().NumAxes()), + XpuVarNdarray<const T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()), + bw_tmp_blob->dptr<T>())); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kReduceMeanConf, ReduceMeanKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/reduce_mean_kernel.h b/oneflow/core/kernel/reduce_mean_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e3541bf50fb697fcf21db1d3ac46b93a2ecdb128 --- /dev/null +++ b/oneflow/core/kernel/reduce_mean_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_REDUCE_MEAN_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_REDUCE_MEAN_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class ReduceMeanKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceMeanKernel); + ReduceMeanKernel() = default; + ~ReduceMeanKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + + void BackwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_REDUCE_MEAN_KERNEL_H_ diff --git a/oneflow/core/kernel/reduce_split_kernel.cpp b/oneflow/core/kernel/reduce_split_kernel.cpp index 847c7b2ead2bbcd94992831efda0694e813353b8..1e2fefc57a3a23f0ee59fa0db4b3f86c8dcf65e2 100644 --- a/oneflow/core/kernel/reduce_split_kernel.cpp +++ b/oneflow/core/kernel/reduce_split_kernel.cpp @@ -5,8 +5,7 @@ namespace oneflow { template<DeviceType device_type> void ReduceSplitKernel<device_type>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - bool is_inplace = *static_cast<bool*>(ctx.other); - if (is_inplace) { return; } + if (device_type == DeviceType::kGPU && Global<JobDesc>::Get()->enable_mem_sharing()) { return; } const Blob* in_blob = BnInOp2Blob("in"); const char* src_cur_dptr = in_blob->dptr<char>(); for (const std::string& obn : this->op_attribute().output_bns()) { diff --git a/oneflow/core/kernel/reduce_sum_kernel.cpp b/oneflow/core/kernel/reduce_sum_kernel.cpp index 0da31f15a962ea43609b6b6d730d4df9c6526a5f..ff3ea04a6b18a416aa97c8e15b34e6f58c3606fd 100644 --- a/oneflow/core/kernel/reduce_sum_kernel.cpp +++ b/oneflow/core/kernel/reduce_sum_kernel.cpp @@ -1,4 +1,5 @@ #include "oneflow/core/kernel/reduce_sum_kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { @@ -8,25 +9,12 @@ void ReduceSumKernel<device_type, T>::ForwardDataContent( const Blob* in_blob = BnInOp2Blob("in"); Blob* out_blob = BnInOp2Blob("out"); Blob* fw_tmp_blob = BnInOp2Blob("fw_tmp"); - if (this->kernel_conf().reduce_sum_conf().has_axis() == false) { - KernelUtil<device_type, T>::Sum(ctx.device_ctx, in_blob->shape().elem_cnt(), in_blob->dptr<T>(), - out_blob->mut_dptr<T>(), fw_tmp_blob->mut_dptr<T>(), - fw_tmp_blob->ByteSizeOfDataContentField()); - return; - } - int32_t axis = this->kernel_conf().reduce_sum_conf().axis(); - int64_t lhs_num = in_blob->shape().Count(0, axis); - int64_t middle_num = in_blob->shape().At(axis); - int64_t rhs_num = in_blob->shape().Count(axis + 1); - FOR_RANGE(int64_t, lhs_i, 0, lhs_num) { - const T* src_ptr = in_blob->dptr<T>() + lhs_i * middle_num * rhs_num; - T* dst_ptr = out_blob->mut_dptr<T>() + lhs_i * rhs_num; - Memcpy<device_type>(ctx.device_ctx, dst_ptr, src_ptr, rhs_num * sizeof(T)); - FOR_RANGE(int64_t, middle_i, 1, middle_num) { - KernelUtil<device_type, T>::Axpy(ctx.device_ctx, rhs_num, 1.0f, src_ptr + middle_i * rhs_num, - 1, dst_ptr, 1); - } - } + NdarrayUtil<device_type, T>::ReduceSum( + ctx.device_ctx, + XpuVarNdarray<T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()), + out_blob->mut_dptr<T>()), + XpuVarNdarray<const T>(in_blob, in_blob->shape().NumAxes()), + XpuVarNdarray<T>(fw_tmp_blob, in_blob->shape().NumAxes())); } template<DeviceType device_type, typename T> @@ -34,28 +22,10 @@ void ReduceSumKernel<device_type, T>::BackwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* out_diff_blob = BnInOp2Blob("out_diff"); Blob* in_diff_blob = BnInOp2Blob("in_diff"); - - if (this->kernel_conf().reduce_sum_conf().has_axis() == false) { - T* dst_ptr = in_diff_blob->mut_dptr<T>(); - const T* src_ptr = out_diff_blob->dptr<T>(); - FOR_RANGE(int64_t, i, 0, in_diff_blob->shape().Count(0)) { - Memcpy<device_type>(ctx.device_ctx, dst_ptr++, src_ptr, sizeof(T)); - } - return; - } - - int32_t axis = this->kernel_conf().reduce_sum_conf().axis(); - int64_t lhs_num = in_diff_blob->shape().Count(0, axis); - int64_t middle_num = in_diff_blob->shape().At(axis); - int64_t rhs_num = in_diff_blob->shape().Count(axis + 1); - FOR_RANGE(int64_t, lhs_i, 0, lhs_num) { - const T* src_ptr = out_diff_blob->dptr<T>() + lhs_i * rhs_num; - T* dst_ptr = in_diff_blob->mut_dptr<T>() + lhs_i * middle_num * rhs_num; - FOR_RANGE(int64_t, middle_i, 0, middle_num) { - Memcpy<device_type>(ctx.device_ctx, dst_ptr, src_ptr, rhs_num * sizeof(T)); - dst_ptr += rhs_num; - } - } + NdarrayUtil<device_type, T>::BroadcastTo( + ctx.device_ctx, XpuVarNdarray<T>(in_diff_blob, in_diff_blob->shape().NumAxes()), + XpuVarNdarray<const T>(Shape(this->kernel_conf().reduce_sum_conf().kept_dims_shape()), + out_diff_blob->dptr<T>())); } ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kReduceSumConf, ReduceSumKernel, ARITHMETIC_DATA_TYPE_SEQ); diff --git a/oneflow/core/kernel/repeat_kernel.h b/oneflow/core/kernel/repeat_kernel.h index ea4a1864e27756a70e2397d0ab1ba60930d5df8d..54116c5d3e4ceb737bd4f8cb9d76b17960551255 100644 --- a/oneflow/core/kernel/repeat_kernel.h +++ b/oneflow/core/kernel/repeat_kernel.h @@ -17,6 +17,10 @@ class RepeatKernel final : public KernelIf<device_type> { std::function<Blob*(const std::string&)>) const override; void BackwardDataContent(const KernelCtx&, std::function<Blob*(const std::string&)>) const override; + void BackwardInDiffLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override { + UNIMPLEMENTED(); + } }; } // namespace oneflow diff --git a/oneflow/core/kernel/rsqrt_kernel.cpp b/oneflow/core/kernel/rsqrt_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1489deeb653ea5f735ff2420fc15326a493b36c7 --- /dev/null +++ b/oneflow/core/kernel/rsqrt_kernel.cpp @@ -0,0 +1,19 @@ +#include "oneflow/core/kernel/rsqrt_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void RsqrtKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + TODO(); +} + +template<DeviceType device_type, typename T> +void RsqrtKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + TODO(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kRsqrtConf, RsqrtKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/rsqrt_kernel.h b/oneflow/core/kernel/rsqrt_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..618de4c693040e9f1e2ab80778de78e48bab4a50 --- /dev/null +++ b/oneflow/core/kernel/rsqrt_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_RSQRT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_RSQRT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class RsqrtKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(RsqrtKernel); + RsqrtKernel() = default; + ~RsqrtKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_RSQRT_KERNEL_H_ diff --git a/oneflow/core/kernel/scalar_add_kernel.cpp b/oneflow/core/kernel/scalar_add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbee81b59086d4288a9112e6dd50a005204da441 --- /dev/null +++ b/oneflow/core/kernel/scalar_add_kernel.cpp @@ -0,0 +1,37 @@ +#include "oneflow/core/kernel/scalar_add_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void ScalarAddKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + T scalar_operand = 0; + const auto& conf = this->op_conf().scalar_add_conf(); + if (IsIntegral<T>::value) { + CHECK(conf.has_int_operand()); + scalar_operand = static_cast<T>(conf.int_operand()); + } else if (IsFloating<T>::value) { + CHECK(conf.has_float_operand()); + scalar_operand = static_cast<T>(conf.float_operand()); + } else { + UNIMPLEMENTED(); + } + KernelUtil<device_type, T>::AddByScalar(ctx.device_ctx, out_blob->shape().elem_cnt(), + in_blob->dptr<T>(), scalar_operand, + out_blob->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void ScalarAddKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in")); + Memcpy<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), out_diff_blob->dptr<T>(), + out_diff_blob->ByteSizeOfDataContentField()); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kScalarAddConf, ScalarAddKernel, ARITHMETIC_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/scalar_add_kernel.h b/oneflow/core/kernel/scalar_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9a34ec190805571047931bd838dd76205aa5980d --- /dev/null +++ b/oneflow/core/kernel/scalar_add_kernel.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_KERNEL_SCALAR_ADD_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SCALAR_ADD_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class ScalarAddKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(ScalarAddKernel); + ScalarAddKernel() = default; + ~ScalarAddKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BackwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override { + return this->op_conf().scalar_add_conf(); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SCALAR_ADD_KERNEL_H_ diff --git a/oneflow/core/kernel/scalar_mul_kernel.cpp b/oneflow/core/kernel/scalar_mul_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7ed70acb5ac26a4bddba91c90cdfdc6829551f0 --- /dev/null +++ b/oneflow/core/kernel/scalar_mul_kernel.cpp @@ -0,0 +1,49 @@ +#include "oneflow/core/kernel/scalar_mul_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void ScalarMulKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + T scalar_operand = 0; + const auto& conf = this->op_conf().scalar_mul_conf(); + if (IsIntegral<T>::value) { + CHECK(conf.has_int_operand()); + scalar_operand = static_cast<T>(conf.int_operand()); + } else if (IsFloating<T>::value) { + CHECK(conf.has_float_operand()); + scalar_operand = static_cast<T>(conf.float_operand()); + } else { + UNIMPLEMENTED(); + } + KernelUtil<device_type, T>::MulByScalarPara(ctx.device_ctx, out_blob->shape().elem_cnt(), + in_blob->dptr<T>(), scalar_operand, + out_blob->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void ScalarMulKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in")); + T scalar_operand = 0; + const auto& conf = this->op_conf().scalar_mul_conf(); + if (IsIntegral<T>::value) { + CHECK(conf.has_int_operand()); + scalar_operand = static_cast<T>(conf.int_operand()); + } else if (IsFloating<T>::value) { + CHECK(conf.has_float_operand()); + scalar_operand = static_cast<T>(conf.float_operand()); + } else { + UNIMPLEMENTED(); + } + KernelUtil<device_type, T>::MulByScalarPara(ctx.device_ctx, in_diff_blob->shape().elem_cnt(), + out_diff_blob->dptr<T>(), scalar_operand, + in_diff_blob->mut_dptr<T>()); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kScalarMulConf, ScalarMulKernel, ARITHMETIC_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/scalar_mul_kernel.h b/oneflow/core/kernel/scalar_mul_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2dca5e04c661c9118741bad419ebb5971af8fb47 --- /dev/null +++ b/oneflow/core/kernel/scalar_mul_kernel.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class ScalarMulKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(ScalarMulKernel); + ScalarMulKernel() = default; + ~ScalarMulKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BackwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override { + return this->op_conf().scalar_mul_conf(); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_ diff --git a/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cpp b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2203faa82cae07340423d04182f53518a05127a8 --- /dev/null +++ b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cpp @@ -0,0 +1,110 @@ +#include "oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h" +#include "oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h" +#include "oneflow/core/kernel/softmax_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename PredType, typename LabelType> +void SigmoidCrossEntropyLossKernel<device_type, PredType, LabelType>::VirtualLossForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const SigmoidCrossEntropyLossOpConf& conf = this->op_conf().sigmoid_cross_entropy_loss_conf(); + const Blob* prediction = BnInOp2Blob("prediction"); + const Blob* label = BnInOp2Blob("label"); + Blob* loss_buf = BnInOp2Blob("loss_buf"); + Blob* tmp_storage = BnInOp2Blob("sum_buf"); + const size_t tmp_storage_byte_size = static_cast<size_t>(tmp_storage->shape().elem_cnt()); + Blob* count = BnInOp2Blob("count"); + Blob* label_num = BnInOp2Blob("label_num"); + Blob* loss = BnInOp2Blob("loss"); + Blob* prediction_diff = BnInOp2Blob(GenDiffBn("prediction")); + int64_t data_dim = label->shape().Count(1); + int64_t data_offset = 0; + if (prediction_diff != nullptr) { + Memset<device_type>(ctx.device_ctx, prediction_diff->mut_dptr<PredType>(), 0, + prediction_diff->ByteSizeOfDataContentField()); + } + FOR_RANGE(int64_t, data_index, 0, prediction->shape().At(0)) { + data_offset = data_dim * data_index; + const PredType* prediction_offset = prediction->dptr<PredType>() + data_offset; + const LabelType* label_offset = label->dptr<LabelType>() + data_offset; + SigmoidCrossEntropyLossKernelUtil<device_type, PredType, LabelType>::Forward( + ctx.device_ctx, conf, data_dim, prediction_offset, label_offset, + loss_buf->mut_dptr<PredType>(), tmp_storage->mut_dptr<PredType>(), tmp_storage_byte_size, + count->mut_dptr<PredType>(), label_num->mut_dptr<PredType>(), + loss->mut_dptr<PredType>() + data_index); + if (prediction_diff != nullptr) { + SigmoidCrossEntropyLossKernelUtil<device_type, PredType, LabelType>::Backward( + ctx.device_ctx, conf, data_dim, prediction_offset, label_offset, + label_num->dptr<PredType>(), prediction_diff->mut_dptr<PredType>() + data_offset); + } + } +} + +template<DeviceType device_type, typename PredType, typename LabelType> +const LossKernelConf& +SigmoidCrossEntropyLossKernel<device_type, PredType, LabelType>::GetLossKernelConf( + const KernelConf& kernel_conf) const { + return kernel_conf.sigmoid_cross_entropy_loss_conf().loss_conf(); +} + +template<typename PredType, typename LabelType> +struct SigmoidCrossEntropyLossKernelUtil<DeviceType::kCPU, PredType, LabelType> { + static void Forward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n, + const PredType* prediction, const LabelType* label, PredType* loss_buf, + PredType* tmp_storage, const size_t tmp_storage_byte_size, PredType* count, + PredType* label_num, PredType* loss) { + loss_buf[0] = 0; + loss[0] = 0; + label_num[0] = 0; + FOR_RANGE(int64_t, index, 0, n) { + if (label[index] != -1) { + loss_buf[0] += + -1 * prediction[index] * (label[index] - (prediction[index] >= 0)) + + logf(1 + expf(prediction[index] - 2 * prediction[index] * (prediction[index] >= 0))); + label_num[0] += 1; + } + } + loss_buf[0] *= static_cast<PredType>(conf.scale()); + if (conf.normalize()) { + if (label_num[0] == 0) { label_num[0] = 1e-5; } + loss[0] = loss_buf[0] / label_num[0]; + } + } + + static void Backward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n, + const PredType* prediction, const LabelType* label, + const PredType* label_num, PredType* pred_diff) { + FOR_RANGE(int64_t, index, 0, n) { + if (label[index] != -1) { + pred_diff[index] = 1.f / (1.f + expf(-prediction[index])) - label[index]; + pred_diff[index] *= static_cast<PredType>(conf.scale()); + if (conf.normalize()) { pred_diff[index] /= label_num[0]; } + } + } + } +}; + +namespace { + +Kernel* CreateSigmoidCrossEntropyLossKernel(const KernelConf& kernel_conf) { + static const HashMap<std::string, std::function<Kernel*()>> creators = { +#define SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, label_type_pair) \ + {GetHashKey(device_type, OF_PP_PAIR_SECOND(pred_type_pair), OF_PP_PAIR_SECOND(label_type_pair)), \ + []() { \ + return new SigmoidCrossEntropyLossKernel<device_type, OF_PP_PAIR_FIRST(pred_type_pair), \ + OF_PP_PAIR_FIRST(label_type_pair)>(); \ + }}, + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_ENTRY, DEVICE_TYPE_SEQ, + FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)}; + return creators.at( + GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), + kernel_conf.sigmoid_cross_entropy_loss_conf().loss_conf().prediction_type(), + kernel_conf.sigmoid_cross_entropy_loss_conf().loss_conf().label_type()))(); +} + +} // namespace + +REGISTER_KERNEL_CREATOR(OperatorConf::kSigmoidCrossEntropyLossConf, + CreateSigmoidCrossEntropyLossKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cu b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..fe7355e06195a3f586d25c4a2033d9c6e5bb8a76 --- /dev/null +++ b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.cu @@ -0,0 +1,83 @@ +#include "oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h" + +namespace oneflow { + +namespace { +template<typename PredType> +__global__ void NoSmallerThan(const int n, PredType* x, const float floor_val) { + CUDA_1D_KERNEL_LOOP(index, n) { x[index] = (x[index] > floor_val) ? x[index] : floor_val; } +} + +template<typename PredType, typename LabelType> +__global__ void SigmoidCrossEntropyLossForward(const int64_t n, const PredType* prediction, + const LabelType* label, PredType* loss_buf, + PredType* count) { + CUDA_1D_KERNEL_LOOP(index, n) { + if (label[index] == -1) { + loss_buf[index] = 0.f; + count[index] = 0.f; + } else { + loss_buf[index] = + -1.f * prediction[index] * (label[index] - (prediction[index] >= 0)) + + logf(1 + expf(prediction[index] - 2 * prediction[index] * (prediction[index] >= 0))); + count[index] = 1.f; + } + } +} + +template<typename PredType, typename LabelType> +__global__ void SigmoidCrossEntropyLossBackward(const int64_t n, const PredType* prediction, + const LabelType* label, PredType* pred_diff) { + CUDA_1D_KERNEL_LOOP(index, n) { + if (label[index] == -1) { + pred_diff[index] = 0.f; + } else { + pred_diff[index] = 1.f / (1.f + expf(-prediction[index])) - label[index]; + } + } +} +} // namespace + +template<typename PredType, typename LabelType> +struct SigmoidCrossEntropyLossKernelUtil<DeviceType::kGPU, PredType, LabelType> { + static void Forward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n, + const PredType* prediction, const LabelType* label, PredType* loss_buf, + PredType* tmp_storage, const size_t tmp_storage_byte_size, PredType* count, + PredType* label_num, PredType* loss) { + SigmoidCrossEntropyLossForward<PredType> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + n, prediction, label, loss_buf, count); + KernelUtil<DeviceType::kGPU, PredType>::Sum(ctx, n, loss_buf, loss, tmp_storage, + tmp_storage_byte_size); + if (conf.normalize()) { + KernelUtil<DeviceType::kGPU, PredType>::Sum(ctx, n, count, label_num, tmp_storage, + tmp_storage_byte_size); + NoSmallerThan<PredType> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + 1, label_num, 1e-5); + KernelUtil<DeviceType::kGPU, PredType>::Div(ctx, 1, loss, label_num); + } + KernelUtil<DeviceType::kGPU, PredType>::Scal(ctx, 1, static_cast<PredType>(conf.scale()), loss, + 1); + } + + static void Backward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n, + const PredType* prediction, const LabelType* label, + const PredType* label_num, PredType* pred_diff) { + SigmoidCrossEntropyLossBackward<PredType> + <<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( + n, prediction, label, pred_diff); + KernelUtil<DeviceType::kGPU, PredType>::Scal(ctx, n, static_cast<PredType>(conf.scale()), + pred_diff, 1); + if (conf.normalize()) { + KernelUtil<DeviceType::kGPU, PredType>::Div(ctx, n, pred_diff, label_num); + } + } +}; + +#define INSTANTIATE_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_UTIL(data_type_pair, label_type_pair) \ + template struct SigmoidCrossEntropyLossKernelUtil< \ + DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(label_type_pair)>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_UTIL, + FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ) +} // namespace oneflow diff --git a/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..13435a3ab52719ff2e11b5ad36ec576829339546 --- /dev/null +++ b/oneflow/core/kernel/sigmoid_cross_entropy_loss_kernel.h @@ -0,0 +1,34 @@ +#ifndef ONEFLOW_CORE_KERNEL_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_H_ + +#include "oneflow/core/kernel/loss_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename PredType, typename LabelType> +class SigmoidCrossEntropyLossKernel final : public LossKernel<device_type, PredType> { + public: + OF_DISALLOW_COPY_AND_MOVE(SigmoidCrossEntropyLossKernel); + SigmoidCrossEntropyLossKernel() = default; + ~SigmoidCrossEntropyLossKernel() = default; + + private: + void VirtualLossForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const override; +}; + +template<DeviceType device_type, typename PredType, typename LabelType> +struct SigmoidCrossEntropyLossKernelUtil { + static void Forward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n, + const PredType* prediction, const LabelType* label, PredType* loss_buf, + PredType* tmp_storage, const size_t tmp_storage_byte_size, PredType* count, + PredType* label_num, PredType* loss); + static void Backward(DeviceCtx* ctx, const SigmoidCrossEntropyLossOpConf& conf, const int64_t n, + const PredType* prediction, const LabelType* label, + const PredType* label_num, PredType* pred_diff); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SIGMOID_CROSS_ENTROPY_LOSS_KERNEL_H_ diff --git a/oneflow/core/kernel/slice_kernel.cpp b/oneflow/core/kernel/slice_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ddbbecd65f426a65c59f68f604adbdbd711293a --- /dev/null +++ b/oneflow/core/kernel/slice_kernel.cpp @@ -0,0 +1,127 @@ +#include "oneflow/core/kernel/slice_kernel.h" +#include "oneflow/core/ndarray/ndarray_helper.h" + +namespace oneflow { + +namespace { + +int64_t GetStart(const DimSliceConf& conf) { + CHECK_GT(conf.stride(), 0); + return conf.has_start() ? conf.start() : Slice::kStart; +} + +int64_t GetEnd(const DimSliceConf& conf) { + CHECK_GT(conf.stride(), 0); + return conf.has_end() ? conf.end() : Slice::kEnd; +} + +int64_t GetStride(const DimSliceConf& conf) { return conf.stride(); } + +} // namespace + +template<typename T> +void SliceKernel<DeviceType::kCPU, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const SliceOpConf& conf = this->op_conf().slice_conf(); + const Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + CHECK_EQ(in_blob->shape().NumAxes(), out_blob->shape().NumAxes()); + + switch (out_blob->shape().NumAxes()) { +// clang-format off +#define MAKE_CASE(num_axes) \ + case num_axes: { \ + NdArraySliceUtil<T, num_axes>::Forward(ctx.device_ctx, conf.dim_slice_conf(), in_blob, \ + out_blob); \ + break; \ + } + MAKE_CASE(2); + MAKE_CASE(3); +#undef MAKE_CASE + // clang-format on + default: { UNIMPLEMENTED(); } + } +} + +template<typename T> +void SliceKernel<DeviceType::kCPU, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const SliceOpConf& conf = this->op_conf().slice_conf(); + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in")); + CHECK_EQ(out_diff_blob->shape().NumAxes(), in_diff_blob->shape().NumAxes()); + + Memset<DeviceType::kCPU>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), 0, + in_diff_blob->ByteSizeOfDataContentField()); + + switch (in_diff_blob->shape().NumAxes()) { +// clang-format off +#define MAKE_CASE(num_axes) \ + case num_axes: { \ + NdArraySliceUtil<T, num_axes>::Backward(ctx.device_ctx, conf.dim_slice_conf(), \ + out_diff_blob, in_diff_blob); \ + break; \ + } + MAKE_CASE(2); + MAKE_CASE(3); +#undef MAKE_CASE + // clang-format on + default: { UNIMPLEMENTED(); } + } +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSliceConf, SliceKernel, ARITHMETIC_DATA_TYPE_SEQ); + +template<typename T> +struct NdArraySliceUtil<T, 2> final { + static void Forward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice, + const Blob* in_blob, Blob* out_blob) { + NdArrayHelper<T, 2> ndarray; + auto&& in_ndarray = ndarray.Var(in_blob->shape(), const_cast<T*>(in_blob->dptr<T>())); + auto&& out_ndarray = ndarray.Var(out_blob->shape(), out_blob->mut_dptr<T>()); + out_ndarray.Assign(in_ndarray({}, {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)), + GetStride(rep_dim_slice.Get(0))})); + } + + static void Backward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice, + const Blob* out_diff_blob, Blob* in_diff_blob) { + NdArrayHelper<T, 2> ndarray; + auto&& out_diff_ndarray = + ndarray.Var(out_diff_blob->shape(), const_cast<T*>(out_diff_blob->dptr<T>())); + auto&& in_diff_ndarray = ndarray.Var(in_diff_blob->shape(), in_diff_blob->mut_dptr<T>()); + in_diff_ndarray({}, {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)), + GetStride(rep_dim_slice.Get(0))}) + .Assign(out_diff_ndarray({}, {})); + } +}; + +template<typename T> +struct NdArraySliceUtil<T, 3> final { + static void Forward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice, + const Blob* in_blob, Blob* out_blob) { + NdArrayHelper<T, 3> ndarray; + auto&& in_ndarray = ndarray.Var(in_blob->shape(), const_cast<T*>(in_blob->dptr<T>())); + auto&& out_ndarray = ndarray.Var(out_blob->shape(), out_blob->mut_dptr<T>()); + out_ndarray.Assign(in_ndarray({}, + {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)), + GetStride(rep_dim_slice.Get(0))}, + {GetStart(rep_dim_slice.Get(1)), GetEnd(rep_dim_slice.Get(1)), + GetStride(rep_dim_slice.Get(1))})); + } + + static void Backward(DeviceCtx* device_ctx, const PbRpf<DimSliceConf>& rep_dim_slice, + const Blob* out_diff_blob, Blob* in_diff_blob) { + NdArrayHelper<T, 3> ndarray; + auto&& out_diff_ndarray = + ndarray.Var(out_diff_blob->shape(), const_cast<T*>(out_diff_blob->dptr<T>())); + auto&& in_diff_ndarray = ndarray.Var(in_diff_blob->shape(), in_diff_blob->mut_dptr<T>()); + in_diff_ndarray({}, + {GetStart(rep_dim_slice.Get(0)), GetEnd(rep_dim_slice.Get(0)), + GetStride(rep_dim_slice.Get(0))}, + {GetStart(rep_dim_slice.Get(1)), GetEnd(rep_dim_slice.Get(1)), + GetStride(rep_dim_slice.Get(1))}) + .Assign(out_diff_ndarray({}, {}, {})); + } +}; + +} // namespace oneflow diff --git a/oneflow/core/kernel/slice_kernel.cu b/oneflow/core/kernel/slice_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..515b36e6864079b9d5e5d0192ff3d2c03ba577a9 --- /dev/null +++ b/oneflow/core/kernel/slice_kernel.cu @@ -0,0 +1,87 @@ +#include "oneflow/core/kernel/slice_kernel.h" + +namespace oneflow { + +namespace { + +template<typename T> +__global__ void SliceForwardGpu(const int64_t n, const int64_t* offset, const T* entire, T* slice) { + CUDA_1D_KERNEL_LOOP(i, n) { slice[i] = entire[offset[i]]; } +} + +template<typename T> +__global__ void SliceBackwardGpu(const int64_t n, const int64_t* offset, const T* slice, + T* entire) { + CUDA_1D_KERNEL_LOOP(i, n) { entire[offset[i]] = slice[i]; } +} + +} // namespace + +template<typename T> +void SliceKernel<DeviceType::kGPU, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + const Blob* offset_blob = BnInOp2Blob("out_to_in_offset"); + Blob* out_blob = BnInOp2Blob("out"); + const int64_t num_output = out_blob->shape().elem_cnt(); + SliceForwardGpu<T><<<BlocksNum4ThreadsNum(num_output), kCudaThreadsNumPerBlock, 0, + ctx.device_ctx->cuda_stream()>>>( + num_output, offset_blob->dptr<int64_t>(), in_blob->dptr<T>(), out_blob->mut_dptr<T>()); +} + +template<typename T> +void SliceKernel<DeviceType::kGPU, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + const Blob* offset_blob = BnInOp2Blob("out_to_in_offset"); + Blob* in_diff_blob = BnInOp2Blob(GenDiffBn("in")); + const int64_t num_output = out_diff_blob->shape().elem_cnt(); + Memset<DeviceType::kGPU>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), 0, + in_diff_blob->ByteSizeOfDataContentField()); + SliceBackwardGpu<T><<<BlocksNum4ThreadsNum(num_output), kCudaThreadsNumPerBlock, 0, + ctx.device_ctx->cuda_stream()>>>(num_output, offset_blob->dptr<int64_t>(), + out_diff_blob->dptr<T>(), + in_diff_blob->mut_dptr<T>()); +} + +template<typename T> +void SliceKernel<DeviceType::kGPU, T>::InitConstBufBlobs( + DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Shape in_shape(this->kernel_conf().slice_conf().in_shape()); + InitOut2InOffsetFromHost(ctx, in_shape, BnInOp2Blob("out_to_in_offset")); +} + +template<typename T> +void SliceKernel<DeviceType::kGPU, T>::InitOut2InOffsetFromHost(DeviceCtx* ctx, + const Shape& in_shape, + Blob* blob) const { + const SliceOpConf& conf = op_conf().slice_conf(); + BEFORE_CPU_INITIALIZE(); + int64_t* host_blob_ptr = host_blob->mut_dptr<int64_t>(); + FOR_RANGE(int64_t, i, 0, host_blob->shape().elem_cnt()) { + int64_t offset = 0; + int64_t index = i; + FOR_RANGE(int64_t, j, 0, host_blob->shape().NumAxes()) { + const int64_t dim_elem_cnt = host_blob->shape().Count(j + 1); + const int64_t dim_i = index / dim_elem_cnt; + index = index % dim_elem_cnt; + int64_t start = 0; + int64_t stride = 1; + if (j > 0) { + const DimSliceConf& dim_slice_conf = conf.dim_slice_conf(j - 1); + if (dim_slice_conf.has_start()) { start = dim_slice_conf.start(); } + if (start < 0) { start += host_blob->shape().At(j); } + stride = dim_slice_conf.stride(); + } + offset += (start + dim_i * stride) * in_shape.Count(j + 1); + } + host_blob_ptr[i] = offset; + } + AFTER_CPU_INITIALIZE(); +} + +#define INSTANTIATE_GPU_SLICE_KERNEL(type_cpp, type_proto) \ + template struct SliceKernel<DeviceType::kGPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_GPU_SLICE_KERNEL, ARITHMETIC_DATA_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/kernel/slice_kernel.h b/oneflow/core/kernel/slice_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..aa438f10fd1110566889b0cd12f881dcf9477c60 --- /dev/null +++ b/oneflow/core/kernel/slice_kernel.h @@ -0,0 +1,48 @@ +#ifndef ONEFLOW_CORE_KERNEL_SLICE_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SLICE_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class SliceKernel; + +template<typename T> +class SliceKernel<DeviceType::kCPU, T> final : public KernelIf<DeviceType::kCPU> { + public: + OF_DISALLOW_COPY_AND_MOVE(SliceKernel); + SliceKernel() = default; + ~SliceKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +template<typename T> +class SliceKernel<DeviceType::kGPU, T> final : public KernelIf<DeviceType::kGPU> { + public: + OF_DISALLOW_COPY_AND_MOVE(SliceKernel); + SliceKernel() = default; + ~SliceKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void InitConstBufBlobs(DeviceCtx*, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + + void InitOut2InOffsetFromHost(DeviceCtx* ctx, const Shape& in_shape, Blob* blob) const; +}; + +template<typename T, size_t NDIMS> +struct NdArraySliceUtil; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SLICE_KERNEL_H_ diff --git a/oneflow/core/kernel/softmax_kernel.cpp b/oneflow/core/kernel/softmax_kernel.cpp index 896a7d5ccf094961d1b547231fa0a81189777fc2..f68327c908f83f386d4de9451ce99e0e9b88fbd4 100644 --- a/oneflow/core/kernel/softmax_kernel.cpp +++ b/oneflow/core/kernel/softmax_kernel.cpp @@ -1,6 +1,7 @@ #include "oneflow/core/kernel/softmax_kernel.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/transpose_kernel.h" +#include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { @@ -14,7 +15,10 @@ void SoftmaxComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const // dot product | get dot product sum_vec[i] from out[i] * out_diff[i] T* tmp = in_diff; KernelUtil<device_type, T>::Mul(ctx, n * w, out, out_diff, tmp); - KernelUtil<device_type, T>::RowSum(ctx, n, w, tmp, sum_vec, temp_storage, temp_storage_bytes); + NdarrayUtil<device_type, T>::ReduceSum( + ctx, XpuVarNdarray<T>({n, 1}, sum_vec), XpuVarNdarray<const T>({n, w}, tmp), + XpuVarNdarray<T>({static_cast<int64_t>(temp_storage_bytes / sizeof(T))}, + reinterpret_cast<T*>(temp_storage))); // copy out_diff to in_diff KernelUtil<device_type, T>::Copy(ctx, n * w, out_diff, 1, in_diff, 1); // sub | in_diff[i][j] -= sum_vec[i] @@ -25,6 +29,32 @@ void SoftmaxComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const } // namespace +template<DeviceType device_type, typename T> +void SoftmaxComputeProb(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* in, T* tmp, + T* prob, void* temp_storage, const size_t temp_storage_bytes) { + // copy in blob to prob blob + KernelUtil<device_type, T>::Copy(ctx, n * w, in, 1, prob, 1); + // max | calculate max of every sample vector prob[i], store in tmp[i] + // the prob[i] now is store the data of in[i] + NdarrayUtil<device_type, T>::ReduceMax( + ctx, XpuVarNdarray<T>({n, 1}, tmp), XpuVarNdarray<const T>({n, w}, prob), + XpuVarNdarray<T>({static_cast<int64_t>(temp_storage_bytes / sizeof(T))}, + reinterpret_cast<T*>(temp_storage))); + // sub | every element of prob blob subract the max value of the same sample + SoftmaxKernelUtil<device_type, T>::Sub(ctx, n, w, prob, tmp); + // exp | exponentiation every element + KernelUtil<device_type, T>::Exp(ctx, n * w, prob, prob); + // sum | calculate sum of every sample vector prob[i], store in tmp[i] + // the prob[i] now is store the tmp data after exp + NdarrayUtil<device_type, T>::ReduceSum( + ctx, XpuVarNdarray<T>({n, 1}, tmp), XpuVarNdarray<const T>({n, w}, prob), + XpuVarNdarray<T>({static_cast<int64_t>(temp_storage_bytes / sizeof(T))}, + reinterpret_cast<T*>(temp_storage))); + // div | every element of prob[i] divided by the data of tmp[i] (the sum + // value) + SoftmaxKernelUtil<device_type, T>::Div(ctx, n, w, prob, tmp); +} + template<DeviceType device_type, typename T> void SoftmaxKernel<device_type, T>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { diff --git a/oneflow/core/kernel/softmax_kernel.h b/oneflow/core/kernel/softmax_kernel.h index 8f1903aa6d5440c0a6831b1f4359d37ad697d7f7..c7f6bf15db54549f31d3d143f7c9cd10f1998319 100644 --- a/oneflow/core/kernel/softmax_kernel.h +++ b/oneflow/core/kernel/softmax_kernel.h @@ -32,23 +32,7 @@ struct SoftmaxKernelUtil { template<DeviceType device_type, typename T> void SoftmaxComputeProb(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* in, T* tmp, - T* prob, void* temp_storage, const size_t temp_storage_bytes) { - // copy in blob to prob blob - KernelUtil<device_type, T>::Copy(ctx, n * w, in, 1, prob, 1); - // max | calculate max of every sample vector prob[i], store in tmp[i] - // the prob[i] now is store the data of in[i] - KernelUtil<device_type, T>::RowMax(ctx, n, w, prob, tmp, temp_storage, temp_storage_bytes); - // sub | every element of prob blob subract the max value of the same sample - SoftmaxKernelUtil<device_type, T>::Sub(ctx, n, w, prob, tmp); - // exp | exponentiation every element - KernelUtil<device_type, T>::Exp(ctx, n * w, prob, prob); - // sum | calculate sum of every sample vector prob[i], store in tmp[i] - // the prob[i] now is store the tmp data after exp - KernelUtil<device_type, T>::RowSum(ctx, n, w, prob, tmp, temp_storage, temp_storage_bytes); - // div | every element of prob[i] divided by the data of tmp[i] (the sum - // value) - SoftmaxKernelUtil<device_type, T>::Div(ctx, n, w, prob, tmp); -} + T* prob, void* temp_storage, const size_t temp_storage_bytes); } // namespace oneflow diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel.cpp b/oneflow/core/kernel/sparse_cross_entropy_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5adee9ea9f54e3e712fb97e57b8095e8c6f6d6b --- /dev/null +++ b/oneflow/core/kernel/sparse_cross_entropy_kernel.cpp @@ -0,0 +1,66 @@ +#include "oneflow/core/kernel/sparse_cross_entropy_kernel.h" +#include "oneflow/core/kernel/sparse_cross_entropy_kernel_util.h" + +namespace oneflow { + +namespace { + +template<DeviceType device_type, typename T, typename K> +void Forward(DeviceCtx* ctx, const Blob* prediction, const Blob* label, Blob* out) { + const int64_t num_instances = label->shape().elem_cnt(); + CHECK_EQ(prediction->shape().elem_cnt() % num_instances, 0); + const int64_t num_classes = prediction->shape().elem_cnt() / num_instances; + SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeEntropy( + ctx, num_instances, num_classes, prediction->dptr<T>(), label->dptr<K>(), out->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T, typename K> +void Backward(DeviceCtx* ctx, const Blob* prediction, const Blob* label, const Blob* out_diff, + Blob* prediction_diff) { + const int64_t num_instances = label->shape().elem_cnt(); + CHECK_EQ(prediction->shape().elem_cnt() % num_instances, 0); + const int64_t num_classes = prediction->shape().elem_cnt() / num_instances; + Memset<device_type>(ctx, prediction_diff->mut_dptr<T>(), 0, + prediction_diff->ByteSizeOfDataContentField()); + SparseCrossEntropyKernelUtil<device_type, T, K>::ComputeDiff( + ctx, num_instances, num_classes, prediction->dptr<T>(), label->dptr<K>(), out_diff->dptr<T>(), + prediction_diff->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +struct SparseCrossEntropyUntil final { +#define MAKE_CROSS_ENTROPY_SWITCH_ENTRY(func_name, K) func_name<device_type, T, K> + DEFINE_STATIC_SWITCH_FUNC(void, Forward, MAKE_CROSS_ENTROPY_SWITCH_ENTRY, + MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); + DEFINE_STATIC_SWITCH_FUNC(void, Backward, MAKE_CROSS_ENTROPY_SWITCH_ENTRY, + MAKE_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); +#undef MAKE_CROSS_ENTROPY_SWITCH_ENTRY +}; + +} // namespace + +template<DeviceType device_type, typename T> +void SparseCrossEntropyKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* prediction = BnInOp2Blob("prediction"); + const Blob* label = BnInOp2Blob("label"); + Blob* out = BnInOp2Blob("out"); + SparseCrossEntropyUntil<device_type, T>::SwitchForward(SwitchCase(label->data_type()), + ctx.device_ctx, prediction, label, out); +} + +template<DeviceType device_type, typename T> +void SparseCrossEntropyKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* prediction = BnInOp2Blob("prediction"); + const Blob* label = BnInOp2Blob("label"); + const Blob* out_diff = BnInOp2Blob(GenDiffBn("out")); + Blob* prediction_diff = BnInOp2Blob(GenDiffBn("prediction")); + SparseCrossEntropyUntil<device_type, T>::SwitchBackward( + SwitchCase(label->data_type()), ctx.device_ctx, prediction, label, out_diff, prediction_diff); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSparseCrossEntropyConf, SparseCrossEntropyKernel, + FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel.h b/oneflow/core/kernel/sparse_cross_entropy_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6dca30ebb54e0e8c546e72f835fe73028485b9ed --- /dev/null +++ b/oneflow/core/kernel/sparse_cross_entropy_kernel.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class SparseCrossEntropyKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(SparseCrossEntropyKernel); + SparseCrossEntropyKernel() = default; + ~SparseCrossEntropyKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_H_ diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cpp b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c80393b76e1c3d0f8448bd98cb33e6a39e249e64 --- /dev/null +++ b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cpp @@ -0,0 +1,46 @@ +#include "oneflow/core/kernel/sparse_cross_entropy_kernel_util.h" +#include "oneflow/core/kernel/kernel_util.cuh" + +namespace oneflow { + +template<typename T, typename K> +struct SparseCrossEntropyKernelUtil<DeviceType::kCPU, T, K> { + static void ComputeEntropy(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* y) { + FOR_RANGE(int64_t, i, 0, num_instances) { + K label = labels[i]; + CHECK_GE(label, 0); + CHECK_LT(label, num_classes); + y[i] = -SafeLog(x[i * num_classes + label]); + } + } + + static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* dx) { + FOR_RANGE(int64_t, i, 0, num_instances) { + K label = labels[i]; + CHECK_GE(label, 0); + CHECK_LT(label, num_classes); + dx[i * num_classes + label] = -1 / MaxWithLogThreshold(x[i * num_classes + label]); + } + } + + static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, const T* dy, T* dx) { + FOR_RANGE(int64_t, i, 0, num_instances) { + K label = labels[i]; + CHECK_GE(label, 0); + CHECK_LT(label, num_classes); + dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]); + } + } +}; + +#define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU(data_type_pair, index_type_pair) \ + template struct SparseCrossEntropyKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), \ + OF_PP_PAIR_FIRST(index_type_pair)>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU, + FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); +#undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_CPU + +} // namespace oneflow diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cu b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..b62ac198af9dcf922a3413e81f6a37d3b0c5c437 --- /dev/null +++ b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.cu @@ -0,0 +1,71 @@ +#include "oneflow/core/kernel/sparse_cross_entropy_kernel_util.h" +#include "oneflow/core/kernel/kernel_util.cuh" + +namespace oneflow { + +namespace { + +template<typename T, typename K> +__global__ void ComputeEntropyGpu(int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* y) { + CUDA_1D_KERNEL_LOOP(i, num_instances) { + K label = labels[i]; + assert(label >= 0); + assert(label < num_classes); + y[i] = -SafeLog(x[i * num_classes + label]); + } +} + +template<typename T, typename K> +__global__ void ComputeDiffGpu(int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* dx) { + CUDA_1D_KERNEL_LOOP(i, num_instances) { + K label = labels[i]; + assert(label >= 0); + assert(label < num_classes); + dx[i * num_classes + label] = -1 / MaxWithLogThreshold(x[i * num_classes + label]); + } +} + +template<typename T, typename K> +__global__ void ComputeDiffGpu(int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, const T* dy, T* dx) { + CUDA_1D_KERNEL_LOOP(i, num_instances) { + K label = labels[i]; + assert(label >= 0); + assert(label < num_classes); + dx[i * num_classes + label] = -dy[i] / MaxWithLogThreshold(x[i * num_classes + label]); + } +} + +} // namespace + +template<typename T, typename K> +struct SparseCrossEntropyKernelUtil<DeviceType::kGPU, T, K> { + static void ComputeEntropy(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* y) { + ComputeEntropyGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0, + ctx->cuda_stream()>>>(num_instances, num_classes, x, labels, y); + } + + static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* dx) { + ComputeDiffGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0, + ctx->cuda_stream()>>>(num_instances, num_classes, x, labels, dx); + } + + static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, const T* dy, T* dx) { + ComputeDiffGpu<<<BlocksNum4ThreadsNum(num_instances), kCudaThreadsNumPerBlock, 0, + ctx->cuda_stream()>>>(num_instances, num_classes, x, labels, dy, dx); + } +}; + +#define INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_GPU(data_type_pair, index_type_pair) \ + template struct SparseCrossEntropyKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), \ + OF_PP_PAIR_FIRST(index_type_pair)>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_GPU, + FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ); +#undef INSTANTIATE_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_GPU + +} // namespace oneflow diff --git a/oneflow/core/kernel/sparse_cross_entropy_kernel_util.h b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..4676e813b4c146d11f8e8ae7d7c30f095c496304 --- /dev/null +++ b/oneflow/core/kernel/sparse_cross_entropy_kernel_util.h @@ -0,0 +1,20 @@ +#ifndef ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_ +#define ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_ + +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, typename K> +struct SparseCrossEntropyKernelUtil { + static void ComputeEntropy(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* y); + static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, T* dx); + static void ComputeDiff(DeviceCtx* ctx, int64_t num_instances, int64_t num_classes, const T* x, + const K* labels, const T* dy, T* dx); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SPARSE_CROSS_ENTROPY_KERNEL_UTIL_H_ diff --git a/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h b/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h index 5ea7cf1d5fcbb4f62f63e8f598100e184d0e209d..f52a09b0392a4d939a29a0d1e70860f1043193c8 100644 --- a/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h +++ b/oneflow/core/kernel/sparse_cross_entropy_loss_kernel.h @@ -6,7 +6,7 @@ namespace oneflow { template<DeviceType device_type, typename PredType, typename LabelType> -class SparseCrossEntropyLossKernel final : public LossKernel<device_type, PredType, LabelType> { +class SparseCrossEntropyLossKernel final : public LossKernel<device_type, PredType> { public: OF_DISALLOW_COPY_AND_MOVE(SparseCrossEntropyLossKernel); SparseCrossEntropyLossKernel() = default; diff --git a/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h b/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h index c0f7d03d1a0dd5d311c19d027917a1d6d0fbbb16..60c34cdb6091bcd4b2922c046d962bb1581a6cbb 100644 --- a/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h +++ b/oneflow/core/kernel/sparse_softmax_cross_entropy_loss_kernel.h @@ -6,8 +6,7 @@ namespace oneflow { template<DeviceType device_type, typename PredType, typename LabelType> -class SparseSoftmaxCrossEntropyLossKernel final - : public LossKernel<device_type, PredType, LabelType> { +class SparseSoftmaxCrossEntropyLossKernel final : public LossKernel<device_type, PredType> { public: OF_DISALLOW_COPY_AND_MOVE(SparseSoftmaxCrossEntropyLossKernel); SparseSoftmaxCrossEntropyLossKernel() = default; diff --git a/oneflow/core/kernel/sqrt_kernel.cpp b/oneflow/core/kernel/sqrt_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa51d2ce0f0dcfc2b16c85d3df6f916105a76e6f --- /dev/null +++ b/oneflow/core/kernel/sqrt_kernel.cpp @@ -0,0 +1,29 @@ +#include "oneflow/core/kernel/sqrt_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void SqrtKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + KernelUtil<device_type, T>::Sqrt(ctx.device_ctx, in_blob->static_shape().elem_cnt(), + in_blob->dptr<T>(), out_blob->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void SqrtKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Blob* out_blob = BnInOp2Blob("out"); + Blob* out_diff_blob = BnInOp2Blob("out_diff"); + Blob* in_diff_blob = BnInOp2Blob("in_diff"); + KernelUtil<device_type, T>::Div(ctx.device_ctx, out_blob->static_shape().elem_cnt(), + out_diff_blob->dptr<T>(), out_blob->dptr<T>(), + in_diff_blob->mut_dptr<T>()); + KernelUtil<device_type, T>::Scal(ctx.device_ctx, out_blob->static_shape().elem_cnt(), + static_cast<T>(0.5), in_diff_blob->mut_dptr<T>(), 1); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSqrtConf, SqrtKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/sqrt_kernel.h b/oneflow/core/kernel/sqrt_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ae206ef4bf56c98a00d443489c2d53d81356e445 --- /dev/null +++ b/oneflow/core/kernel/sqrt_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_SQRT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SQRT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class SqrtKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(SqrtKernel); + SqrtKernel() = default; + ~SqrtKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SQRT_KERNEL_H_ diff --git a/oneflow/core/kernel/square_kernel.cpp b/oneflow/core/kernel/square_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0dd0eb7d5dd3d7fb8d10b5277ce38f588a44b95 --- /dev/null +++ b/oneflow/core/kernel/square_kernel.cpp @@ -0,0 +1,29 @@ +#include "oneflow/core/kernel/square_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void SquareKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Blob* in_blob = BnInOp2Blob("in"); + Blob* out_blob = BnInOp2Blob("out"); + KernelUtil<device_type, T>::Square(ctx.device_ctx, in_blob->static_shape().elem_cnt(), + in_blob->dptr<T>(), out_blob->mut_dptr<T>()); +} + +template<DeviceType device_type, typename T> +void SquareKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + Blob* in_blob = BnInOp2Blob("in"); + Blob* out_diff_blob = BnInOp2Blob("out_diff"); + Blob* in_diff_blob = BnInOp2Blob("in_diff"); + KernelUtil<device_type, T>::Mul(ctx.device_ctx, in_blob->static_shape().elem_cnt(), + out_diff_blob->dptr<T>(), in_blob->dptr<T>(), + in_diff_blob->mut_dptr<T>()); + KernelUtil<device_type, T>::Scal(ctx.device_ctx, in_blob->static_shape().elem_cnt(), + static_cast<T>(2), in_diff_blob->mut_dptr<T>(), 1); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSquareConf, SquareKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/square_kernel.h b/oneflow/core/kernel/square_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..af578ccf48a342c42fd20b50a9be7ad32de8d205 --- /dev/null +++ b/oneflow/core/kernel/square_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_SQUARE_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SQUARE_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class SquareKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(SquareKernel); + SquareKernel() = default; + ~SquareKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SQUARE_KERNEL_H_ diff --git a/oneflow/core/kernel/tick_kernel.cpp b/oneflow/core/kernel/tick_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28dee87c66b46eb9be8b57fcebb9d8278da26ddb --- /dev/null +++ b/oneflow/core/kernel/tick_kernel.cpp @@ -0,0 +1,7 @@ +#include "oneflow/core/kernel/tick_kernel.h" + +namespace oneflow { + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kTickConf, TickKernel, ARITHMETIC_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/tick_kernel.h b/oneflow/core/kernel/tick_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..dfe8f36201b528da0bc16b55e1a4442e6fc8053e --- /dev/null +++ b/oneflow/core/kernel/tick_kernel.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_KERNEL_TICK_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_TICK_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class TickKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(TickKernel); + TickKernel() = default; + ~TickKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override {} + void BackwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override { + UNIMPLEMENTED(); + } + const PbMessage& GetCustomizedOpConf() const override { return this->op_conf().tick_conf(); } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_TICK_KERNEL_H_ diff --git a/oneflow/core/kernel/top_k_kernel.cpp b/oneflow/core/kernel/top_k_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b83dfb12675d67e91f8d1910717de9fda010607f --- /dev/null +++ b/oneflow/core/kernel/top_k_kernel.cpp @@ -0,0 +1,40 @@ +#include "oneflow/core/kernel/top_k_kernel.h" + +namespace oneflow { + +template<typename T> +void TopKKernel<T>::ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + Blob* fw_buf_blob = BnInOp2Blob("fw_buf"); + Blob* out_blob = BnInOp2Blob("out"); + + CHECK_LE(in_blob->shape().elem_cnt(), GetMaxVal<int32_t>()); + const int32_t instance_size = static_cast<int32_t>(in_blob->shape().dim_vec().back()); + const int32_t instance_num = static_cast<int32_t>(in_blob->shape().elem_cnt() / instance_size); + const T* in = in_blob->dptr<T>(); + int32_t* fw_buf = fw_buf_blob->mut_dptr<int32_t>(); + int32_t* out = out_blob->mut_dptr<int32_t>(); + const auto& conf = this->op_conf().top_k_conf(); + const int32_t k = conf.k(); + FOR_RANGE(int32_t, i, 0, instance_num) { + std::iota(fw_buf, fw_buf + instance_size, 0); + const int32_t offset = i * instance_size; + auto comp = [&](const int32_t lhs, const int32_t rhs) { + const T l = in[offset + lhs]; + const T r = in[offset + rhs]; + if (l == r) { + return lhs < rhs; + } else { + return l > r; + } + }; + std::nth_element(fw_buf, fw_buf + k, fw_buf + instance_size, comp); + if (k > 1 && conf.sorted()) { std::sort(fw_buf, fw_buf + k, comp); } + std::copy(fw_buf, fw_buf + k, out + i * k); + } +} + +ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kTopKConf, TopKKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/top_k_kernel.h b/oneflow/core/kernel/top_k_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..98fb94d467d5ef77fb3d4b72694c79e63c2afde7 --- /dev/null +++ b/oneflow/core/kernel/top_k_kernel.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<typename T> +class TopKKernel final : public KernelIf<DeviceType::kCPU> { + public: + OF_DISALLOW_COPY_AND_MOVE(TopKKernel); + TopKKernel() = default; + ~TopKKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_ diff --git a/oneflow/core/kernel/tuple_identity_kernel.cpp b/oneflow/core/kernel/tuple_identity_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..442f4cb716d490148edef4ce489ab13476715d0e --- /dev/null +++ b/oneflow/core/kernel/tuple_identity_kernel.cpp @@ -0,0 +1,31 @@ +#include "oneflow/core/kernel/tuple_identity_kernel.h" + +namespace oneflow { + +template<DeviceType device_type> +void TupleIdentityKernel<device_type>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const auto& input_bns = this->op_attribute().input_bns(); + const auto& output_bns = this->op_attribute().output_bns(); + CHECK_EQ(input_bns.size(), output_bns.size()); + FOR_RANGE(int, i, 0, input_bns.size()) { + Blob* out_blob = BnInOp2Blob(output_bns.Get(i)); + out_blob->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(input_bns.Get(i))); + } +} + +template<DeviceType device_type> +void TupleIdentityKernel<device_type>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const auto& input_diff_bns = this->op_attribute().input_diff_bns(); + const auto& output_diff_bns = this->op_attribute().output_diff_bns(); + CHECK_EQ(input_diff_bns.size(), output_diff_bns.size()); + FOR_RANGE(int, i, 0, output_diff_bns.size()) { + Blob* in_diff_blob = BnInOp2Blob(input_diff_bns.Get(i)); + in_diff_blob->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob(output_diff_bns.Get(i))); + } +} + +ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kTupleIdentityConf, TupleIdentityKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/tuple_identity_kernel.h b/oneflow/core/kernel/tuple_identity_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d2d953e6fe2041604c2d82dc855c64200dbbf57f --- /dev/null +++ b/oneflow/core/kernel/tuple_identity_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_TUPLE_IDENTITY_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_TUPLE_IDENTITY_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/kernel/kernel_context.h" + +namespace oneflow { + +template<DeviceType device_type> +class TupleIdentityKernel final : public KernelIf<device_type> { + public: + OF_DISALLOW_COPY_AND_MOVE(TupleIdentityKernel); + TupleIdentityKernel() = default; + ~TupleIdentityKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_TUPLE_IDENTITY_KERNEL_H_ diff --git a/oneflow/core/kernel/unpack_kernel.h b/oneflow/core/kernel/unpack_kernel.h index c82b085ef3f27625148bee306db85e1507d891c6..1411672182860813239cbda53d93c04eeb03836b 100644 --- a/oneflow/core/kernel/unpack_kernel.h +++ b/oneflow/core/kernel/unpack_kernel.h @@ -31,6 +31,10 @@ class UnpackKernel final : public KernelIf<device_type> { const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; void BackwardInDiffDim0ValidNum( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BackwardInDiffLossInstanceNum( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const override { + UNIMPLEMENTED(); + } }; } // namespace oneflow diff --git a/oneflow/core/kernel/variable_kernel.cpp b/oneflow/core/kernel/variable_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b32360299a1deb6e45fd4962402cdc90ab136134 --- /dev/null +++ b/oneflow/core/kernel/variable_kernel.cpp @@ -0,0 +1,70 @@ +#include "oneflow/core/kernel/variable_kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +void VariableKernel<device_type, T>::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const Blob* model_blob = BnInOp2Blob(ModelName()); + Blob* out_blob = BnInOp2Blob("out"); + if ((this->op_conf().trainable() && *tick_ % Global<JobDesc>::Get()->NumOfPiecesInBatch() == 0) + || (this->op_conf().trainable() == false && *tick_ == 0)) { + if (this->kernel_conf().variable_conf().is_fw_inplace()) { + CHECK_EQ(out_blob->dptr(), model_blob->dptr()); + } else { + CHECK_NE(out_blob->dptr(), model_blob->dptr()); + out_blob->CopyDataContentFrom(ctx.device_ctx, model_blob); + } + } else { + // do nothing + } + ++(*tick_); +} + +template<DeviceType device_type, typename T> +void VariableKernel<device_type, T>::BackwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + CHECK(this->op_conf().trainable()); + const Blob* out_diff_blob = BnInOp2Blob(GenDiffBn("out")); + Blob* model_diff_blob = BnInOp2Blob(GenDiffBn(ModelName())); + if (this->kernel_conf().variable_conf().is_bw_inplace()) { + CHECK_EQ(out_diff_blob->dptr(), model_diff_blob->dptr()); + } else { + CHECK_NE(out_diff_blob->dptr(), model_diff_blob->dptr()); + model_diff_blob->CopyDataContentFrom(ctx.device_ctx, out_diff_blob); + } +} + +template<DeviceType device_type, typename T> +void VariableKernel<device_type, T>::InitModelBlobsWithRandomSeed( + DeviceCtx* ctx, std::mt19937* random_seed_gen, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + KernelUtil<device_type, T>::InitializeWithProperConf( + ctx, GetMsgPtrFromPbMessage(this->op_conf().variable_conf(), "initializer"), + (*random_seed_gen)(), BnInOp2Blob(ModelName())); +} + +template<DeviceType device_type, typename T> +void VariableKernel<device_type, T>::InitModelBlobsWithDir( + DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const std::string& model_name = ModelName(); + Blob* model_blob = BnInOp2Blob(model_name); + KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, model_blob, + model_name, model_blob->shape().At(0), + model_blob->shape().Count(1)); +} + +template<DeviceType device_type, typename T> +const PbMessage& VariableKernel<device_type, T>::GetCustomizedOpConf() const { + return this->op_conf().variable_conf(); +} + +template<DeviceType device_type, typename T> +const std::string& VariableKernel<device_type, T>::ModelName() const { + return this->op_conf().variable_conf().model_name(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kVariableConf, VariableKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/variable_kernel.h b/oneflow/core/kernel/variable_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..82185acddcd96d8743380a658b69bc572789b2a7 --- /dev/null +++ b/oneflow/core/kernel/variable_kernel.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_KERNEL_VARIABLE_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_VARIABLE_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +class VariableKernel final : public KernelIfWithModel<device_type, T> { + public: + OF_DISALLOW_COPY_AND_MOVE(VariableKernel); + VariableKernel() : tick_(new int64_t(0)) {} + ~VariableKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void BackwardDataContent(const KernelCtx&, + std::function<Blob*(const std::string&)>) const override; + void InitModelBlobsWithRandomSeed(DeviceCtx*, std::mt19937*, + std::function<Blob*(const std::string&)>) const override; + void InitModelBlobsWithDir(DeviceCtx*, int32_t part_id, int32_t part_num, + const std::string& model_load_dir, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override; + const std::string& ModelName() const; + + std::unique_ptr<int64_t> tick_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_VARIABLE_KERNEL_H_ diff --git a/oneflow/core/ndarray/binary_func.h b/oneflow/core/ndarray/binary_func.h new file mode 100644 index 0000000000000000000000000000000000000000..c1c8ffee8505ff933bc4b3d5a9056d4fef9eb450 --- /dev/null +++ b/oneflow/core/ndarray/binary_func.h @@ -0,0 +1,70 @@ +#ifndef ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_ +#define ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_ + +#include <cstdint> +#include <climits> +#include <cfloat> +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/common/util.h" + +namespace oneflow { + +template<typename T> +OF_DEVICE_FUNC const T BinaryFuncAdd(const T x, const T y) { + return x + y; +} + +template<typename T> +OF_DEVICE_FUNC const T BinaryFuncSub(const T x, const T y) { + return x - y; +} + +template<typename T> +OF_DEVICE_FUNC const T BinaryFuncMul(const T x, const T y) { + return x * y; +} + +template<typename T> +OF_DEVICE_FUNC const T BinaryFuncDiv(const T x, const T y) { + return x / y; +} + +template<typename T> +OF_DEVICE_FUNC const T BinaryFuncMax(const T x, const T y) { + return x > y ? x : y; +} + +template<typename T> +OF_DEVICE_FUNC const T BinaryFuncMin(const T x, const T y) { + return x < y ? x : y; +} + +#define ARITHMETIC_BINARY_FUNC_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncAdd) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncSub) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncMul) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncDiv) + +template<typename T, const T (*binary_func)(const T, const T), typename Enable = void> +struct UnitOfBinaryFunc; + +#define SPECIALIZE_UNIT_OF_BINARY_FUNC(binary_func, val_trait) \ + template<typename T, const T (*bfunc)(const T, const T)> \ + struct UnitOfBinaryFunc<T, bfunc, typename std::enable_if<bfunc == &binary_func<T>>::type> \ + final { \ + constexpr static T value = val_trait<T>::value; \ + }; +SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAdd, ZeroVal); +SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMul, OneVal); +SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMax, MinVal); +SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncMin, MaxVal); +#undef SPECIALIZE_UNIT_OF_BINARY_FUNC + +#define REDUCE_BINARY_FUNC_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncAdd) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncMax) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryFuncMin) + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_ diff --git a/oneflow/core/ndarray/cpu_ndarray_assign.h b/oneflow/core/ndarray/cpu_ndarray_assign.h new file mode 100644 index 0000000000000000000000000000000000000000..38cac781358960ef4bb21c8701289680b50fb6fa --- /dev/null +++ b/oneflow/core/ndarray/cpu_ndarray_assign.h @@ -0,0 +1,18 @@ +#ifndef ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_ASSIGN_H_ +#define ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_ASSIGN_H_ + +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<int NDIMS, typename T, typename X> +OF_DEVICE_FUNC void CpuNdArrayAssign(XpuVarNdarray<T>* y, const X& x) { + size_t n = y->shape().ElemNum(); + FOR_RANGE(int, i, 0, n) { *(y->template Mut<NDIMS>(i)) = x.template Get<NDIMS>(i); } +} + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_CPU_NDARRAY_ASSIGN_H_ diff --git a/oneflow/core/ndarray/exec_shape.cpp b/oneflow/core/ndarray/exec_shape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd5712afd79fc9ff233ee0386dca686cbcf220ce --- /dev/null +++ b/oneflow/core/ndarray/exec_shape.cpp @@ -0,0 +1,37 @@ +#include "oneflow/core/ndarray/xpu_shape.h" + +namespace oneflow { + +XpuShape::XpuShape(const int64_t dim[], int num_axes) { + num_axes_ = num_axes; + int i = 0; + for (; i < num_axes_; ++i) { dim_[i] = dim[i]; } + UpdateDimElemNumAndElemNum(); + for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { + dim_[i] = 1; + dim_elem_num_[i] = 1; + } +} + +XpuShape::XpuShape(const Shape& shape) { + num_axes_ = shape.NumAxes(); + int i = 0; + for (; i < num_axes_; ++i) { dim_[i] = shape.At(i); } + UpdateDimElemNumAndElemNum(); + for (; i < sizeof(dim_) / sizeof(dim_[0]); ++i) { + dim_[i] = 1; + dim_elem_num_[i] = 1; + } +} + +bool XpuShape::operator==(const XpuShape& rhs) const { + if (num_axes_ != rhs.num_axes_) { return false; } + if (elem_num_ != rhs.elem_num_) { return false; } + for (int i = 0; i < num_axes_; ++i) { + if (dim_[i] != rhs.dim_[i]) { return false; } + if (dim_elem_num_[i] != rhs.dim_elem_num_[i]) { return false; } + } + return true; +} + +} // namespace oneflow diff --git a/oneflow/core/ndarray/gpu_ndarray_assign.h b/oneflow/core/ndarray/gpu_ndarray_assign.h new file mode 100644 index 0000000000000000000000000000000000000000..42a6faa9e16f8e57ec905e453c5e78ed79dd53b3 --- /dev/null +++ b/oneflow/core/ndarray/gpu_ndarray_assign.h @@ -0,0 +1,18 @@ +#ifndef ONEFLOW_CORE_NDARRAY_GPU_NDARRAY_ASSIGN_H_ +#define ONEFLOW_CORE_NDARRAY_GPU_NDARRAY_ASSIGN_H_ + +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/kernel/kernel_util.cuh" + +namespace oneflow { + +template<int NDIMS, typename T, typename X> +__device__ void GpuNdArrayAssign(XpuVarNdarray<T>* y, const X& x) { + size_t n = y->shape().ElemNum(); + CUDA_1D_KERNEL_LOOP(i, n) { *(y->template Mut<NDIMS>(i)) = x.template Get<NDIMS>(i); } +} + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_GPU_NDARRAY_ASSIGN_H_ diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h new file mode 100644 index 0000000000000000000000000000000000000000..3abcbfd051fde57def8927aa998d470afcfbd96b --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h @@ -0,0 +1,44 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_ + +#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h" +#include "oneflow/core/common/util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayApplyBroadcastBinary final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a, + const XpuVarNdarray<const T>& b) { + using NdarrayAssign = XpuNdArrayAssign<device_type, T>; + using BroadcastBinary = + NdArrayApplyBroadcastBinaryCoreWrapper<device_type, T, NDIMS, binary_func>; + CheckBroadcastable(y, a, b); + return BroadcastBinary::Apply(ctx, y, a, b); + if (a.shape() == y.shape()) { + NdarrayAssign::Assign(ctx, y, a); + BroadcastBinary::ImplaceApply(ctx, y, b); + } else if (b.shape() == y.shape()) { + NdarrayAssign::Assign(ctx, y, b); + BroadcastBinary::ImplaceApply(ctx, y, a); + } else { + BroadcastBinary::Apply(ctx, y, a, b); + } + } + + static void CheckBroadcastable(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a, + const XpuVarNdarray<const T>& b) { + CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes()); + CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes()); + for (int i = 0; i < y.shape().NumAxes(); ++i) { + CHECK_EQ(y.shape().At(i), std::max(a.shape().At(i), b.shape().At(i))); + if (a.shape().At(i) != b.shape().At(i)) { + CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1); + } + } + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_H_ diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9a0fcdaec7bd0a2b9dd8c5e67dc78c4febdcd223 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cpp @@ -0,0 +1,22 @@ +#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h" + +namespace oneflow { + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayApplyBroadcastBinaryCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func> final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a, + const XpuVarNdarray<const T>& b) { + NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::Apply(y, a, b); + } + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x) { + NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::ImplaceApply(y, x); + } +}; + +#define INSTANTIATE_BROADCAST_BINARY_FUNC(dtype_pair, NDIMS, binary_func) \ + template struct NdArrayApplyBroadcastBinaryCoreWrapper< \ + DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, binary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_BINARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ, + DIM_SEQ, ARITHMETIC_BINARY_FUNC_SEQ) +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu new file mode 100644 index 0000000000000000000000000000000000000000..94522f200083a04cf8188e7aa26b572c6604c9ce --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu @@ -0,0 +1,38 @@ +#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h" + +namespace oneflow { + +namespace { + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +__global__ void GpuBroadcastBinaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> a, + const XpuVarNdarray<const T> b) { + NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::Apply(y, a, b); +} +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +__global__ void GpuBroadcastBinaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> x) { + NdArrayApplyBroadcastBinaryCore<T, NDIMS, binary_func>::ImplaceApply(y, x); +} + +} // namespace + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayApplyBroadcastBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS, binary_func> final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a, + const XpuVarNdarray<const T>& b) { + size_t n = y.host_shape().HostElemNum(); + RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, a, b); + } + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x) { + size_t n = y.host_shape().HostElemNum(); + RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, x); + } +}; + +#define INSTANTIATE_BROADCAST_BINARY_FUNC(dtype_pair, NDIMS, binary_func) \ + template struct NdArrayApplyBroadcastBinaryCoreWrapper< \ + DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, binary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_BINARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ, + DIM_SEQ, ARITHMETIC_BINARY_FUNC_SEQ) +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h new file mode 100644 index 0000000000000000000000000000000000000000..ab67762ed03ce40f994ffab991bca1b2aaf9205c --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.h @@ -0,0 +1,37 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_ + +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/xpu_broadcast_ndarray.h" +#include "oneflow/core/ndarray/xpu_binary_func_ndarray.h" +#include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayApplyBroadcastBinaryCoreWrapper final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a, + const XpuVarNdarray<const T>& b); + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x); +}; + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayApplyBroadcastBinaryCore final { + OF_DEVICE_FUNC static void Apply(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& a, + const XpuVarNdarray<const T>& b) { + const auto& ret = + a.Broadcast(y.shape()).template BinaryFunc<binary_func>(b.Broadcast(y.shape())); + y.template Assign<NDIMS>(ret); + } + OF_DEVICE_FUNC static void ImplaceApply(const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x) { + y.template BinaryAssign<binary_func, NDIMS>(x.Broadcast(y.shape())); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_BINARY_CORE_H_ diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_unary.h new file mode 100644 index 0000000000000000000000000000000000000000..432b966ad37c75d46f06505ae548ca0a3bb24828 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_ + +#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h" +#include "oneflow/core/common/util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, int NDIMS, const T (*unary_func)(const T)> +struct NdArrayApplyBroadcastUnary final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + CheckBroadcastable(y, x); + NdArrayApplyBroadcastUnaryCoreWrapper<device_type, T, NDIMS, unary_func>::Apply(ctx, y, x); + } + + static void CheckBroadcastable(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes()); + for (int i = 0; i < y.shape().NumAxes(); ++i) { + CHECK(x.shape().At(i) == 1 || x.shape().At(i) == y.shape().At(i)); + } + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_H_ diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..842c1d20a5c972c95f66b49fd4f87d952f1dac5b --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cpp @@ -0,0 +1,17 @@ +#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h" + +namespace oneflow { + +template<typename T, int NDIMS, const T (*unary_func)(const T)> +struct NdArrayApplyBroadcastUnaryCoreWrapper<DeviceType::kCPU, T, NDIMS, unary_func> final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + NdArrayApplyBroadcastUnaryCore<T, NDIMS, unary_func>::Apply(y, x); + } +}; + +#define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \ + template struct NdArrayApplyBroadcastUnaryCoreWrapper< \ + DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ, + DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ) +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu new file mode 100644 index 0000000000000000000000000000000000000000..f705f8b069a3fc43c9bf720262800db328651506 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu @@ -0,0 +1,27 @@ +#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h" + +namespace oneflow { + +namespace { + +template<typename T, int NDIMS, const T (*unary_func)(const T)> +__global__ void GpuBroadcastUnaryFunc(const XpuVarNdarray<T> y, const XpuVarNdarray<const T> x) { + NdArrayApplyBroadcastUnaryCore<T, NDIMS, unary_func>::Apply(y, x); +} + +} // namespace + +template<typename T, int NDIMS, const T (*unary_func)(const T)> +struct NdArrayApplyBroadcastUnaryCoreWrapper<DeviceType::kGPU, T, NDIMS, unary_func> final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + size_t n = y.host_shape().HostElemNum(); + RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc<T, NDIMS, unary_func>), ctx, n, y, x); + } +}; + +#define INSTANTIATE_BROADCAST_UNARY_FUNC(dtype_pair, NDIMS, unary_func) \ + template struct NdArrayApplyBroadcastUnaryCoreWrapper< \ + DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, unary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_BROADCAST_UNARY_FUNC, ARITHMETIC_DATA_TYPE_SEQ, + DIM_SEQ, ARITHMETIC_UNARY_FUNC_SEQ) +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h new file mode 100644 index 0000000000000000000000000000000000000000..18213b7bd858515ba67f2ccd4ab884a5908c546d --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_ + +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/xpu_broadcast_ndarray.h" +#include "oneflow/core/ndarray/xpu_unary_func_ndarray.h" +#include "oneflow/core/ndarray/unary_func.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, int NDIMS, const T (*unary_func)(const T)> +struct NdArrayApplyBroadcastUnaryCoreWrapper final { + static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x); +}; + +template<typename T, int NDIMS, const T (*unary_func)(const T)> +struct NdArrayApplyBroadcastUnaryCore final { + OF_DEVICE_FUNC static void Apply(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + y.template Assign<NDIMS>(x.Broadcast(y.shape()).template UnaryFunc<unary_func>()); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_BROADCAST_UNARY_CORE_H_ diff --git a/oneflow/core/ndarray/ndarray_apply_unary.h b/oneflow/core/ndarray/ndarray_apply_unary.h new file mode 100644 index 0000000000000000000000000000000000000000..774cdaa2b63ecd842a4bd6a873dd5197fbfb9fd2 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_unary.h @@ -0,0 +1,18 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/ndarray/ndarray_apply_unary_core.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, const T (*unary_func)(const T)> +struct NdArrayApplyUnary final { + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) { + NdArrayApplyUnaryCoreWrapper<device_type, T, unary_func>::ImplaceApply(ctx, y); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_H_ diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cpp b/oneflow/core/ndarray/ndarray_apply_unary_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dda45c5e8ff56ff03b76acf99a83cf5e198d5916 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cpp @@ -0,0 +1,19 @@ +#include "oneflow/core/ndarray/ndarray_apply_unary_core.h" +#include "oneflow/core/ndarray/unary_func.h" + +namespace oneflow { + +template<typename T, const T (*unary_func)(const T)> +struct NdArrayApplyUnaryCoreWrapper<DeviceType::kCPU, T, unary_func> final { + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) { + NdArrayApplyUnaryCore<T, unary_func>::ImplaceApply(y.ptr(), y.shape().ElemNum()); + } +}; + +#define INSTANTIATE_NDARRAY_APPLY_UNARY_CORE(dtype_pair, unary_func) \ + template struct NdArrayApplyUnaryCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), \ + unary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_APPLY_UNARY_CORE, ARITHMETIC_DATA_TYPE_SEQ, + ARITHMETIC_UNARY_FUNC_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_unary_core.cu new file mode 100644 index 0000000000000000000000000000000000000000..256bf5deb81a757abc0a3794b483eea69a28e609 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cu @@ -0,0 +1,29 @@ +#include "oneflow/core/ndarray/ndarray_apply_unary_core.h" +#include "oneflow/core/ndarray/unary_func.h" + +namespace oneflow { + +namespace { + +template<typename T, const T (*unary_func)(const T)> +__global__ void NdArrayApplyUnaryImplaceApplyGpu(T* ptr, size_t n) { + NdArrayApplyUnaryCore<T, unary_func>::ImplaceApply(ptr, n); +} + +} // namespace + +template<typename T, const T (*unary_func)(const T)> +struct NdArrayApplyUnaryCoreWrapper<DeviceType::kGPU, T, unary_func> final { + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) { + size_t n = y.host_shape().HostElemNum(); + RUN_CUDA_KERNEL((NdArrayApplyUnaryImplaceApplyGpu<T, unary_func>), ctx, n, y.host_ptr(), n); + } +}; + +#define INSTANTIATE_NDARRAY_APPLY_UNARY_CORE(dtype_pair, unary_func) \ + template struct NdArrayApplyUnaryCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), \ + unary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_APPLY_UNARY_CORE, ARITHMETIC_DATA_TYPE_SEQ, + ARITHMETIC_UNARY_FUNC_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.h b/oneflow/core/ndarray/ndarray_apply_unary_core.h new file mode 100644 index 0000000000000000000000000000000000000000..ddec5689e837d9d2cccf7f378195c2799e42cc7d --- /dev/null +++ b/oneflow/core/ndarray/ndarray_apply_unary_core.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_ + +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ndarray/xpu_unary_func_ndarray.h" +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/xpu_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, const T (*unary_func)(const T)> +struct NdArrayApplyUnaryCoreWrapper final { + static void ImplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y); +}; + +template<typename T, const T (*unary_func)(const T)> +struct NdArrayApplyUnaryCore final { + OF_DEVICE_FUNC static void ImplaceApply(T* y, size_t n) { + XPU_1D_KERNEL_LOOP(i, n) { y[i] = unary_func(y[i]); } + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_APPLY_UNARY_CORE_H_ diff --git a/oneflow/core/ndarray/ndarray_assign_core.cpp b/oneflow/core/ndarray/ndarray_assign_core.cpp new file mode 100644 index 0000000000000000000000000000000000000000..351edf63db93da80ca07ac972ef13a6b828ec51e --- /dev/null +++ b/oneflow/core/ndarray/ndarray_assign_core.cpp @@ -0,0 +1,18 @@ +#include "oneflow/core/ndarray/ndarray_assign_core.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<typename T, int NDIMS> +struct NdArrayAssignCoreWrapper<DeviceType::kCPU, T, NDIMS> final { + static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuReducedNdarray<T, NDIMS>& reduced) { + NdArrayAssignCore<T, NDIMS>::Assign(y, reduced); + } +}; + +#define INSTANTIATE_NDARRAY_ASSIGN(dtype_pair, NDIMS) \ + template struct NdArrayAssignCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ, DIM_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_assign_core.cu b/oneflow/core/ndarray/ndarray_assign_core.cu new file mode 100644 index 0000000000000000000000000000000000000000..e7e05d7c7336eca7b8ac5e66fecf21202a4079b8 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_assign_core.cu @@ -0,0 +1,29 @@ +#include "oneflow/core/ndarray/ndarray_assign_core.h" +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +namespace { + +template<typename T, int NDIMS> +__global__ void NdArrayAssignGpu(XpuVarNdarray<T> y, const XpuReducedNdarray<T, NDIMS> reduced) { + NdArrayAssignCore<T, NDIMS>::Assign(y, reduced); +} + +} // namespace + +template<typename T, int NDIMS> +struct NdArrayAssignCoreWrapper<DeviceType::kGPU, T, NDIMS> final { + static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuReducedNdarray<T, NDIMS>& reduced) { + size_t n = y.host_shape().HostElemNum(); + RUN_CUDA_KERNEL((NdArrayAssignGpu<T, NDIMS>), ctx, n, y, reduced); + } +}; + +#define INSTANTIATE_NDARRAY_ASSIGN(dtype_pair, NDIMS) \ + template struct NdArrayAssignCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ, DIM_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_assign_core.h b/oneflow/core/ndarray/ndarray_assign_core.h new file mode 100644 index 0000000000000000000000000000000000000000..4a822b46bdb031ff29e5c93fcdd2394a68ef16f4 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_assign_core.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_ + +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/xpu_reduced_ndarray.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, int NDIMS> +struct NdArrayAssignCoreWrapper final { + static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuReducedNdarray<T, NDIMS>& reduced); +}; + +template<typename T, int NDIMS> +struct NdArrayAssignCore final { + OF_DEVICE_FUNC static void Assign(const XpuVarNdarray<T>& y, + const XpuReducedNdarray<T, NDIMS>& reduced) { + y.template Assign<NDIMS>(reduced); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_ASSIGN_CORE_H_ diff --git a/oneflow/core/ndarray/ndarray_reduce.h b/oneflow/core/ndarray/ndarray_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..f508c5af381042c6d4c74fd2d40874274286eeb6 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_reduce.h @@ -0,0 +1,31 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_ + +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/ndarray/ndarray_reduce_impl.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)> +struct NdArrayReduce final { + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes()); + if (NdarrayNoReduce<device_type, T, binary_func>::Matched(y, x)) { + NdarrayNoReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage); + } else if (NdarrayScalarReduce<device_type, T, binary_func>::Matched(y, x)) { + NdarrayScalarReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage); + } else if (NdarrayMatrixRowReduce<device_type, T, binary_func>::Matched(y, x)) { + NdarrayMatrixRowReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage); + } else if (NdarrayMatrixColReduce<device_type, T, binary_func>::Matched(y, x)) { + NdarrayMatrixColReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage); + } else { + NdarrayDefaultReduce<device_type, T, binary_func>::Reduce(ctx, y, x, tmp_storage); + } + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_H_ diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.cpp b/oneflow/core/ndarray/ndarray_reduce_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e34366ae22e86738a90327076d586076684bb2f6 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_reduce_impl.cpp @@ -0,0 +1,45 @@ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/preprocessor.h" +#include "oneflow/core/ndarray/ndarray_reduce_impl.h" +#include "oneflow/core/ndarray/binary_func.h" + +namespace oneflow { + +#define SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(struct_name) \ + template<typename T, const T (*binary_func)(const T, const T)> \ + struct struct_name<DeviceType::kCPU, T, binary_func> final { \ + static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { \ + return false; \ + } \ + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, \ + const XpuVarNdarray<T>& tmp_storage) { \ + UNIMPLEMENTED(); \ + } \ + } +SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce); +SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce); +SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce); +#undef SPECIALIZE_CPU_NDARRAY_REDUCE_IMPL + +#define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func) \ + template struct NdarrayScalarReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \ + template struct NdarrayMatrixRowReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \ + template struct NdarrayMatrixColReduce<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype), binary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ, + REDUCE_BINARY_FUNC_SEQ); + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayReduceCoreWrapper<DeviceType::kCPU, T, NDIMS, binary_func> final { + static void ReduceAxis(DeviceCtx* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced, + const XpuReducedNdarray<T, NDIMS>& x, int axis) { + NdArrayReduceCore<T, NDIMS, binary_func>::ReduceAxis(dst_reduced, x, axis); + } +}; + +#define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func) \ + template struct NdArrayReduceCoreWrapper<DeviceType::kCPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, \ + binary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ, + DIM_SEQ, REDUCE_BINARY_FUNC_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.cu b/oneflow/core/ndarray/ndarray_reduce_impl.cu new file mode 100644 index 0000000000000000000000000000000000000000..abec90b2c0e48baac7ab01bc067aa1889f5794ce --- /dev/null +++ b/oneflow/core/ndarray/ndarray_reduce_impl.cu @@ -0,0 +1,193 @@ +#include <cub/cub.cuh> +#include "oneflow/core/ndarray/ndarray_reduce_impl.h" +#include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/common/preprocessor.h" +#include "oneflow/core/common/shape.h" + +namespace oneflow { + +template<typename T, const T (*binary_func)(const T, const T), typename Enable = void> +struct CubFunctor4BianryFunc; + +#define SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(binary_func, cub_functor) \ + template<typename T, const T (*bfunc)(const T, const T)> \ + struct CubFunctor4BianryFunc<T, bfunc, typename std::enable_if<bfunc == &binary_func<T>>::type> \ + final { \ + using type = cub_functor; \ + } + +SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(BinaryFuncAdd, cub::Sum); +SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(BinaryFuncMax, cub::Max); +SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC(BinaryFuncMin, cub::Min); + +#undef SPECIALIZE_CUB_FUNCTOR_4_BINARY_FUNC + +namespace { + +template<typename T, const T (*binary_func)(const T, const T)> +void __global__ NdarrayMatrixColReduceNaiveCudaKernel(T* y_ptr, const T* x_ptr, int32_t num_rows, + int32_t num_cols) { + CUDA_1D_KERNEL_LOOP(j, num_cols) { + T reduced = x_ptr[j]; + FOR_RANGE(int32_t, i, 1, num_rows) { reduced = binary_func(reduced, x_ptr[i * num_cols + j]); } + y_ptr[j] = reduced; + } +} + +} // namespace + +struct RowOffsetFunctor final { + OF_DEVICE_FUNC explicit RowOffsetFunctor(int32_t num_cols) : num_cols_(num_cols) {} + OF_DEVICE_FUNC int32_t operator()(const int32_t& x) const { return x * num_cols_; } + int32_t num_cols_; +}; + +template<typename T, const T (*binary_func)(const T, const T)> +struct NdarrayScalarReduce<DeviceType::kGPU, T, binary_func> final { + static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + return y.shape().ElemNum() == 1; + } + + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + CHECK(Matched(y, x)); + size_t x_size = x.shape().ElemNum(); + size_t tmp_storage_bytes = 0; + auto DoReduce = [&](T* tmp_storage_ptr) { + int retcode = + cub::DeviceReduce::Reduce(tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), x_size, + typename CubFunctor4BianryFunc<T, binary_func>::type(), + UnitOfBinaryFunc<T, binary_func>::value, ctx->cuda_stream()); + CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error"; + }; + DoReduce(nullptr); + CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes); + DoReduce(tmp_storage.ptr()); + } +}; + +template<typename T, const T (*binary_func)(const T, const T)> +struct NdarrayMatrixRowReduce<DeviceType::kGPU, T, binary_func> final { + static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; } + const auto& x_squeezed = SqueezeRight(x.shape()); + const auto& y_squeezed = SqueezeRight(y.shape()); + if (x_squeezed.NumAxes() == 0) { return false; } + for (int i = 0; i < y_squeezed.NumAxes(); ++i) { + if (x_squeezed.At(i) != y_squeezed.At(i)) { return false; } + } + CHECK_EQ(x.shape().ElemNum() % y.shape().ElemNum(), 0); + return true; + } + + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + CHECK(Matched(y, x)); + int32_t num_rows = y.shape().ElemNum(); + int32_t num_cols = x.shape().ElemNum() / y.shape().ElemNum(); + RowOffsetFunctor get_row_offset(num_cols); + cub::CountingInputIterator<int32_t> counting_intput_it(0); + cub::TransformInputIterator<int32_t, RowOffsetFunctor, cub::CountingInputIterator<int32_t>> + transform_input_iter(counting_intput_it, get_row_offset); + size_t tmp_storage_bytes = 0; + auto DoReduce = [&](T* tmp_storage_ptr) { + int retcode = cub::DeviceSegmentedReduce::Reduce( + tmp_storage_ptr, tmp_storage_bytes, x.ptr(), y.ptr(), num_rows, transform_input_iter, + transform_input_iter + 1, typename CubFunctor4BianryFunc<T, binary_func>::type(), + UnitOfBinaryFunc<T, binary_func>::value, ctx->cuda_stream()); + CHECK_EQ(retcode, 0) << "cub::DeviceSegmentedReduce::Reduce error"; + }; + DoReduce(nullptr); + CHECK_GE(tmp_storage.shape().ElemNum() * sizeof(T), tmp_storage_bytes); + DoReduce(tmp_storage.ptr()); + } + + private: + static XpuShape SqueezeRight(const XpuShape& shape) { + std::vector<int64_t> dim_vec; + for (int i = 0; i < shape.NumAxes(); ++i) { dim_vec.push_back(shape.At(i)); } + for (int i = shape.NumAxes() - 1; i >= 0; --i) { + if (dim_vec.at(i) != 1) { break; } + dim_vec.pop_back(); + } + if (dim_vec.empty()) { dim_vec.push_back(1LL); } + return XpuShape(Shape(dim_vec)); + } +}; + +template<typename T, const T (*binary_func)(const T, const T)> +struct NdarrayMatrixColReduce<DeviceType::kGPU, T, binary_func> final { + static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + if (y.shape().ElemNum() > GetMaxVal<int32_t>()) { return false; } + const auto& x_squeezed = SqueezeLeft(x.shape()); + const auto& y_squeezed = SqueezeLeft(y.shape()); + if (x_squeezed.NumAxes() == 0) { return false; } + for (int i = 0; i < y_squeezed.NumAxes(); ++i) { + if (x_squeezed.At(x_squeezed.NumAxes() - 1 - i) + != y_squeezed.At(y_squeezed.NumAxes() - 1 - i)) { + return false; + } + } + CHECK_EQ(x.shape().ElemNum() % y.shape().ElemNum(), 0); + return true; + } + + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + CHECK(Matched(y, x)); + int32_t num_rows = x.shape().ElemNum() / y.shape().ElemNum(); + int32_t num_cols = y.shape().ElemNum(); + RUN_CUDA_KERNEL((NdarrayMatrixColReduceNaiveCudaKernel<T, binary_func>), ctx, num_cols, y.ptr(), + x.ptr(), num_rows, num_cols); + } + + private: + static XpuShape SqueezeLeft(const XpuShape& shape) { + std::vector<int64_t> dim_vec; + bool all_squeezed = false; + for (int i = 0; i < shape.NumAxes(); ++i) { + if (all_squeezed == false) { + if (shape.At(i) == 1) { continue; } + all_squeezed = true; + } + dim_vec.push_back(shape.At(i)); + } + if (dim_vec.empty()) { dim_vec.push_back(1LL); } + return XpuShape(Shape(dim_vec)); + } +}; + +namespace { + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +__global__ void NdArrayReduceGpuImplaceReduceAxis(const XpuReducedNdarray<T, NDIMS> dst_reduced, + const XpuReducedNdarray<T, NDIMS> x, int axis) { + NdArrayReduceCore<T, NDIMS, binary_func>::ReduceAxis(dst_reduced, x, axis); +} + +} // namespace + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayReduceCoreWrapper<DeviceType::kGPU, T, NDIMS, binary_func> final { + static void ReduceAxis(DeviceCtx* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced, + const XpuReducedNdarray<T, NDIMS>& x, int axis) { + size_t n = x.host_shape().HostElemNum(); + RUN_CUDA_KERNEL((NdArrayReduceGpuImplaceReduceAxis<T, NDIMS, binary_func>), ctx, n, dst_reduced, + x, axis); + } +}; + +#define INSTANTIATE_NDARRAY_REDUCE_IMPL(dtype, binary_func) \ + template struct NdarrayScalarReduce<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \ + template struct NdarrayMatrixRowReduce<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype), binary_func>; \ + template struct NdarrayMatrixColReduce<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype), binary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_IMPL, ARITHMETIC_DATA_TYPE_SEQ, + REDUCE_BINARY_FUNC_SEQ); + +#define INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER(dtype_pair, NDIMS, binary_func) \ + template struct NdArrayReduceCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS, \ + binary_func>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_REDUCE_CORE_WRAPPER, ARITHMETIC_DATA_TYPE_SEQ, + DIM_SEQ, REDUCE_BINARY_FUNC_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.h b/oneflow/core/ndarray/ndarray_reduce_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d6191db2b79c352c9a5668a80d2348cb410ad175 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_reduce_impl.h @@ -0,0 +1,109 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_ + +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/switch_func.h" +#include "oneflow/core/ndarray/xpu_ndarray_assign.h" +#include "oneflow/core/ndarray/binary_func.h" + +namespace oneflow { + +#define DECLARE_NDARRAY_REDUCE_IMPL(struct_name) \ + template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)> \ + struct struct_name final { \ + static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x); \ + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, \ + const XpuVarNdarray<T>& tmp_storage); \ + } +DECLARE_NDARRAY_REDUCE_IMPL(NdarrayScalarReduce); +DECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixRowReduce); +DECLARE_NDARRAY_REDUCE_IMPL(NdarrayMatrixColReduce); +#undef DECLARE_NDARRAY_REDUCE_IMPL + +template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)> +struct NdarrayNoReduce final { + static bool Matched(const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + return x.shape() == y.shape(); + } + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + XpuNdArrayAssign<device_type, T>::Assign(ctx, y, x); + } +}; + +template<DeviceType device_type, typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayReduceCoreWrapper final { + static void ReduceAxis(DeviceCtx* ctx, const XpuReducedNdarray<T, NDIMS>& dst_reduced, + const XpuReducedNdarray<T, NDIMS>& x, int axis); +}; + +template<DeviceType device_type, typename T, const T (*binary_func)(const T, const T)> +struct NdarrayDefaultReduce final { + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + return SwitchReduce(SwitchCase(y.shape().NumAxes()), ctx, y, x, tmp_storage); + } + + private: +#define DEFINE_NDARRAY_REDUCE(func_name, NDIMS) func_name<NDIMS> + DEFINE_STATIC_SWITCH_FUNC(void, Reduce, DEFINE_NDARRAY_REDUCE, MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); +#undef DEFINE_NDARRAY_REDUCE + + template<int NDIMS> + static void Reduce(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + XpuVarNdarray<T> storage(x.shape(), tmp_storage.ptr()); + XpuShape cur_shape(x.shape()); + CHECK_EQ(y.shape().NumAxes(), x.shape().NumAxes()); + CHECK(x.shape() != y.shape()); + XpuNdArrayAssign<device_type, T>::Assign(ctx, storage, x); + for (int i = 0; i < x.shape().NumAxes(); ++i) { + if (y.shape().At(i) == x.shape().At(i)) { continue; } + CHECK_EQ(y.shape().At(i), 1); + CHECK_GT(x.shape().At(i), y.shape().At(i)); + ImplaceReduceAxis<NDIMS>(ctx, i, storage, &cur_shape); + } + XpuReducedNdarray<T, NDIMS> reduced(y.shape(), storage); + XpuNdArrayAssign<device_type, T>::template Assign<NDIMS>(ctx, y, reduced); + } + + template<int NDIMS> + static void ImplaceReduceAxis(DeviceCtx* ctx, int axis, const XpuVarNdarray<T>& implace, + XpuShape* cur_shape) { + int64_t target_elem_num = cur_shape->ElemNum() / cur_shape->At(axis); + while (cur_shape->At(axis) > 1) { + int64_t shrink = 8 + std::sqrt(target_elem_num); + XpuReducedNdarray<T, NDIMS> from(*cur_shape, implace); + int64_t new_dim_value = (cur_shape->At(axis) + (shrink - 1)) / shrink; + cur_shape->Set(axis, new_dim_value); + XpuReducedNdarray<T, NDIMS> to(*cur_shape, implace); + NdArrayReduceCoreWrapper<device_type, T, NDIMS, binary_func>::ReduceAxis(ctx, to, from, axis); + } + } +}; + +template<typename T, int NDIMS, const T (*binary_func)(const T, const T)> +struct NdArrayReduceCore final { + template<typename X> + OF_DEVICE_FUNC static void ReduceAxis(const XpuReducedNdarray<T, NDIMS>& dst_reduced, const X& x, + int axis) { + size_t n = dst_reduced.shape().ElemNum(); + int64_t dst_dim_val = dst_reduced.shape().At(axis); + XPU_1D_KERNEL_LOOP(i, n) { + T* dst_reduced_ptr = dst_reduced.template Mut(i); + int64_t coord[NDIMS]; + dst_reduced.shape().template Offset2Coordinate<NDIMS>(i, coord); + T reduced = UnitOfBinaryFunc<T, binary_func>::value; + while (coord[axis] < x.shape().At(axis)) { + reduced = binary_func(reduced, x.template Get<NDIMS>(coord)); + coord[axis] += dst_dim_val; + } + *dst_reduced_ptr = reduced; + } + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_REDUCE_IMPL_H_ diff --git a/oneflow/core/ndarray/ndarray_reduce_test.cpp b/oneflow/core/ndarray/ndarray_reduce_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3d7e11506b8073a36e3c73ba68e6956d08a9885 --- /dev/null +++ b/oneflow/core/ndarray/ndarray_reduce_test.cpp @@ -0,0 +1,42 @@ +#include "oneflow/core/ndarray/ndarray_util.h" +#include <gtest/gtest.h> + +namespace oneflow { + +namespace test { + +namespace { + +void TestMiddleAxis(int num) { + std::vector<int32_t> data(num * num * num, 1); + std::vector<int32_t> tmp_storage(num * num * num, -8888); + XpuVarNdarray<const int32_t> x(XpuShape(Shape({num, num, num})), data.data()); + XpuVarNdarray<int32_t> tmp(XpuShape(Shape({num, num, num})), tmp_storage.data()); + std::vector<int32_t> ret(num * num, -999); + XpuVarNdarray<int32_t> y(XpuShape(Shape({num, 1, num})), ret.data()); + NdArrayReduce<DeviceType::kCPU, int32_t, BinaryFuncAdd>::Reduce(nullptr, y, x, tmp); + for (int i = 0; i < num; ++i) { + for (int j = 0; j < num; ++j) { ASSERT_EQ(ret[i * num + j], num); } + } +} + +} // namespace + +TEST(NdArrayReduce, sum) { + std::vector<int32_t> data(100, 1); + std::vector<int32_t> tmp_storage(100, -1); + XpuVarNdarray<const int32_t> x(XpuShape(Shape({100})), data.data()); + XpuVarNdarray<int32_t> tmp(XpuShape(Shape({100})), tmp_storage.data()); + int32_t ret = -100; + XpuVarNdarray<int32_t> y(XpuShape(Shape({1})), &ret); + NdArrayReduce<DeviceType::kCPU, int32_t, BinaryFuncAdd>::Reduce(nullptr, y, x, tmp); + ASSERT_EQ(ret, 100); +} + +TEST(NdArrayReduce, middle_axis_2) { TestMiddleAxis(10); } + +TEST(NdArrayReduce, middle_axis_10) { TestMiddleAxis(125); } + +} // namespace test + +} // namespace oneflow diff --git a/oneflow/core/ndarray/ndarray_util.h b/oneflow/core/ndarray/ndarray_util.h new file mode 100644 index 0000000000000000000000000000000000000000..64813f3b356c2086d64c397e11df77795d89f8ad --- /dev/null +++ b/oneflow/core/ndarray/ndarray_util.h @@ -0,0 +1,84 @@ +#ifndef ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_ +#define ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/ndarray_reduce.h" +#include "oneflow/core/ndarray/ndarray_apply_unary.h" +#include "oneflow/core/ndarray/ndarray_apply_broadcast_unary.h" +#include "oneflow/core/ndarray/ndarray_apply_broadcast_binary.h" +#include "oneflow/core/ndarray/xpu_reduced_ndarray.h" +#include "oneflow/core/common/switch_func.h" +#include "oneflow/core/common/util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +struct NdarrayUtil final { + template<const T (*unary_func)(const T)> + static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x) { + CHECK_EQ(x.shape().NumAxes(), y.shape().NumAxes()); + return Unary<unary_func>::SwitchBroadcastApply(SwitchCase(x.shape().NumAxes()), ctx, y, x); + } + static void BroadcastTo(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x) { + return BroadcastApply<UnaryFuncIdentity>(ctx, y, x); + } + template<const T (*binary_func)(const T, const T)> + static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) { + CHECK_EQ(a.shape().NumAxes(), y.shape().NumAxes()); + CHECK_EQ(b.shape().NumAxes(), y.shape().NumAxes()); + return Binary<binary_func>::SwitchBroadcastApply(SwitchCase(y.shape().NumAxes()), ctx, y, a, b); + } + static void ReduceSum(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + return NdArrayReduce<device_type, T, BinaryFuncAdd>::Reduce(ctx, y, x, tmp_storage); + } + static void ReduceMax(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + return NdArrayReduce<device_type, T, BinaryFuncMax>::Reduce(ctx, y, x, tmp_storage); + } + static void ReduceMin(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x, + const XpuVarNdarray<T>& tmp_storage) { + return NdArrayReduce<device_type, T, BinaryFuncMin>::Reduce(ctx, y, x, tmp_storage); + } + template<const T (*unary_func)(const T)> + static void ImplaceApplyUnary(DeviceCtx* ctx, const XpuVarNdarray<T>& y) { + return NdArrayApplyUnary<device_type, T, unary_func>::ImplaceApply(ctx, y); + } + + private: + template<const T (*unary_func)(const T)> + struct Unary final { + template<int NDIMS> + static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& x) { + return NdArrayApplyBroadcastUnary<device_type, T, NDIMS, unary_func>::Apply(ctx, y, x); + } +#define DEFINE_NDARRAY_BROADCAST_UNARY(func_name, NDIMS) \ + NdarrayUtil<device_type, T>::Unary<unary_func>::func_name<NDIMS> + DEFINE_STATIC_SWITCH_FUNC(void, BroadcastApply, DEFINE_NDARRAY_BROADCAST_UNARY, + MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); +#undef DEFINE_NDARRAY_BROADCAST_UNARY + }; + + template<const T (*binary_func)(const T, const T)> + struct Binary final { + template<int NDIMS> + static void BroadcastApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) { + return NdArrayApplyBroadcastBinary<device_type, T, NDIMS, binary_func>::Apply(ctx, y, a, b); + } +#define DEFINE_NDARRAY_BROADCAST_BINARY(func_name, NDIMS) \ + NdarrayUtil<device_type, T>::Binary<binary_func>::func_name<NDIMS> + DEFINE_STATIC_SWITCH_FUNC(void, BroadcastApply, DEFINE_NDARRAY_BROADCAST_BINARY, + MAKE_NDIM_CTRV_SEQ(DIM_SEQ)); +#undef DEFINE_NDARRAY_BROADCAST_BINARY + }; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_NDARRAY_UTIL_H_ diff --git a/oneflow/core/ndarray/unary_func.h b/oneflow/core/ndarray/unary_func.h new file mode 100644 index 0000000000000000000000000000000000000000..b6642a785ec5bde07bafccd390cd36f88bc1988b --- /dev/null +++ b/oneflow/core/ndarray/unary_func.h @@ -0,0 +1,74 @@ +#ifndef ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_ +#define ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_ + +#include "oneflow/core/common/util.h" + +namespace oneflow { + +#define ARITHMETIC_UNARY_FUNC_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(UnaryFuncIdentity) \ + OF_PP_MAKE_TUPLE_SEQ(UnaryFuncMinus) \ + OF_PP_MAKE_TUPLE_SEQ(UnaryFuncLog2) \ + OF_PP_MAKE_TUPLE_SEQ(UnaryFuncExp2) + +template<typename T> +OF_DEVICE_FUNC const T UnaryFuncIdentity(const T x) { + return x; +} + +template<typename T> +OF_DEVICE_FUNC const T UnaryFuncMinus(const T x) { + return -x; +} + +template<typename T> +OF_DEVICE_FUNC + typename std::enable_if<std::is_same<T, float>::value || std::is_same<T, double>::value, + const T>::type + UnaryFuncLog2(const T x) { +#if defined(__CUDACC__) + return log2(x); +#else + return std::log2(x); +#endif +} + +template<typename T> +OF_DEVICE_FUNC + typename std::enable_if<!(std::is_same<T, float>::value || std::is_same<T, double>::value), + const T>::type + UnaryFuncLog2(const T x) { +#if defined(__CUDACC__) + return log2(static_cast<float>(x)); +#else + return std::log2(x); +#endif +} + +template<typename T> +OF_DEVICE_FUNC + typename std::enable_if<std::is_same<T, float>::value || std::is_same<T, double>::value, + const T>::type + UnaryFuncExp2(const T x) { +#if defined(__CUDACC__) + return exp2(x); +#else + return std::exp2(x); +#endif +} + +template<typename T> +OF_DEVICE_FUNC + typename std::enable_if<!(std::is_same<T, float>::value || std::is_same<T, double>::value), + const T>::type + UnaryFuncExp2(const T x) { +#if defined(__CUDACC__) + return exp2(static_cast<float>(x)); +#else + return std::exp2(x); +#endif +} + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_UNARY_FUNC_H_ diff --git a/oneflow/core/ndarray/xpu_binary_func_ndarray.h b/oneflow/core/ndarray/xpu_binary_func_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..be2debecfd9b197b0245e1af2649ad599a50ff08 --- /dev/null +++ b/oneflow/core/ndarray/xpu_binary_func_ndarray.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_ + +namespace oneflow { + +template<typename T, const T (*binary_func)(const T, const T), typename A, typename B> +class XpuBinaryFuncNdarray final { + public: + OF_DEVICE_FUNC XpuBinaryFuncNdarray(const A& a, const B& b) : a_(a), b_(b) {} + + template<int NDIMS> + OF_DEVICE_FUNC T Get(int64_t offset) const { + return binary_func(a_.Get<NDIMS>(offset), b_.Get<NDIMS>(offset)); + } + + private: + const A& a_; + const B& b_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_BINARY_FUNC_NDARRAY_H_ diff --git a/oneflow/core/ndarray/xpu_broadcast_ndarray.h b/oneflow/core/ndarray/xpu_broadcast_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..9f46e992df6b9ae718cda6a0c661d0f5eb7c9da7 --- /dev/null +++ b/oneflow/core/ndarray/xpu_broadcast_ndarray.h @@ -0,0 +1,53 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_ + +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/xpu_ndarray_base.h" + +namespace oneflow { + +template<typename T, int NDIMS> +struct XpuBroadcastNdarrayUtil; + +template<typename T> +class XpuBroadcastNdarray final : public XpuNdarrayBase<XpuBroadcastNdarray<T>, T> { + public: + OF_DEVICE_FUNC XpuBroadcastNdarray(const XpuShape& shape, const XpuVarNdarray<T>& var) + : shape_(shape), var_(var) {} + OF_DEVICE_FUNC ~XpuBroadcastNdarray() = default; + + template<int NDIMS> + OF_DEVICE_FUNC T Get(int64_t offset) const { + int64_t coord[NDIMS]; + shape_.template Offset2Coordinate<NDIMS>(offset, coord); + XpuBroadcastNdarrayUtil<T, NDIMS>::SrcCoordinate(var_.shape(), coord); + return var_.template Get<NDIMS>(coord); + } + + OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; } + OF_DEVICE_FUNC const XpuVarNdarray<T>& var() const { return var_; } + + private: + const XpuShape& shape_; + const XpuVarNdarray<T>& var_; +}; + +#define IMPLACE_SET_SRC_COORD(i) coord[i] %= src_shape.At(i); +#define SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(n) \ + template<typename T> \ + struct XpuBroadcastNdarrayUtil<T, n + 1> final { \ + OF_DEVICE_FUNC static void SrcCoordinate(const XpuShape& src_shape, int64_t coord[n + 1]) { \ + OF_PP_FOR_EACH_TUPLE(IMPLACE_SET_SRC_COORD, GET_SEQ(n)); \ + } \ + } +SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(0); +SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(1); +SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(2); +SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(3); +SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(4); +#undef SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL +#undef IMPLACE_SET_SRC_COORD + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_BROADCAST_NDARRAY_H_ diff --git a/oneflow/core/ndarray/xpu_ndarray_assign.cu b/oneflow/core/ndarray/xpu_ndarray_assign.cu new file mode 100644 index 0000000000000000000000000000000000000000..c53ed205a87b4c019b5470651c61c1efc07ef4e5 --- /dev/null +++ b/oneflow/core/ndarray/xpu_ndarray_assign.cu @@ -0,0 +1,29 @@ +#include "oneflow/core/ndarray/ndarray_assign_core.h" +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +namespace { + +template<typename T, int NDIMS> +__global__ void NdArrayAssignGpu(XpuVarNdarray<T> y, const XpuReducedNdarray<T, NDIMS> reduced) { + NdArrayAssignCore<T, NDIMS>::Assign(y, reduced); +} + +} // namespace + +template<typename T, int NDIMS> +struct NdArrayAssignCoreWrapper<DeviceType::kGPU, T, NDIMS> final { + static void Assign(DeviceCtx* ctx, XpuVarNdarray<T>* y, + const XpuReducedNdarray<T, NDIMS>& reduced) { + size_t n = y->host_shape().HostElemNum(); + RUN_CUDA_KERNEL((NdArrayAssignGpu<T, NDIMS>), ctx, n, *y, reduced); + } +}; + +#define INSTANTIATE_NDARRAY_ASSIGN(dtype_pair, NDIMS) \ + template struct NdArrayAssignCoreWrapper<DeviceType::kGPU, OF_PP_PAIR_FIRST(dtype_pair), NDIMS>; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NDARRAY_ASSIGN, ARITHMETIC_DATA_TYPE_SEQ, DIM_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/ndarray/xpu_ndarray_assign.h b/oneflow/core/ndarray/xpu_ndarray_assign.h new file mode 100644 index 0000000000000000000000000000000000000000..ecf34075b2897287f49d6add7a5c9636f458f814 --- /dev/null +++ b/oneflow/core/ndarray/xpu_ndarray_assign.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_ + +#include "oneflow/core/ndarray/ndarray_assign_core.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T> +struct XpuNdArrayAssign final { + template<int NDIMS> + static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y, + const XpuReducedNdarray<T, NDIMS>& reduced) { + NdArrayAssignCoreWrapper<device_type, T, NDIMS>::Assign(ctx, y, reduced); + } + static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) { + CHECK(y.shape() == x.shape()); + if (x.ptr() == y.ptr()) { return; } + Memcpy<device_type>(ctx, y.ptr(), x.ptr(), y.shape().ElemNum() * sizeof(T)); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_ASSIGN_H_ diff --git a/oneflow/core/ndarray/xpu_ndarray_base.h b/oneflow/core/ndarray/xpu_ndarray_base.h new file mode 100644 index 0000000000000000000000000000000000000000..aa3fa8e89fdcc23b950e630cafd4b022b4a41e76 --- /dev/null +++ b/oneflow/core/ndarray/xpu_ndarray_base.h @@ -0,0 +1,48 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_ + +namespace oneflow { + +template<typename T, const T (*unary_func)(const T), typename X> +class XpuUnaryFuncNdarray; +template<typename T, const T (*binary_func)(const T, const T), typename A, typename B> +class XpuBinaryFuncNdarray; +template<typename T> +class XpuBroadcastNdarray; +template<typename T, int, typename X> +class XpuTransposeNdarray; +template<typename T, int, typename X> +class XpuReshapeNdarray; + +template<typename DerivedT, typename T> +class XpuNdarrayBase { + public: + OF_DEVICE_FUNC XpuNdarrayBase() = default; + OF_DEVICE_FUNC ~XpuNdarrayBase() = default; + + template<const T (*unary_func)(const T)> + OF_DEVICE_FUNC XpuUnaryFuncNdarray<T, unary_func, DerivedT> UnaryFunc() const { + return XpuUnaryFuncNdarray<T, unary_func, DerivedT>(*static_cast<const DerivedT*>(this)); + } + template<const T (*binary_func)(const T, const T), typename X> + OF_DEVICE_FUNC XpuBinaryFuncNdarray<T, binary_func, DerivedT, X> BinaryFunc(const X& x) const { + return XpuBinaryFuncNdarray<T, binary_func, DerivedT, X>(*static_cast<const DerivedT*>(this), + x); + } + OF_DEVICE_FUNC XpuBroadcastNdarray<const T> Broadcast(const XpuShape& shape) const { + return XpuBroadcastNdarray<const T>(shape, *static_cast<const DerivedT*>(this)); + } + template<int NDIMS> + OF_DEVICE_FUNC XpuTransposeNdarray<T, NDIMS, DerivedT> Transpose( + const int64_t perm[NDIMS]) const { + return XpuTransposeNdarray<T, NDIMS, DerivedT>(*static_cast<const DerivedT*>(this), perm); + } + template<int NDIMS> + OF_DEVICE_FUNC XpuReshapeNdarray<T, NDIMS, DerivedT> Reshape(const int64_t shape[NDIMS]) { + return XpuReshapeNdarray<T, NDIMS, DerivedT>(*static_cast<const DerivedT*>(this), shape); + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_NDARRAY_BASE_H_ diff --git a/oneflow/core/ndarray/xpu_reduced_ndarray.h b/oneflow/core/ndarray/xpu_reduced_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..d2bdd21e979db5406fecf023dcd3ba3b7c874c4f --- /dev/null +++ b/oneflow/core/ndarray/xpu_reduced_ndarray.h @@ -0,0 +1,51 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_ + +#include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/ndarray/unary_func.h" + +namespace oneflow { + +template<typename T, int NDIMS, typename X = XpuVarNdarray<T>> +class XpuReducedNdarray final { + public: + OF_DEVICE_FUNC XpuReducedNdarray(const XpuShape& shape, const X& data) + : shape_(shape), data_(data) {} + + OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; } + const XpuShape& host_shape() const { return shape_; } + OF_DEVICE_FUNC const X& data() const { return data_; } + + template<int ndims = NDIMS> + OF_DEVICE_FUNC T Get(int64_t offset) const { + int64_t coord[NDIMS]; + shape_.template Offset2Coordinate<NDIMS>(offset, coord); + return Get(coord); + } + + template<int ndims = NDIMS> + OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const { + return data_.template Get<ndims>(coord); + } + + template<int ndims = NDIMS> + OF_DEVICE_FUNC T* Mut(int64_t offset) const { + int64_t coord[NDIMS]; + shape_.template Offset2Coordinate<NDIMS>(offset, coord); + return Mut(coord); + } + + template<int ndims = NDIMS> + OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { + return data_.template Mut<NDIMS>(coord); + } + + private: + XpuShape shape_; + X data_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_REDUCED_NDARRAY_H_ diff --git a/oneflow/core/ndarray/xpu_reshape_ndarray.h b/oneflow/core/ndarray/xpu_reshape_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..612fa8fad97acc8f047f357de935cb1a532b5f2d --- /dev/null +++ b/oneflow/core/ndarray/xpu_reshape_ndarray.h @@ -0,0 +1,39 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_ + +namespace oneflow { + +template<typename T, int NDIMS, typename X = XpuVarNdarray<T>> +class XpuReshapeNdarray final { + public: + OF_DEVICE_FUNC XpuReshapeNdarray(const X& x, const int64_t dim[NDIMS]) + : x_(x), shape_(dim, NDIMS) {} + + template<int ndims = NDIMS> + OF_DEVICE_FUNC T Get(int64_t offset) const { + return x_.template Get<ndims>(offset); + } + template<int ndims = NDIMS> + OF_DEVICE_FUNC T* Mut(int64_t offset) const { + return x_.template Mut<ndims>(offset); + } + template<int ndims = NDIMS> + OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const { + return Get<ndims>(Coord2Offset(coord)); + } + template<int ndims = NDIMS> + OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { + return Get<NDIMS>(Coord2Offset(coord)); + } + + private: + OF_DEVICE_FUNC int64_t Coord2Offset(const int64_t coord[NDIMS]) const { + return XpuShapeUtil<NDIMS>::Coord2Offset(shape_, coord); + } + const X& x_; + XpuShape shape_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_RESHAPE_NDARRAY_H_ diff --git a/oneflow/core/ndarray/xpu_shape.h b/oneflow/core/ndarray/xpu_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..0099085d7b1155dc47d05ca1e9f0c46367106f19 --- /dev/null +++ b/oneflow/core/ndarray/xpu_shape.h @@ -0,0 +1,96 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_ + +#include "oneflow/core/common/shape.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ndarray/xpu_util.h" + +namespace oneflow { + +template<int NDIMS> +struct XpuShapeUtil; + +class XpuShape final { + public: + explicit XpuShape(const Shape& shape); + OF_DEVICE_FUNC XpuShape(const int64_t dim[], int num_axes); + OF_DEVICE_FUNC XpuShape(const XpuShape&) = default; + + OF_DEVICE_FUNC int64_t At(int64_t dim) const { return dim_[dim]; } + + OF_DEVICE_FUNC size_t ElemNum() const { return elem_num_; } + OF_DEVICE_FUNC size_t NumAxes() const { return num_axes_; } + size_t HostElemNum() const { return elem_num_; } + bool operator==(const XpuShape&) const; + bool operator!=(const XpuShape& rhs) const { return !(*this == rhs); } + + OF_DEVICE_FUNC void Set(int64_t axis, int64_t value) { + dim_[axis] = value; + UpdateDimElemNumAndElemNum(); + } + + template<int NDIMS> + OF_DEVICE_FUNC int64_t Coordinate2Offset(const int64_t coord[NDIMS]) const { + return XpuShapeUtil<NDIMS>::Coordinate2Offset(*this, coord); + } + template<int NDIMS> + OF_DEVICE_FUNC void Offset2Coordinate(int64_t offset, int64_t coord[NDIMS]) const { + XpuShapeUtil<NDIMS>::Offset2Coordinate(*this, offset, coord); + } + + OF_DEVICE_FUNC void UpdateDimElemNumAndElemNum() { + elem_num_ = 1; + for (int i = num_axes_ - 1; i >= 0; --i) { + dim_elem_num_[i] = elem_num_; + elem_num_ *= dim_[i]; + } + } + + size_t num_axes_; + size_t elem_num_; + int64_t dim_[OF_PP_SEQ_SIZE(DIM_SEQ)]; + int64_t dim_elem_num_[OF_PP_SEQ_SIZE(DIM_SEQ)]; +}; + +template<> +struct XpuShapeUtil<1> final { + OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape, const int64_t coord[1]) { + return coord[0]; + } + OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset, + int64_t coord[1]) { + coord[0] = offset; + } +}; + +#define COORD_MUL_STRIDE(i) coord[i] * shape.dim_elem_num_[i] + +#define EXTRACT_COORD(i) \ + coord[i] = offset / shape.dim_elem_num_[i]; \ + offset %= shape.dim_elem_num_[i]; + +#define SPECIALIZE_XPU_SHAPE_UTIL(n) \ + template<> \ + struct XpuShapeUtil<n + 2> final { \ + OF_DEVICE_FUNC static int64_t Coordinate2Offset(const XpuShape& shape, \ + const int64_t coord[n + 2]) { \ + return OF_PP_FOR_EACH_TUPLE(COORD_MUL_STRIDE, GET_SEQ(n)) coord[n + 1]; \ + } \ + OF_DEVICE_FUNC static void Offset2Coordinate(const XpuShape& shape, int64_t offset, \ + int64_t coord[n + 2]) { \ + OF_PP_FOR_EACH_TUPLE(EXTRACT_COORD, GET_SEQ(n)); \ + coord[n + 1] = offset; \ + } \ + }; + +SPECIALIZE_XPU_SHAPE_UTIL(0); +SPECIALIZE_XPU_SHAPE_UTIL(1); +SPECIALIZE_XPU_SHAPE_UTIL(2); +SPECIALIZE_XPU_SHAPE_UTIL(3); +#undef SPECIALIZE_XPU_SHAPE_UTIL +#undef EXTRACT_COORD +#undef COORD_MUL_STRIDE + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_SHAPE_H_ diff --git a/oneflow/core/ndarray/xpu_transpose_ndarray.h b/oneflow/core/ndarray/xpu_transpose_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..e233076541d6c6ee558214f185c850b2fe98dce6 --- /dev/null +++ b/oneflow/core/ndarray/xpu_transpose_ndarray.h @@ -0,0 +1,64 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_ + +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template<typename T, int NDIMS, typename X = XpuVarNdarray<T>> +class XpuTransposeNdarray final { + public: + OF_DEVICE_FUNC XpuTransposeNdarray(const X& x, const int64_t perm[NDIMS]) + : x_(x), shape_(x.shape()) { + for (int i = 0; i < NDIMS; ++i) { + perm_[i] = perm[i]; + shape_.Set(i, x.shape().At(perm[i])); + } + } + + template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type> + OF_DEVICE_FUNC T Get(int64_t offset) const { + int64_t coord[NDIMS]; + Offset2Coord(offset, coord); + return Get(coord); + } + + template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type> + OF_DEVICE_FUNC T* Mut(int64_t offset) const { + int64_t coord[NDIMS]; + Offset2Coord(offset, coord); + return Mut(coord); + } + + template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type> + OF_DEVICE_FUNC T Get(int64_t coord[ndims]) const { + int64_t permuted_coord[NDIMS]; + PermuteCoord(coord, permuted_coord); + return x_.template Get<ndims>(permuted_coord); + } + + template<int ndims, typename = typename std::enable_if<ndims == NDIMS>::type> + OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { + int64_t permuted_coord[NDIMS]; + PermuteCoord(coord, permuted_coord); + return x_.template Mut<NDIMS>(permuted_coord); + } + + private: + OF_DEVICE_FUNC void Offset2Coord(int64_t offset, int64_t coord[NDIMS]) const { + shape_.Offset2Coordinate<NDIMS>(offset, coord); + } + + OF_DEVICE_FUNC void PermuteCoord(const int64_t coord[NDIMS], + int64_t permuted_coord[NDIMS]) const { + for (int i = 0; i < NDIMS; ++i) { permuted_coord[perm_[i]] = coord[i]; } + } + + const X& x_; + XpuShape shape_; + int64_t perm_[NDIMS]; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_TRANSPOSE_NDARRAY_H_ diff --git a/oneflow/core/ndarray/xpu_unary_func_ndarray.h b/oneflow/core/ndarray/xpu_unary_func_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..d4dbdd2a4efb399cb695c14403255538eae19ad6 --- /dev/null +++ b/oneflow/core/ndarray/xpu_unary_func_ndarray.h @@ -0,0 +1,22 @@ +#ifndef ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_ +#define ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_ + +namespace oneflow { + +template<typename T, const T (*unary_func)(const T), typename X> +class XpuUnaryFuncNdarray final { + public: + OF_DEVICE_FUNC XpuUnaryFuncNdarray(const X& x) : x_(x) {} + + template<int NDIMS> + OF_DEVICE_FUNC T Get(int64_t offset) const { + return unary_func(x_.Get<NDIMS>(offset)); + } + + private: + const X& x_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_UNARY_FUNC_NDARRAY_H_ diff --git a/oneflow/core/ndarray/xpu_util.h b/oneflow/core/ndarray/xpu_util.h new file mode 100644 index 0000000000000000000000000000000000000000..2876e46b1a6112013cb11ffd417fd36b4062757d --- /dev/null +++ b/oneflow/core/ndarray/xpu_util.h @@ -0,0 +1,40 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_ + +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/device/cuda_util.h" + +namespace oneflow { + +#if defined(__CUDACC__) +#define XPU_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP(i, n) +#else +#define XPU_1D_KERNEL_LOOP(i, n) FOR_RANGE(int64_t, i, 0, n) +#endif + +#if defined(__CUDACC__) +#define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n) \ + for (int64_t i = blockIdx.x; i < (m); i += gridDim.x) \ + for (int64_t j = threadIdx.x; j < (n); j += blockDim.x) +#else +#define XPU_BLOAD_THREAD_2D_KERNEL_LOOP(i, j, m, n) \ + for (int64_t i = 0; i < (m); ++i) \ + for (int64_t j = 0; j < (n); ++j) +#endif + +#if defined(__CUDACC__) +#define OF_GLOBAL_FUNC __global__ +#else +#define OF_GLOBAL_FUNC +#endif + +#define GET_SEQ(n) OF_PP_CAT(OF_PP_CAT(GET_SEQ_, n), ) +#define GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(0) +#define GET_SEQ_1 GET_SEQ_0 OF_PP_MAKE_TUPLE_SEQ(1) +#define GET_SEQ_2 GET_SEQ_1 OF_PP_MAKE_TUPLE_SEQ(2) +#define GET_SEQ_3 GET_SEQ_2 OF_PP_MAKE_TUPLE_SEQ(3) +#define GET_SEQ_4 GET_SEQ_3 OF_PP_MAKE_TUPLE_SEQ(4) +#define GET_SEQ_5 GET_SEQ_5 OF_PP_MAKE_TUPLE_SEQ(5) +} + +#endif // ONEFLOW_CORE_NDARRAY_XPU_UTIL_H_ diff --git a/oneflow/core/ndarray/xpu_var_ndarray.h b/oneflow/core/ndarray/xpu_var_ndarray.h new file mode 100644 index 0000000000000000000000000000000000000000..f7ecdc5550611f5bf0479018a6b7bda056c4b921 --- /dev/null +++ b/oneflow/core/ndarray/xpu_var_ndarray.h @@ -0,0 +1,73 @@ +#ifndef ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_ +#define ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_ + +#include "oneflow/core/ndarray/xpu_shape.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/register/blob.h" +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/ndarray/xpu_ndarray_base.h" + +namespace oneflow { + +template<typename T> +class XpuVarNdarray final : public XpuNdarrayBase<XpuVarNdarray<T>, T> { + public: + explicit XpuVarNdarray(const Blob* blob, int ndims_extend_to) + : shape_(blob->shape().CreateLeftExtendedShape(ndims_extend_to)), + ptr_(blob->dptr<typename std::remove_const<T>::type>()) {} + explicit XpuVarNdarray(Blob* blob, int ndims_extend_to) + : shape_(blob->shape().CreateLeftExtendedShape(ndims_extend_to)), ptr_(blob->mut_dptr<T>()) {} + XpuVarNdarray(const Shape& shape, T* ptr) : shape_(shape), ptr_(ptr) {} + OF_DEVICE_FUNC ALWAYS_INLINE XpuVarNdarray(const XpuVarNdarray&) = default; + OF_DEVICE_FUNC ALWAYS_INLINE XpuVarNdarray(const XpuShape& shape, T* ptr) + : shape_(shape), ptr_(ptr) {} + + const XpuShape& host_shape() const { return shape_; } + T* host_ptr() const { return ptr_; } + + OF_DEVICE_FUNC const XpuShape& shape() const { return shape_; } + OF_DEVICE_FUNC T* ptr() const { return ptr_; } + + template<int NDIMS> + OF_DEVICE_FUNC T Get(int64_t offset) const { + return ptr_[offset]; + } + template<int NDIMS> + OF_DEVICE_FUNC T Get(int64_t coord[NDIMS]) const { + return ptr_[shape().template Coordinate2Offset<NDIMS>(coord)]; + } + + template<int NDIMS> + OF_DEVICE_FUNC T* Mut(int64_t offset) const { + return ptr_ + offset; + } + + template<int NDIMS> + OF_DEVICE_FUNC T* Mut(int64_t coord[NDIMS]) const { + return ptr_ + shape().template Coordinate2Offset<NDIMS>(coord); + } + + template<int NDIMS, typename X> + OF_DEVICE_FUNC void Assign(const X& x) const { + size_t n = shape_.ElemNum(); + XPU_1D_KERNEL_LOOP(i, n) { ptr_[i] = x.template Get<NDIMS>(i); } + } + + template<const T (*binary_func)(const T, const T), int NDIMS, typename X> + OF_DEVICE_FUNC void BinaryAssign(const X& x) const { + size_t n = shape_.ElemNum(); + XPU_1D_KERNEL_LOOP(i, n) { + T* ptr_i = ptr_ + i; + *ptr_i = binary_func(*ptr_i, x.template Get<NDIMS>(i)); + } + } + + private: + XpuShape shape_; + T* ptr_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_NDARRAY_XPU_VAR_NDARRAY_H_ diff --git a/oneflow/core/operator/accumulate_op.h b/oneflow/core/operator/accumulate_op.h index b3bd2dfa84a95a25c359215f6a35cfc7ead4d7ff..89e798ac34969666d7b96ba2aa9016ec0af502bc 100644 --- a/oneflow/core/operator/accumulate_op.h +++ b/oneflow/core/operator/accumulate_op.h @@ -16,8 +16,15 @@ class AccumulateOp final : public Operator { void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override {} + void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, + Shape* time_shape) const override { + TODO(); + } private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; diff --git a/oneflow/core/operator/accuracy_op.cpp b/oneflow/core/operator/accuracy_op.cpp index 40372efffe12cfeb4e04fe31d85a469da6bac12b..1d742a6b13184e9e4528d83c44c02f482d45baed 100644 --- a/oneflow/core/operator/accuracy_op.cpp +++ b/oneflow/core/operator/accuracy_op.cpp @@ -7,6 +7,10 @@ void AccuracyOp::InitFromOpConf() { EnrollInputBn("label", false); EnrollOutputBn("accuracy", false); EnrollOutputBn("accuracy_instance_num", false); + if (op_conf().accuracy_conf().has_weight()) { + EnrollInputBn("weight", false); + EnrollDataTmpBn("weight_reduce_tmp"); + } } const PbMessage& AccuracyOp::GetCustomizedConf() const { return op_conf().accuracy_conf(); } @@ -20,8 +24,7 @@ void AccuracyOp::VirtualGenKernelConf( } void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, - std::function<void(OpContext*)>) const { + const ParallelContext* parallel_ctx) const { BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction"); BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label"); CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field()); @@ -34,6 +37,20 @@ void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> Get CHECK_GE(pred_blob_desc->shape().NumAxes(), 2); CHECK_EQ(label_blob_desc->shape(), Shape({pred_blob_desc->shape().At(0)})); + if (op_conf().accuracy_conf().has_weight()) { + const BlobDesc* weight = GetBlobDesc4BnInOp("weight"); + CHECK_EQ(weight->shape(), label_blob_desc->shape()); + CHECK_EQ(weight->data_type(), pred_blob_desc->data_type()); + CHECK_EQ(weight->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field()); + CHECK_EQ(weight->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape()); + if (label_blob_desc->has_dim0_inner_shape()) { + CHECK_EQ(weight->dim0_inner_shape(), label_blob_desc->dim0_inner_shape()); + } + BlobDesc* weight_reduce_tmp = GetBlobDesc4BnInOp("weight_reduce_tmp"); + weight_reduce_tmp->mut_shape() = weight->shape(); + weight_reduce_tmp->set_data_type(weight->data_type()); + } + // accuracy BlobDesc* accuracy_blob_desc = GetBlobDesc4BnInOp("accuracy"); *accuracy_blob_desc = *pred_blob_desc; diff --git a/oneflow/core/operator/accuracy_op.h b/oneflow/core/operator/accuracy_op.h index 2fea279dc66a0b654e932adf850d6f06e784486d..ece825f2c64da93b5edac889217266ffbe312dbf 100644 --- a/oneflow/core/operator/accuracy_op.h +++ b/oneflow/core/operator/accuracy_op.h @@ -17,12 +17,13 @@ class AccuracyOp final : public Operator { const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, - std::function<void(OpContext*)> EnrollOpCtx) const override; + const ParallelContext* parallel_ctx) const override; void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/accuracy_print_op.h b/oneflow/core/operator/accuracy_print_op.h index 60bf09bfcb84fa86cfdab79bf7e9e1bd7aa4e352..ac027e6210d5c038913661ef09965021eb0d0778 100644 --- a/oneflow/core/operator/accuracy_print_op.h +++ b/oneflow/core/operator/accuracy_print_op.h @@ -15,6 +15,8 @@ class AccuracyPrintOp final : public Operator { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override; }; diff --git a/oneflow/core/operator/adam_model_update_op.cpp b/oneflow/core/operator/adam_model_update_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f95358e170521cb74b90af29244751429148423f --- /dev/null +++ b/oneflow/core/operator/adam_model_update_op.cpp @@ -0,0 +1,40 @@ +#include "oneflow/core/operator/adam_model_update_op.h" + +namespace oneflow { + +void AdamModelUpdateOp::MdUpdtVirtualInitFromOpConf() { + const auto& adam_conf = op_conf().normal_mdupdt_conf().user_conf().adam_conf(); + CHECK_GE(adam_conf.beta1(), 0); + CHECK_LT(adam_conf.beta1(), 1); + CHECK_GE(adam_conf.beta2(), 0); + CHECK_LT(adam_conf.beta2(), 1); + + EnrollForwardModelBn("m"); + EnrollForwardModelBn("v"); + if (adam_conf.do_bias_correction()) { + EnrollForwardModelBn("beta1_t"); + EnrollForwardModelBn("beta2_t"); + } +} + +void AdamModelUpdateOp::MdUpdtVirtualInferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const auto& adam_conf = op_conf().normal_mdupdt_conf().user_conf().adam_conf(); + const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model"); + CHECK_EQ(model_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType()); + CHECK_EQ(model_blob_desc->has_data_id_field(), false); + *GetBlobDesc4BnInOp("m") = *model_blob_desc; + *GetBlobDesc4BnInOp("v") = *model_blob_desc; + + if (adam_conf.do_bias_correction()) { + *GetBlobDesc4BnInOp("beta1_t") = *model_blob_desc; + *GetBlobDesc4BnInOp("beta2_t") = *model_blob_desc; + GetBlobDesc4BnInOp("beta1_t")->mut_shape() = Shape({1}); + GetBlobDesc4BnInOp("beta2_t")->mut_shape() = Shape({1}); + } +} + +REGISTER_CLASS(NormalModelUpdateOpUserConf::kAdamConf, NormalModelUpdtOp, AdamModelUpdateOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/adam_model_update_op.h b/oneflow/core/operator/adam_model_update_op.h new file mode 100644 index 0000000000000000000000000000000000000000..21cc800bfac47a8f578b29ef3a6f039b7d130a24 --- /dev/null +++ b/oneflow/core/operator/adam_model_update_op.h @@ -0,0 +1,22 @@ +#ifndef ONEFLOW_CORE_OPERATOR_ADAM_MODEL_UPDATE_OP_H_ +#define ONEFLOW_CORE_OPERATOR_ADAM_MODEL_UPDATE_OP_H_ + +#include "oneflow/core/operator/normal_model_update_op.h" + +namespace oneflow { + +class AdamModelUpdateOp final : public NormalModelUpdtOp { + public: + OF_DISALLOW_COPY_AND_MOVE(AdamModelUpdateOp); + AdamModelUpdateOp() = default; + ~AdamModelUpdateOp() = default; + + private: + void MdUpdtVirtualInitFromOpConf() override; + void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_ADAM_MODEL_UPDATE_OP_H_ diff --git a/oneflow/core/operator/add_op.cpp b/oneflow/core/operator/add_op.cpp index 17d9ff2b6aa2ab615aaad19d10dbb787d1081f0a..365e734186acba54b8b885c078cf5f8807689605 100644 --- a/oneflow/core/operator/add_op.cpp +++ b/oneflow/core/operator/add_op.cpp @@ -4,8 +4,9 @@ namespace oneflow { void AddOp::VirtualInitFromOpConf() { CHECK(op_conf().has_add_conf()); } const PbMessage& AddOp::GetCustomizedConf() const { return op_conf().add_conf(); } -void AddOp::FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const { +void AddOp::VirtualFixInDiffBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { if (!Global<JobDesc>::Get()->enable_blob_mem_sharing()) { return; } int64_t blob_mem_id = oneflow_cast<int64_t>(NewUniqueId()); FOR_RANGE(size_t, i, 0, input_diff_bns().size()) { diff --git a/oneflow/core/operator/add_op.h b/oneflow/core/operator/add_op.h index c607e03a2a153d39b3e5eb984fa1e9fc3095d43a..086cb16100550c1f31266030cc792189ff78627f 100644 --- a/oneflow/core/operator/add_op.h +++ b/oneflow/core/operator/add_op.h @@ -15,9 +15,13 @@ class AddOp final : public CWiseOp { const PbMessage& GetCustomizedConf() const override; bool NeedInBlobWhenBackward() const override { return false; } bool NeedOutBlobWhenBackward() const override { return false; } - virtual void FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*) const override; + void VirtualFixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; + } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_ADD_OP_H_ diff --git a/oneflow/core/operator/basic_rnn_op.h b/oneflow/core/operator/basic_rnn_op.h index d01a167b1903479cb0ca0129f47166436c20edfd..99b4dcd29ef500cd4f2681e6024bef716314ee6b 100644 --- a/oneflow/core/operator/basic_rnn_op.h +++ b/oneflow/core/operator/basic_rnn_op.h @@ -13,6 +13,8 @@ class BasicRnnOp final : public RecurrentOp { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + void VirtualInitFromOpConf(); void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const; diff --git a/oneflow/core/operator/batch_gather_op.cpp b/oneflow/core/operator/batch_gather_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..989fe38f3d1be0b91d25b88276974f9795ccf9e0 --- /dev/null +++ b/oneflow/core/operator/batch_gather_op.cpp @@ -0,0 +1,38 @@ +#include "oneflow/core/operator/batch_gather_op.h" + +namespace oneflow { + +void BatchGatherOp::InitFromOpConf() { + CHECK(op_conf().has_batch_gather_conf()); + EnrollInputBn("in"); + EnrollInputBn("indices", false); + EnrollOutputBn("out"); +} + +const PbMessage& BatchGatherOp::GetCustomizedConf() const { return op_conf().batch_gather_conf(); } + +void BatchGatherOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + CHECK_GT(in->shape().NumAxes(), 0); + const BlobDesc* indices = GetBlobDesc4BnInOp("indices"); + CHECK_GT(indices->shape().NumAxes(), 0); + CHECK(IsIntegralDataType(indices->data_type())); + const std::vector<int64_t>& in_dim_vec = in->shape().dim_vec(); + const std::vector<int64_t>& indices_dim_vec = indices->shape().dim_vec(); + CHECK_LE(indices_dim_vec.size(), in_dim_vec.size()); + FOR_RANGE(int64_t, i, 0, indices_dim_vec.size() - 1) { + CHECK_EQ(indices_dim_vec.at(i), in_dim_vec.at(i)); + } + // out + std::vector<int64_t> out_dim_vec(indices_dim_vec); + out_dim_vec.insert(out_dim_vec.end(), in_dim_vec.cbegin() + indices_dim_vec.size(), + in_dim_vec.cend()); + BlobDesc* out = GetBlobDesc4BnInOp("out"); + *out = *in; + out->mut_shape() = Shape(out_dim_vec); +} + +REGISTER_OP(OperatorConf::kBatchGatherConf, BatchGatherOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/batch_gather_op.h b/oneflow/core/operator/batch_gather_op.h new file mode 100644 index 0000000000000000000000000000000000000000..469100f9057ae2482edd72dd63fddbba593f78cd --- /dev/null +++ b/oneflow/core/operator/batch_gather_op.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BATCH_GATHER_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BATCH_GATHER_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class BatchGatherOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(BatchGatherOp); + BatchGatherOp() = default; + ~BatchGatherOp() override = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BATCH_GATHER_OP_H_ diff --git a/oneflow/core/operator/bias_add_op.cpp b/oneflow/core/operator/bias_add_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..66b0107cc3a6f78a35d8048674be71cf995b8821 --- /dev/null +++ b/oneflow/core/operator/bias_add_op.cpp @@ -0,0 +1,45 @@ +#include "oneflow/core/operator/bias_add_op.h" +#include "oneflow/core/common/balanced_splitter.h" + +namespace oneflow { + +void BiasAddOp::InitFromOpConf() { + CHECK(op_conf().has_bias_add_conf()); + EnrollInputBn("a"); + EnrollInputBn("b"); + EnrollOutputBn("out"); + EnrollConstBufBn("bias_multiplier"); +} + +const PbMessage& BiasAddOp::GetCustomizedConf() const { return op_conf().bias_add_conf(); } + +bool BiasAddOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const { + CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()); + return ibn == "b"; +} + +void BiasAddOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(Make_DS_MB_2_DS_OpParallelSignature(this)); + auto EqZero = [](int32_t axis) { return axis == 0; }; + op_parallel_signatures->emplace_back(Make_DB_MS_2_MS_OpParallelSignature(this, EqZero)); +} + +void BiasAddOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* a_blob_desc = GetBlobDesc4BnInOp("a"); + const BlobDesc* b_blob_desc = GetBlobDesc4BnInOp("b"); + + CHECK_EQ(a_blob_desc->shape().NumAxes(), 2); + CHECK_EQ(b_blob_desc->shape().NumAxes(), 1); + CHECK_EQ(a_blob_desc->shape().At(1), b_blob_desc->shape().At(0)); + CHECK_EQ(a_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType()); + CHECK_EQ(b_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType()); + + *GetBlobDesc4BnInOp("out") = *a_blob_desc; + GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({a_blob_desc->shape().At(0), 1}); +} + +REGISTER_OP(OperatorConf::kBiasAddConf, BiasAddOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/bias_add_op.h b/oneflow/core/operator/bias_add_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f9fa0332199b731969fe8c5fdfbb5762d0ef9f1a --- /dev/null +++ b/oneflow/core/operator/bias_add_op.h @@ -0,0 +1,34 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BIAS_ADD_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BIAS_ADD_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class BiasAddOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(BiasAddOp); + BiasAddOp() = default; + ~BiasAddOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override { + return 1; + } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override; + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BIAS_ADD_OP_H_ diff --git a/oneflow/core/operator/boxing_op.h b/oneflow/core/operator/boxing_op.h index 2df43329bbc9d782c92143058d2f76fba8f79e9b..a5607953e4f7cd8d6a9f23a089f52a145407af53 100644 --- a/oneflow/core/operator/boxing_op.h +++ b/oneflow/core/operator/boxing_op.h @@ -23,6 +23,8 @@ class BoxingOp final : public Operator { KernelConf* kernel_conf) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override; LogicalBlobId obn2lbi(const std::string& output_bn) const override; void InferDataTmpBlobDesc(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, diff --git a/oneflow/core/operator/broadcast_add_op.cpp b/oneflow/core/operator/broadcast_add_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67634bca061c1132b1d20d5225f7cae704ceba4a --- /dev/null +++ b/oneflow/core/operator/broadcast_add_op.cpp @@ -0,0 +1,11 @@ +#include "oneflow/core/operator/broadcast_add_op.h" + +namespace oneflow { + +const PbMessage& BroadcastAddOp::GetCustomizedConf() const { + return op_conf().broadcast_add_conf(); +} + +REGISTER_OP(OperatorConf::kBroadcastAddConf, BroadcastAddOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/broadcast_add_op.h b/oneflow/core/operator/broadcast_add_op.h new file mode 100644 index 0000000000000000000000000000000000000000..94dbcf02a75d8d2eaf8957cd1dd51213371dfb1f --- /dev/null +++ b/oneflow/core/operator/broadcast_add_op.h @@ -0,0 +1,22 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_ADD_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BROADCAST_ADD_OP_H_ + +#include "oneflow/core/operator/broadcast_binary_op.h" + +namespace oneflow { + +class BroadcastAddOp final : public BroadcastBinaryOp { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastAddOp); + BroadcastAddOp() = default; + ~BroadcastAddOp() = default; + + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + const PbMessage& GetCustomizedConf() const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BROADCAST_ADD_OP_H_ diff --git a/oneflow/core/operator/broadcast_binary_op.cpp b/oneflow/core/operator/broadcast_binary_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b411e90f026f260e5c94ed41da53272c92ada30b --- /dev/null +++ b/oneflow/core/operator/broadcast_binary_op.cpp @@ -0,0 +1,119 @@ +#include "oneflow/core/operator/broadcast_binary_op.h" + +namespace oneflow { + +namespace { + +bool IsScalarBlob(const BlobDesc* blob) { + return blob->shape().NumAxes() == 1 && blob->shape().At(0) == 1; +} + +class BroadcastBinaryOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastBinaryOpParallelSignature); + ~BroadcastBinaryOpParallelSignature() override = default; + + BroadcastBinaryOpParallelSignature(const Operator* op, + const HashSet<std::string>& model_input_bns) + : OpParallelSignature(op), model_input_bns_(model_input_bns) {} + + const std::string Description() const override { + return op().op_name() + ": (C, ..., S(0), ...) -> (S(0), ...)"; + } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + for (const auto& bn : op().input_bns()) { + const auto& sbp_infer_hint = SbpInferHint4Ibn(bn); + bool is_model_input_bns = (model_input_bns_.find(bn) != model_input_bns_.end()); + bool has_actual_model_input = sbp_infer_hint.is_model_blob(); + if (is_model_input_bns ^ has_actual_model_input) { + return MakeOpParallelMatchSignatureMismatch(); + } + } + if (parallel_ctx->policy() == kDataParallel) { return MakeOpParallelMatchSuccess(); } + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kDataParallel); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + for (const auto& bn : op().input_bns()) { + if (model_input_bns_.find(bn) != model_input_bns_.end()) { + (*bn2sbp)[bn].mutable_broadcast_parallel(); + } else { + const auto& in_sbp = SbpInferHint4Ibn(bn).sbp_parallel(); + if (in_sbp.has_broadcast_parallel()) { + (*bn2sbp)[bn].mutable_broadcast_parallel(); + } else { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); + } + } + } + for (const auto& bn : op().output_bns()) { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); + } + } + + private: + HashSet<std::string> model_input_bns_; +}; + +std::unique_ptr<const OpParallelSignature> MakeBroadcastBinaryOpParallelSignature( + const Operator* op, const HashSet<std::string>& model_input_bns) { + return std::unique_ptr<const OpParallelSignature>( + new BroadcastBinaryOpParallelSignature(op, model_input_bns)); +} + +} // namespace + +void BroadcastBinaryOp::InitFromOpConf() { + EnrollInputBn("a"); + EnrollInputBn("b"); + EnrollOutputBn("out"); + EnrollBwBufBn("bw_buf"); +} + +void BroadcastBinaryOp::InferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* a_blob_desc = GetBlobDesc4BnInOp("a"); + const BlobDesc* b_blob_desc = GetBlobDesc4BnInOp("b"); + CHECK_EQ(a_blob_desc->data_type(), b_blob_desc->data_type()); + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + size_t output_num_axes = std::max(a_blob_desc->shape().NumAxes(), b_blob_desc->shape().NumAxes()); + if (IsScalarBlob(a_blob_desc)) { + *out_blob_desc = *b_blob_desc; + } else if (IsScalarBlob(b_blob_desc)) { + *out_blob_desc = *a_blob_desc; + } else { + const auto& a_shape = a_blob_desc->shape().CreateLeftExtendedShape(output_num_axes); + const auto& b_shape = b_blob_desc->shape().CreateLeftExtendedShape(output_num_axes); + *out_blob_desc = *a_blob_desc; + Shape out_shape(a_shape); + FOR_RANGE(int64_t, i, 0, a_shape.NumAxes()) { + CHECK(a_shape.At(i) == 1 || b_shape.At(i) == 1 || a_shape.At(i) == b_shape.At(i)); + out_shape.Set(i, std::max(a_shape.At(i), b_shape.At(i))); + } + out_blob_desc->mut_shape() = out_shape; + } +} + +void BroadcastBinaryOp::InferBwBufBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const { + const BlobDesc* out = GetBlobDesc4BnInOp("out"); + BlobDesc* bw_buf = GetBlobDesc4BnInOp("bw_buf"); + bw_buf->mut_shape() = Shape({out->shape().elem_cnt()}); + bw_buf->set_data_type(out->data_type()); +} + +void BroadcastBinaryOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {})); + op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {"a"})); + op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {"b"})); + op_parallel_signatures->emplace_back(MakeBroadcastBinaryOpParallelSignature(this, {"a", "b"})); +} + +} // namespace oneflow diff --git a/oneflow/core/operator/broadcast_binary_op.h b/oneflow/core/operator/broadcast_binary_op.h new file mode 100644 index 0000000000000000000000000000000000000000..669c634401df9f9d942d1f2a41e0c20d928078ee --- /dev/null +++ b/oneflow/core/operator/broadcast_binary_op.h @@ -0,0 +1,29 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_BINARY_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BROADCAST_BINARY_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class BroadcastBinaryOp : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastBinaryOp); + BroadcastBinaryOp() = default; + virtual ~BroadcastBinaryOp() = default; + + void InitFromOpConf() override; + bool IsAllOutputConst() const override { return GetValFromCustomizedConf<bool>("is_const"); } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BROADCAST_BINARY_OP_H_ diff --git a/oneflow/core/operator/broadcast_div_op.cpp b/oneflow/core/operator/broadcast_div_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2dc4d625b79167a308ca599c1a083e84966089db --- /dev/null +++ b/oneflow/core/operator/broadcast_div_op.cpp @@ -0,0 +1,11 @@ +#include "oneflow/core/operator/broadcast_div_op.h" + +namespace oneflow { + +const PbMessage& BroadcastDivOp::GetCustomizedConf() const { + return op_conf().broadcast_div_conf(); +} + +REGISTER_OP(OperatorConf::kBroadcastDivConf, BroadcastDivOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/broadcast_div_op.h b/oneflow/core/operator/broadcast_div_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a6eb5eec6abb8fc14eb6922e0f03d4916d47f2c2 --- /dev/null +++ b/oneflow/core/operator/broadcast_div_op.h @@ -0,0 +1,19 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_DIV_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BROADCAST_DIV_OP_H_ + +#include "oneflow/core/operator/broadcast_binary_op.h" + +namespace oneflow { + +class BroadcastDivOp final : public BroadcastBinaryOp { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastDivOp); + BroadcastDivOp() = default; + ~BroadcastDivOp() = default; + + const PbMessage& GetCustomizedConf() const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BROADCAST_DIV_OP_H_ diff --git a/oneflow/core/operator/broadcast_mul_op.cpp b/oneflow/core/operator/broadcast_mul_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17aad41f418e92ddefedb6eefb14b9b18eee742d --- /dev/null +++ b/oneflow/core/operator/broadcast_mul_op.cpp @@ -0,0 +1,11 @@ +#include "oneflow/core/operator/broadcast_mul_op.h" + +namespace oneflow { + +const PbMessage& BroadcastMulOp::GetCustomizedConf() const { + return op_conf().broadcast_mul_conf(); +} + +REGISTER_OP(OperatorConf::kBroadcastMulConf, BroadcastMulOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/broadcast_mul_op.h b/oneflow/core/operator/broadcast_mul_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5a7576936ab3dfbcd79c0d448d5fa1c09b239a28 --- /dev/null +++ b/oneflow/core/operator/broadcast_mul_op.h @@ -0,0 +1,20 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_MUL_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BROADCAST_MUL_OP_H_ + +#include "oneflow/core/operator/broadcast_binary_op.h" + +namespace oneflow { + +class BroadcastMulOp final : public BroadcastBinaryOp { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastMulOp); + BroadcastMulOp() = default; + ~BroadcastMulOp() = default; + + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BROADCAST_MUL_OP_H_ diff --git a/oneflow/core/operator/broadcast_sub_op.cpp b/oneflow/core/operator/broadcast_sub_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c91b3e18fa4e5becffe55941cb8396278876fc90 --- /dev/null +++ b/oneflow/core/operator/broadcast_sub_op.cpp @@ -0,0 +1,11 @@ +#include "oneflow/core/operator/broadcast_sub_op.h" + +namespace oneflow { + +const PbMessage& BroadcastSubOp::GetCustomizedConf() const { + return op_conf().broadcast_sub_conf(); +} + +REGISTER_OP(OperatorConf::kBroadcastSubConf, BroadcastSubOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/broadcast_sub_op.h b/oneflow/core/operator/broadcast_sub_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2ef143bb3219393c2b4123330e08ac8f733291c2 --- /dev/null +++ b/oneflow/core/operator/broadcast_sub_op.h @@ -0,0 +1,21 @@ +#ifndef ONEFLOW_CORE_OPERATOR_BROADCAST_SUB_OP_H_ +#define ONEFLOW_CORE_OPERATOR_BROADCAST_SUB_OP_H_ + +#include "oneflow/core/operator/broadcast_binary_op.h" + +namespace oneflow { + +class BroadcastSubOp final : public BroadcastBinaryOp { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastSubOp); + BroadcastSubOp() = default; + ~BroadcastSubOp() = default; + + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_BROADCAST_SUB_OP_H_ diff --git a/oneflow/core/operator/cast_op.cpp b/oneflow/core/operator/cast_op.cpp index 4fbdefc8c78b192c3ffe5ee569e0254febf2388f..4c3d58080b7f82eaa66ab233ed44f0e6872ece5b 100644 --- a/oneflow/core/operator/cast_op.cpp +++ b/oneflow/core/operator/cast_op.cpp @@ -18,6 +18,6 @@ void CastOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob out_blob_desc->set_data_type(op_conf().cast_conf().data_type()); } -REGISTER_CPU_OP(OperatorConf::kCastConf, CastOp); +REGISTER_OP(OperatorConf::kCastConf, CastOp); } // namespace oneflow diff --git a/oneflow/core/operator/cast_op.h b/oneflow/core/operator/cast_op.h index 2e8585bd0162a3b4e077efcb2790da4129836022..3f4f52f7201cfc90fc0f4949df24513df9bd3699 100644 --- a/oneflow/core/operator/cast_op.h +++ b/oneflow/core/operator/cast_op.h @@ -13,9 +13,11 @@ class CastOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - bool IsElemWiseOp() const override { return true; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } }; } // namespace oneflow diff --git a/oneflow/core/operator/clone_op.h b/oneflow/core/operator/clone_op.h index af9a70f7e73c8e002598fd17dcdac4879816ba5b..699f8357ea97d7b059c8fd960839d6dafd44cb86 100644 --- a/oneflow/core/operator/clone_op.h +++ b/oneflow/core/operator/clone_op.h @@ -22,6 +22,8 @@ class CloneOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return LogicalBlobId(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override { return LogicalBlobId(); } void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, diff --git a/oneflow/core/operator/concat_op.h b/oneflow/core/operator/concat_op.h index e5cb8994609fa067aab7e38e703e12fa2d0144f4..2d61e95f0c0ca9b74642e975c691f8cfcc571a31 100644 --- a/oneflow/core/operator/concat_op.h +++ b/oneflow/core/operator/concat_op.h @@ -19,6 +19,9 @@ class ConcatOp final : public Operator { void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/constant_op.cpp b/oneflow/core/operator/constant_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8467491cd186b701a0a43c6e7c752b2f900ee0f7 --- /dev/null +++ b/oneflow/core/operator/constant_op.cpp @@ -0,0 +1,58 @@ +#include "oneflow/core/operator/constant_op.h" + +namespace oneflow { + +void ConstantOp::InitFromOpConf() { + CHECK(op_conf().has_constant_conf()); + EnrollInputBn("tick", false); + EnrollOutputBn("out", false); +} + +const PbMessage& ConstantOp::GetCustomizedConf() const { return op_conf().constant_conf(); } + +void ConstantOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const { + CHECK_EQ(parallel_ctx->policy(), ParallelPolicy::kDataParallel); + const ConstantOpConf& conf = op_conf().constant_conf(); + const DataType& data_type = + conf.has_data_type() ? conf.data_type() : Global<JobDesc>::Get()->DefaultDataType(); + std::vector<int64_t> dim_vec; + if (conf.use_device_piece_size_as_dim0()) { + CHECK_EQ(record_piece_size % parallel_ctx->parallel_num(), 0); + dim_vec.push_back(record_piece_size / parallel_ctx->parallel_num()); + } + if (conf.has_shape()) { + dim_vec.insert(dim_vec.end(), conf.shape().dim().cbegin(), conf.shape().dim().cend()); + } + if (dim_vec.empty()) { dim_vec.push_back(1); } + BlobDesc* out = GetBlobDesc4BnInOp("out"); + out->set_data_type(data_type); + out->mut_shape() = Shape(dim_vec); +} + +void ConstantOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { + kernel_conf->mutable_constant_conf()->set_random_seed(NewRandomSeed()); + const DataType& data_type = GetBlobDesc4BnInOp("out")->data_type(); + if (op_conf().constant_conf().has_initializer()) { + *kernel_conf->mutable_constant_conf()->mutable_initializer() = + op_conf().constant_conf().initializer(); + } else if (IsFloatingDataType(data_type)) { + InitializerConf conf; + conf.mutable_constant_conf()->set_value(0); + *kernel_conf->mutable_constant_conf()->mutable_initializer() = conf; + } else if (IsIntegralDataType(data_type)) { + InitializerConf conf; + conf.mutable_constant_int_conf()->set_value(0); + *kernel_conf->mutable_constant_conf()->mutable_initializer() = conf; + } else { + UNIMPLEMENTED(); + } + kernel_conf->set_data_type(data_type); +} + +REGISTER_OP(OperatorConf::kConstantConf, ConstantOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/constant_op.h b/oneflow/core/operator/constant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c791cb01603b3f5633bd2dd53afc64fa1941107 --- /dev/null +++ b/oneflow/core/operator/constant_op.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_OPERATOR_CONSTANT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_CONSTANT_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ConstantOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ConstantOp); + ConstantOp() = default; + ~ConstantOp() override = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + bool IsAllOutputConst() const override { return true; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + KernelConf* kernel_conf) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_CONSTANT_OP_H_ diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp index 550fa3fc101df74e5956f1d75323ca0fa2509dd1..7df482164e6109dbef249688f2e5360b02e6d3ef 100644 --- a/oneflow/core/operator/conv_op.cpp +++ b/oneflow/core/operator/conv_op.cpp @@ -94,7 +94,7 @@ void ConvOp<NDims>::InitFromOpConf() { template<int32_t NDims> void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const { const std::string& data_format = GetValFromCustomizedConf<std::string>("data_format"); @@ -288,7 +288,9 @@ PbMessage* ConvOp<NDims>::MutableCustomizedKernelConf(KernelConf* kernel_conf) c } template<int32_t NDims> -int32_t ConvOp<NDims>::ModelSplitAxis() const { +int32_t ConvOp<NDims>::OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const { if (GetValFromCustomizedConf<std::string>("data_format") == "channels_first") { return 1; } else if (GetValFromCustomizedConf<std::string>("data_format") == "channels_last") { @@ -298,11 +300,6 @@ int32_t ConvOp<NDims>::ModelSplitAxis() const { } } -template<int32_t NDims> -int32_t ConvOp<NDims>::MaxModelSplitNum() const { - return GetValFromCustomizedConf<int32_t>("filters"); -} - #ifdef WITH_CUDA template<int32_t NDims> void ConvOp<NDims>::InferCudnnAlgo( diff --git a/oneflow/core/operator/conv_op.h b/oneflow/core/operator/conv_op.h index d0d782edce761f57870407483c375f807a813911..4d2316b0b77ac5501367012f2ee18bd5d3512e07 100644 --- a/oneflow/core/operator/conv_op.h +++ b/oneflow/core/operator/conv_op.h @@ -41,15 +41,18 @@ class ConvOp : public Operator { void InitFromOpConf() override; bool NeedOutBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const override; void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, const OpContext*) const override; - int32_t ModelSplitAxis() const override; - int32_t MaxModelSplitNum() const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + PbMessage* MutableCustomizedKernelConf(KernelConf*) const override; void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*, const OpContext*) const override; diff --git a/oneflow/core/operator/copy_comm_net_op.h b/oneflow/core/operator/copy_comm_net_op.h index 3709ed055674d3fd31abfb93e31da58d663eb4c9..1954eda28d8ddbde5b7ca2bd91335e15a8e4a888 100644 --- a/oneflow/core/operator/copy_comm_net_op.h +++ b/oneflow/core/operator/copy_comm_net_op.h @@ -15,6 +15,8 @@ class CopyCommNetOp final : public Operator { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override; LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/copy_hd_op.h b/oneflow/core/operator/copy_hd_op.h index 657d7b0f906bc8d8041579c011fb21f12102457f..fa35b16b185d0cf42c588963d0e2e7e5148a49fd 100644 --- a/oneflow/core/operator/copy_hd_op.h +++ b/oneflow/core/operator/copy_hd_op.h @@ -15,6 +15,8 @@ class CopyHdOp final : public Operator { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; diff --git a/oneflow/core/operator/debug_op.cpp b/oneflow/core/operator/debug_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7ae2a9ee4fa6cebf7ae1b882602101a67187260 --- /dev/null +++ b/oneflow/core/operator/debug_op.cpp @@ -0,0 +1,17 @@ +#include "oneflow/core/operator/debug_op.h" + +namespace oneflow { + +void DebugOp::InitFromOpConf() { + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +void DebugOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_CPU_OP(OperatorConf::kDebugConf, DebugOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/debug_op.h b/oneflow/core/operator/debug_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0b854da59a82ebcd817de7f8cb3734bacc498d53 --- /dev/null +++ b/oneflow/core/operator/debug_op.h @@ -0,0 +1,28 @@ +#ifndef ONEFLOW_CORE_OPERATOR_DEBUG_OP_H_ +#define ONEFLOW_CORE_OPERATOR_DEBUG_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class DebugOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(DebugOp); + DebugOp() = default; + ~DebugOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override { return op_conf().debug_conf(); } + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_DEBUG_OP_H_ diff --git a/oneflow/core/operator/decode_ofrecord_op.h b/oneflow/core/operator/decode_ofrecord_op.h index cab4107ec7411dba39896f56339904bb0b5ba785..f94dcbde24334f2bb364784a96f5e5679b68e903 100644 --- a/oneflow/core/operator/decode_ofrecord_op.h +++ b/oneflow/core/operator/decode_ofrecord_op.h @@ -24,6 +24,8 @@ class DecodeOFRecordOp final : public Operator { KernelConf* kernel_conf) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/decode_random_op.cpp b/oneflow/core/operator/decode_random_op.cpp index 22c1b6a00214e2ac87d315debd1392d726c5604a..4e35ca19aaabca7b71d0c32f62372740ff69ae33 100644 --- a/oneflow/core/operator/decode_random_op.cpp +++ b/oneflow/core/operator/decode_random_op.cpp @@ -18,13 +18,13 @@ void DecodeRandomOp::VirtualGenKernelConf( } void DecodeRandomOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const { + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); const DecodeRandomOpConf& conf = op_conf().decode_random_conf(); std::vector<int64_t> dim_vec(1 + conf.shape().dim_size()); - int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize(); - CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0); - dim_vec[0] = global_piece_size / parallel_ctx->parallel_num(); + CHECK_EQ(record_piece_size % parallel_ctx->parallel_num(), 0); + dim_vec[0] = record_piece_size / parallel_ctx->parallel_num(); FOR_RANGE(size_t, j, 1, dim_vec.size()) { dim_vec[j] = conf.shape().dim(j - 1); } out_blob_desc->mut_shape() = Shape(dim_vec); out_blob_desc->set_data_type(conf.data_type()); diff --git a/oneflow/core/operator/decode_random_op.h b/oneflow/core/operator/decode_random_op.h index 0aaad916d826c883bddb95f4dfc8d14646db6f92..2ab009feffd36b07e75e40291b8b07c403872168 100644 --- a/oneflow/core/operator/decode_random_op.h +++ b/oneflow/core/operator/decode_random_op.h @@ -17,9 +17,12 @@ class DecodeRandomOp final : public Operator { LogicalNode* NewProperLogicalNode() override { return new DecodeRandomLogicalNode; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override; + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; diff --git a/oneflow/core/operator/define_test_blob_op.cpp b/oneflow/core/operator/define_test_blob_op.cpp index a56f8fb13d6fdabc695678601ad25f282fa95435..bdbd8fc0955151464963445f2ec9d1226cc0e68d 100644 --- a/oneflow/core/operator/define_test_blob_op.cpp +++ b/oneflow/core/operator/define_test_blob_op.cpp @@ -4,7 +4,7 @@ namespace oneflow { void DefineTestBlobOp::InitFromOpConf() { CHECK(op_conf().has_define_test_blob_conf()); - EnrollOutputBn("out", false); + EnrollOutputBn("out", op_conf().define_test_blob_conf().has_diff()); } const PbMessage& DefineTestBlobOp::GetCustomizedConf() const { diff --git a/oneflow/core/operator/define_test_blob_op.h b/oneflow/core/operator/define_test_blob_op.h index 69a91479c51052f04984e4fe9904d0cca6a9a63d..f8034c787668cc5bb39df4b7108717b1e740b20c 100644 --- a/oneflow/core/operator/define_test_blob_op.h +++ b/oneflow/core/operator/define_test_blob_op.h @@ -18,6 +18,9 @@ class DefineTestBlobOp final : public Operator { void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/dot_op.h b/oneflow/core/operator/dot_op.h index dc2291eb6d50b517100765c30021a4dea06bbb99..bc926bfe0ecce38672a2cfd29fed2f32d091d3c7 100644 --- a/oneflow/core/operator/dot_op.h +++ b/oneflow/core/operator/dot_op.h @@ -12,6 +12,9 @@ class DotOp final : public Operator { const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/dropout_op.h b/oneflow/core/operator/dropout_op.h index c634d1b05d658c98971c3801221a43a90c569103..b837f27862f919da110a487eade2987dec987c9f 100644 --- a/oneflow/core/operator/dropout_op.h +++ b/oneflow/core/operator/dropout_op.h @@ -13,7 +13,6 @@ class DropoutOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - bool IsElemWiseOp() const override { return true; } bool NeedInBlobWhenBackward() const override { return false; } bool NeedOutBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, @@ -22,6 +21,9 @@ class DropoutOp final : public Operator { void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } }; } // namespace oneflow diff --git a/oneflow/core/operator/embedding_lookup_accumulate_op.h b/oneflow/core/operator/embedding_lookup_accumulate_op.h index dc42310fde5ffa21907d6fa0fcd38f7a43d7077d..c95274ded53a5f7c418b8cbb7489858b42582a7d 100644 --- a/oneflow/core/operator/embedding_lookup_accumulate_op.h +++ b/oneflow/core/operator/embedding_lookup_accumulate_op.h @@ -13,11 +13,11 @@ class EmbeddingLookupAccumulateOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; diff --git a/oneflow/core/operator/embedding_lookup_op.h b/oneflow/core/operator/embedding_lookup_op.h index ef9611f0b797c188119c738fa5c807ffd6336737..4ac9c0ce343c1bc962ac45d5ad68e1d5442f3985 100644 --- a/oneflow/core/operator/embedding_lookup_op.h +++ b/oneflow/core/operator/embedding_lookup_op.h @@ -16,10 +16,14 @@ class EmbeddingLookupOp final : public Operator { const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; - int32_t ModelSplitAxis() const override { return 1; } - int32_t MaxModelSplitNum() const override { return op_conf().embedding_lookup_conf().units(); } + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override { + return 1; + } private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/fully_connected_op.h b/oneflow/core/operator/fully_connected_op.h index 0adf7a3a2612e2b53a972367c5355967e5e7d52a..efa229120d17549385f2c89657d7cc3ab59df39d 100644 --- a/oneflow/core/operator/fully_connected_op.h +++ b/oneflow/core/operator/fully_connected_op.h @@ -16,8 +16,14 @@ class FullyConnectedOp final : public Operator { bool NeedOutBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; - int32_t ModelSplitAxis() const override { return 1; } - int32_t MaxModelSplitNum() const override { return op_conf().fully_connected_conf().units(); } + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override { + return 1; + } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/gather_op.cpp b/oneflow/core/operator/gather_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4f8658ca88adf69f11a6fe9b243a468a32c5859 --- /dev/null +++ b/oneflow/core/operator/gather_op.cpp @@ -0,0 +1,115 @@ +#include "oneflow/core/operator/gather_op.h" + +namespace oneflow { + +namespace { + +int64_t GetGatherAxis(const GatherOpConf& conf, int64_t num_axes) { + const int64_t axis = conf.axis() < 0 ? num_axes + conf.axis() : conf.axis(); + CHECK_GE(axis, 0); + CHECK_LT(axis, num_axes); + return axis; +} + +int64_t GetGatherAxis(const GatherOpConf& conf, const BlobDesc* in_blob_desc) { + return GetGatherAxis(conf, in_blob_desc->shape().NumAxes()); +} + +class Gather_DB_MS_2_P_OpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(Gather_DB_MS_2_P_OpParallelSignature); + ~Gather_DB_MS_2_P_OpParallelSignature() override = default; + + Gather_DB_MS_2_P_OpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": (C, S) -> P"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp, + const ParallelContext* parallel_ctx) const override { + const SbpInferHint& in_sbp_infer_hint = SbpInferHint4BnInOp("in"); + if (!in_sbp_infer_hint.is_model_split()) { return MakeOpParallelMatchSignatureMismatch(); } + if (in_sbp_infer_hint.split_axis() != 0) { return MakeOpParallelMatchSignatureMismatch(); } + if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); } + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + (*bn2sbp)["indices"].mutable_broadcast_parallel(); + (*bn2sbp)["in"].mutable_split_parallel()->set_axis(0); + (*bn2sbp)["out"].mutable_partial_sum_parallel(); + } +}; + +} // namespace + +void GatherOp::InitFromOpConf() { + CHECK(op_conf().has_gather_conf()); + EnrollInputBn("indices", false); + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +const PbMessage& GatherOp::GetCustomizedConf() const { return op_conf().gather_conf(); } + +bool GatherOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const { + CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()); + return ibn == "in"; +} + +void GatherOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* indices = GetBlobDesc4BnInOp("indices"); + CHECK(IsIntegralDataType(indices->data_type())); + CHECK_GT(indices->shape().NumAxes(), 0); + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + CHECK_GT(in->shape().NumAxes(), 0); + const int64_t axis = GetGatherAxis(op_conf().gather_conf(), in); + BlobDesc* out = GetBlobDesc4BnInOp("out"); + *out = *in; + std::vector<int64_t> dim_vec; + dim_vec.insert(dim_vec.end(), in->shape().dim_vec().cbegin(), + in->shape().dim_vec().cbegin() + axis); + dim_vec.insert(dim_vec.end(), indices->shape().dim_vec().cbegin(), + indices->shape().dim_vec().cend()); + dim_vec.insert(dim_vec.end(), in->shape().dim_vec().cbegin() + axis + 1, + in->shape().dim_vec().end()); + out->mut_shape() = Shape(dim_vec); +} + +void GatherOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { + const int64_t axis = GetGatherAxis(op_conf().gather_conf(), GetBlobDesc4BnInOp("in")); + kernel_conf->mutable_gather_conf()->set_axis(axis); +} + +void GatherOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(MakeDataSplitOpParallelSignature(this)); + op_parallel_signatures->emplace_back(Make_DS_MB_2_DS_OpParallelSignature(this)); + auto GtZero = [](int32_t axis) { return axis > 0; }; + op_parallel_signatures->emplace_back(Make_DB_MS_2_MS_OpParallelSignature(this, GtZero)); + op_parallel_signatures->emplace_back(new Gather_DB_MS_2_P_OpParallelSignature(this)); +} + +int32_t GatherOp::OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const { + const SbpInferHint& indices_sbp_infer_hint = SbpInferHint4Ibn("indices"); + CHECK(indices_sbp_infer_hint.is_data_blob()); + const SbpInferHint& in_sbp_infer_hint = SbpInferHint4Ibn("in"); + const int64_t in_num_axes = in_sbp_infer_hint.num_axes(); + const int64_t gather_axis = GetGatherAxis(op_conf().gather_conf(), in_num_axes); + CHECK(in_sbp_infer_hint.is_model_split()); + CHECK_GT(in_sbp_infer_hint.split_axis(), 0); + CHECK_GT(in_sbp_infer_hint.split_axis(), gather_axis); + CHECK_LT(in_sbp_infer_hint.split_axis(), in_num_axes); + return in_sbp_infer_hint.split_axis() + indices_sbp_infer_hint.num_axes() - 1; +} + +REGISTER_OP(OperatorConf::kGatherConf, GatherOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/gather_op.h b/oneflow/core/operator/gather_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7d9e3287fc2ff91a011f73d1e3339f9d38df1669 --- /dev/null +++ b/oneflow/core/operator/gather_op.h @@ -0,0 +1,34 @@ +#ifndef ONEFLOW_CORE_OPERATOR_GATHER_OP_H_ +#define ONEFLOW_CORE_OPERATOR_GATHER_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class GatherOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(GatherOp); + GatherOp() = default; + ~GatherOp() override = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + KernelConf* kernel_conf) const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override; + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_GATHER_OP_H_ diff --git a/oneflow/core/operator/gelu_op.cpp b/oneflow/core/operator/gelu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfc4f64e1ccf5ed325a35a9ab986188929a12fc4 --- /dev/null +++ b/oneflow/core/operator/gelu_op.cpp @@ -0,0 +1,22 @@ +#include "oneflow/core/operator/gelu_op.h" +#include "oneflow/core/common/balanced_splitter.h" + +namespace oneflow { + +void GeluOp::InitFromOpConf() { + CHECK(op_conf().has_gelu_conf()); + + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +const PbMessage& GeluOp::GetCustomizedConf() const { return op_conf().gelu_conf(); } + +void GeluOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_OP(OperatorConf::kGeluConf, GeluOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/gelu_op.h b/oneflow/core/operator/gelu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..66a083b3f13f7ef4e8659d0acc1cb8d6c4c4a016 --- /dev/null +++ b/oneflow/core/operator/gelu_op.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_OPERATOR_GELU_OP_H_ +#define ONEFLOW_CORE_OPERATOR_GELU_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class GeluOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(GeluOp); + GeluOp() = default; + ~GeluOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_GELU_OP_H_ diff --git a/oneflow/core/operator/hinge_loss_op.h b/oneflow/core/operator/hinge_loss_op.h index 527754fd6c2d27b61a89844c5f214f910ed814cb..f17cc11c6ccb956faf7ff93bd2055c40627a6ca0 100644 --- a/oneflow/core/operator/hinge_loss_op.h +++ b/oneflow/core/operator/hinge_loss_op.h @@ -14,6 +14,8 @@ class HingeLossOp final : public LossOp { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualInitFromOpConf() override; void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; diff --git a/oneflow/core/operator/identity_loss_op.cpp b/oneflow/core/operator/identity_loss_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92100a568f360d5a9795b5d13ab40f34f4f4890b --- /dev/null +++ b/oneflow/core/operator/identity_loss_op.cpp @@ -0,0 +1,26 @@ +#include "oneflow/core/operator/identity_loss_op.h" + +namespace oneflow { + +const PbMessage& IdentityLossOp::GetCustomizedConf() const { + return op_conf().identity_loss_conf(); +} + +LossKernelConf* IdentityLossOp::GetMutLossKernelConf(KernelConf* kernel_conf) const { + return kernel_conf->mutable_identity_loss_conf()->mutable_loss_conf(); +} + +void IdentityLossOp::VirtualInitFromOpConf() { EnrollConstBufBn("ones"); } + +void IdentityLossOp::VirtualInferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* prediction = GetBlobDesc4BnInOp("prediction"); + BlobDesc* ones = GetBlobDesc4BnInOp("ones"); + ones->set_data_type(prediction->data_type()); + ones->mut_shape() = prediction->shape(); +} + +REGISTER_OP(OperatorConf::kIdentityLossConf, IdentityLossOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/identity_loss_op.h b/oneflow/core/operator/identity_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0fc406d710799149c20e561b8b818caf9a43c336 --- /dev/null +++ b/oneflow/core/operator/identity_loss_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_IDENTITY_LOSS_OP_H_ +#define ONEFLOW_CORE_OPERATOR_IDENTITY_LOSS_OP_H_ + +#include "oneflow/core/operator/loss_op.h" + +namespace oneflow { + +class IdentityLossOp final : public LossOp { + public: + OF_DISALLOW_COPY_AND_MOVE(IdentityLossOp); + IdentityLossOp() = default; + ~IdentityLossOp() override = default; + + const PbMessage& GetCustomizedConf() const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + + LossKernelConf* GetMutLossKernelConf(KernelConf*) const override; + void VirtualInitFromOpConf() override; + void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_IDENTITY_LOSS_OP_H_ diff --git a/oneflow/core/operator/identity_op.cpp b/oneflow/core/operator/identity_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7287398718997ae51b2588a81ed59945ab58de72 --- /dev/null +++ b/oneflow/core/operator/identity_op.cpp @@ -0,0 +1,15 @@ +#include "oneflow/core/operator/identity_op.h" + +namespace oneflow { + +void IdentityOp::InitFromOpConf() { + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +void IdentityOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +} // namespace oneflow diff --git a/oneflow/core/operator/identity_op.h b/oneflow/core/operator/identity_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bca97c449adb7640fdc455ad7e99ecda2f10e9ff --- /dev/null +++ b/oneflow/core/operator/identity_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_IDENTITY_OP_H_ +#define ONEFLOW_CORE_OPERATOR_IDENTITY_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class IdentityOp : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(IdentityOp); + IdentityOp() = default; + virtual ~IdentityOp() = default; + + void InitFromOpConf() override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_IDENTITY_OP_H_ diff --git a/oneflow/core/operator/l2_normalize_op.h b/oneflow/core/operator/l2_normalize_op.h index 1d5a8c5369d1b1dc2a98d7f2d3439484e44a3026..9f827c4079db3e2ef0238d1e3d91d77b57fdad50 100644 --- a/oneflow/core/operator/l2_normalize_op.h +++ b/oneflow/core/operator/l2_normalize_op.h @@ -16,6 +16,9 @@ class L2NormalizeOp final : public Operator { bool NeedInBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/lars_model_update_op.cpp b/oneflow/core/operator/lars_model_update_op.cpp index 3c3d24a9a29af3f9e61b11f8136b6967d27fe7e4..51a0ca916f58baa1a42874e2076382c30e855d37 100644 --- a/oneflow/core/operator/lars_model_update_op.cpp +++ b/oneflow/core/operator/lars_model_update_op.cpp @@ -7,7 +7,7 @@ void LARSModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollDataTmpBn("data_tmp"); } -void LARSModelUpdateOp::InferBlobDescs( +void LARSModelUpdateOp::MdUpdtVirtualInferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model"); diff --git a/oneflow/core/operator/lars_model_update_op.h b/oneflow/core/operator/lars_model_update_op.h index 1040b8101d3cd98f7c4fdadd315344bfb7848a22..35673b31877b29eaaa52a6eed5be3638e23e19d5 100644 --- a/oneflow/core/operator/lars_model_update_op.h +++ b/oneflow/core/operator/lars_model_update_op.h @@ -11,11 +11,10 @@ class LARSModelUpdateOp final : public NormalModelUpdtOp { LARSModelUpdateOp() = default; ~LARSModelUpdateOp() = default; - void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override; - private: void MdUpdtVirtualInitFromOpConf() override; + void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; }; } // namespace oneflow diff --git a/oneflow/core/operator/layer_norm_op.cpp b/oneflow/core/operator/layer_norm_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35069f6e4c07bac6c6af5a03669ddb9196a6255c --- /dev/null +++ b/oneflow/core/operator/layer_norm_op.cpp @@ -0,0 +1,86 @@ +#include "oneflow/core/operator/layer_norm_op.h" + +namespace oneflow { + +namespace { + +int64_t ShiftNegativeAxisIfNeed(const Shape& shape, int64_t axis) { + const int64_t shifted = axis < 0 ? axis + shape.NumAxes() : axis; + CHECK_GE(shifted, 0); + CHECK_LT(shifted, shape.NumAxes()); + return shifted; +} + +} // namespace + +void LayerNormOp::InitFromOpConf() { + CHECK(op_conf().has_layer_norm_conf()); + const LayerNormOpConf& conf = op_conf().layer_norm_conf(); + if (!(conf.center() || conf.scale())) { mut_op_conf()->set_trainable(false); } + EnrollInputBn("in"); + EnrollOutputBn("out"); + if (conf.center()) { EnrollModelBn("beta"); } + if (conf.scale()) { + EnrollModelBn("gamma"); + EnrollDataTmpBn("normalize_out"); + } + EnrollDataTmpBn("cudnn_bn_mean"); + EnrollDataTmpBn("cudnn_bn_inv_variance"); + EnrollConstBufBn("cudnn_bn_scale_ones"); + EnrollConstBufBn("cudnn_bn_bias_zeros"); + EnrollBwBufBn("cudnn_bn_scale_diff_buf"); + EnrollBwBufBn("cudnn_bn_bias_diff_buf"); + if (op_conf().trainable()) { EnrollBwBufBn("bw_reduce_buf"); } +} + +void LayerNormOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + CHECK(parallel_ctx->policy() != kModelParallel); + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + *GetBlobDesc4BnInOp("out") = *in; + const LayerNormOpConf& conf = op_conf().layer_norm_conf(); + const int64_t begin_params_axis = ShiftNegativeAxisIfNeed(in->shape(), conf.begin_params_axis()); + const Shape param_shape = Shape({in->shape().Count(begin_params_axis)}); + if (conf.center()) { + BlobDesc* beta = GetBlobDesc4BnInOp("beta"); + beta->mut_shape() = param_shape; + beta->set_data_type(in->data_type()); + } + if (conf.scale()) { + BlobDesc* gamma = GetBlobDesc4BnInOp("gamma"); + gamma->mut_shape() = param_shape; + gamma->set_data_type(in->data_type()); + *GetBlobDesc4BnInOp("normalize_out") = *in; + } + const int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(in->shape(), conf.begin_norm_axis()); + const Shape bn_param_shape = Shape({in->shape().Count(0, begin_norm_axis)}); + BlobDesc* cudnn_bn_mean = GetBlobDesc4BnInOp("cudnn_bn_mean"); + cudnn_bn_mean->mut_shape() = bn_param_shape; + cudnn_bn_mean->set_data_type(in->data_type()); + *GetBlobDesc4BnInOp("cudnn_bn_inv_variance") = *cudnn_bn_mean; + *GetBlobDesc4BnInOp("cudnn_bn_scale_ones") = *cudnn_bn_mean; + *GetBlobDesc4BnInOp("cudnn_bn_bias_zeros") = *cudnn_bn_mean; +} + +void LayerNormOp::InferBwBufBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + CHECK(parallel_ctx->policy() != kModelParallel); + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + const LayerNormOpConf& conf = op_conf().layer_norm_conf(); + if (op_conf().trainable()) { *GetBlobDesc4BnInOp("bw_reduce_buf") = *in; } + const int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(in->shape(), conf.begin_norm_axis()); + const Shape bn_param_shape = Shape({in->shape().Count(0, begin_norm_axis)}); + BlobDesc* bn_scale_diff = GetBlobDesc4BnInOp("cudnn_bn_scale_diff_buf"); + bn_scale_diff->mut_shape() = bn_param_shape; + bn_scale_diff->set_data_type(in->data_type()); + *GetBlobDesc4BnInOp("cudnn_bn_bias_diff_buf") = *bn_scale_diff; +} + +void LayerNormOp::VirtualFixParallelDesc(ParallelDesc* pr_desc) const { + pr_desc->set_policy(ParallelPolicy::kDataParallel); +} + +REGISTER_OP(OperatorConf::kLayerNormConf, LayerNormOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/layer_norm_op.h b/oneflow/core/operator/layer_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9fcf8327ee0ae7d32d50a06d2b8da81e9d3733db --- /dev/null +++ b/oneflow/core/operator/layer_norm_op.h @@ -0,0 +1,30 @@ +#ifndef ONEFLOW_CORE_OPERATOR_LAYER_NORM_OP_H_ +#define ONEFLOW_CORE_OPERATOR_LAYER_NORM_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class LayerNormOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(LayerNormOp); + LayerNormOp() = default; + ~LayerNormOp() override = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override { return op_conf().layer_norm_conf(); } + bool NeedInBlobWhenBackward() const override { return true; } + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const override; + void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualFixParallelDesc(ParallelDesc* pr_desc) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_LAYER_NORM_OP_H_ diff --git a/oneflow/core/operator/local_reponse_normalization_op.h b/oneflow/core/operator/local_reponse_normalization_op.h index 22864365d8fc7351250c59765c7452d1eaf1c128..1028687023f7141c3dea67cd3ca79a2181b8cb6f 100644 --- a/oneflow/core/operator/local_reponse_normalization_op.h +++ b/oneflow/core/operator/local_reponse_normalization_op.h @@ -17,6 +17,8 @@ class LocalResponseNormalizationOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; diff --git a/oneflow/core/operator/log_counter_op.h b/oneflow/core/operator/log_counter_op.h index f464a9d926b72fbb9db581050d6763bbd85c6a91..28ad2ac1506c671d5605fcc654cd3ff2d552366f 100644 --- a/oneflow/core/operator/log_counter_op.h +++ b/oneflow/core/operator/log_counter_op.h @@ -15,6 +15,9 @@ class LogCounterOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; virtual LogicalNode* NewProperLogicalNode() { return new PrintLogicalNode; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } }; } // namespace oneflow diff --git a/oneflow/core/operator/loss_op.cpp b/oneflow/core/operator/loss_op.cpp index d91c02ee1a14c5f9f29ab75e86081881beced44f..ed450083398f9eb5d08a72e34e2e39ea991d9aa7 100644 --- a/oneflow/core/operator/loss_op.cpp +++ b/oneflow/core/operator/loss_op.cpp @@ -4,7 +4,7 @@ namespace oneflow { void LossOp::InitFromOpConf() { EnrollInputBn("prediction"); - EnrollInputBn("label", false); + if (HasFieldInCustomizedConf("label")) { EnrollInputBn("label", false); } EnrollOutputBn("loss", false); EnrollOutputBn("loss_instance_num", false); if (!GetValFromCustomizedConf<std::string>("weight").empty()) { @@ -20,23 +20,29 @@ void LossOp::VirtualGenKernelConf( const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { LossKernelConf* conf = GetMutLossKernelConf(kernel_conf); conf->set_prediction_type(GetBlobDesc4BnInOp("prediction")->data_type()); - conf->set_label_type(GetBlobDesc4BnInOp("label")->data_type()); + if (HasFieldInCustomizedConf("label")) { + conf->set_label_type(GetBlobDesc4BnInOp("label")->data_type()); + } else { + conf->set_label_type(DataType::kInvalidDataType); + } conf->set_weight_scalar(GetValFromCustomizedConf<float>("weight_scalar")); conf->set_reduction(static_cast<LossReductionType>(GetEnumFromCustomizedConf("reduction"))); } void LossOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, - std::function<void(OpContext*)>) const { + const ParallelContext* parallel_ctx) const { const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction"); - const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label"); - CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field()); - CHECK_EQ(pred_blob_desc->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field()); - CHECK_EQ(pred_blob_desc->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape()); + if (HasFieldInCustomizedConf("label")) { + const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label"); + CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field()); + CHECK_EQ(pred_blob_desc->has_dim0_valid_num_field(), + label_blob_desc->has_dim0_valid_num_field()); + CHECK_EQ(pred_blob_desc->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape()); + } if (pred_blob_desc->has_dim0_inner_shape()) { CHECK_EQ(pred_blob_desc->dim0_inner_shape().At(0), 1); } - CHECK_GE(pred_blob_desc->shape().NumAxes(), 2); + CHECK_GT(pred_blob_desc->shape().NumAxes(), 0); // loss BlobDesc* loss_blob_desc = GetBlobDesc4BnInOp("loss"); *loss_blob_desc = *pred_blob_desc; diff --git a/oneflow/core/operator/loss_op.h b/oneflow/core/operator/loss_op.h index ad05fd4e4420dd4b4766945b83eb735dc1bf17ed..1dc47d2d08940c562162ff2c9ff01b33c35b0cb5 100644 --- a/oneflow/core/operator/loss_op.h +++ b/oneflow/core/operator/loss_op.h @@ -16,8 +16,7 @@ class LossOp : public Operator { LogicalNode* NewProperLogicalNode() override { return new LossLogicalNode; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, - std::function<void(OpContext*)> EnrollOpCtx) const override; + const ParallelContext* parallel_ctx) const override; bool IsLossOp() const override { return true; } protected: @@ -32,6 +31,8 @@ class LossOp : public Operator { KernelConf* kernel_conf) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/loss_print_op.h b/oneflow/core/operator/loss_print_op.h index b7d5ef9d025b91cdeb159be9fd8ce9466f72160b..4c62db725486c4085b153991fbb03d4446b62c3e 100644 --- a/oneflow/core/operator/loss_print_op.h +++ b/oneflow/core/operator/loss_print_op.h @@ -15,6 +15,8 @@ class LossPrintOp final : public Operator { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override; }; diff --git a/oneflow/core/operator/matmul_op.cpp b/oneflow/core/operator/matmul_op.cpp index 62bcb2a1daacb077c47500b288160ffca1c64da2..a53920733d5c1a62d980fac521bd6db8b382eca5 100644 --- a/oneflow/core/operator/matmul_op.cpp +++ b/oneflow/core/operator/matmul_op.cpp @@ -2,37 +2,139 @@ #include "oneflow/core/common/balanced_splitter.h" namespace oneflow { +namespace { + +class Matmul_MS_MS_2_P_OpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(Matmul_MS_MS_2_P_OpParallelSignature); + ~Matmul_MS_MS_2_P_OpParallelSignature() override = default; + + Matmul_MS_MS_2_P_OpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": (S, S) -> P"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + const auto& b_sbp_infer_hint = SbpInferHint4Ibn("b"); + if (!b_sbp_infer_hint.is_model_split()) { return MakeOpParallelMatchSignatureMismatch(); } + int32_t b_expected_split_axis = (op().op_conf().matmul_conf().transpose_b() ? 1 : 0); + if (b_sbp_infer_hint.split_axis() != b_expected_split_axis) { + return MakeOpParallelMatchSignatureMismatch(); + } + if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); } + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + int32_t a_split_axis = (op().op_conf().matmul_conf().transpose_a() ? 0 : 1); + (*bn2sbp)["a"].mutable_split_parallel()->set_axis(a_split_axis); + (*bn2sbp)["b"] = SbpInferHint4Ibn("b").sbp_parallel(); + (*bn2sbp)["out"].mutable_partial_sum_parallel(); + } +}; + +} // namespace + void MatmulOp::InitFromOpConf() { CHECK(op_conf().has_matmul_conf()); - - EnrollInputBn("in"); - EnrollInputBn("weight"); + EnrollInputBn("a"); + EnrollInputBn("b"); EnrollOutputBn("out"); - if (op_conf().matmul_conf().has_bias()) { - EnrollInputBn("bias"); - EnrollConstBufBn("bias_multiplier"); - } + EnrollFwBufBn("fw_buf"); + EnrollBwBufBn("bw_buf"); } const PbMessage& MatmulOp::GetCustomizedConf() const { return op_conf().matmul_conf(); } +bool MatmulOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const { + CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()); + return ibn == "b"; +} + +void MatmulOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(MakeDataSplitOpParallelSignature(this)); + op_parallel_signatures->emplace_back(Make_DS_MB_2_DS_OpParallelSignature(this)); + auto IsValidSplit = [this](int32_t axis) { + int32_t b_expected_split_axis = (op_conf().matmul_conf().transpose_b() ? 0 : 1); + return axis == b_expected_split_axis; + }; + op_parallel_signatures->emplace_back(Make_DB_MS_2_MS_OpParallelSignature(this, IsValidSplit)); + op_parallel_signatures->emplace_back(new Matmul_MS_MS_2_P_OpParallelSignature(this)); +} + void MatmulOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const MatmulOpConf& conf = op_conf().matmul_conf(); - BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); - CHECK_EQ(in_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType()); - int32_t units = conf.units(); - if (parallel_ctx->policy() == kModelParallel) { - BalancedSplitter splitter(units, parallel_ctx->parallel_num()); - units = splitter.At(parallel_ctx->parallel_id()).size(); - } - // out + BlobDesc* a_blob_desc = GetBlobDesc4BnInOp("a"); + BlobDesc* b_blob_desc = GetBlobDesc4BnInOp("b"); + CHECK_EQ(a_blob_desc->shape().NumAxes(), b_blob_desc->shape().NumAxes()); + CHECK_GE(a_blob_desc->shape().NumAxes(), 2); + size_t num_axes = a_blob_desc->shape().NumAxes(); + if (conf.transpose_a()) { + CHECK(!a_blob_desc->has_dim0_valid_num_field()); + CHECK(!a_blob_desc->has_dim1_valid_num_field()); + CHECK(!a_blob_desc->has_dim2_valid_num_field()); + } + if (conf.transpose_b()) { + CHECK(!b_blob_desc->has_dim0_valid_num_field()); + CHECK(!b_blob_desc->has_dim1_valid_num_field()); + CHECK(!b_blob_desc->has_dim2_valid_num_field()); + } + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + *out_blob_desc = *a_blob_desc; + FOR_RANGE(int32_t, i, 0, num_axes - 2) { + CHECK_EQ(a_blob_desc->shape().At(i), b_blob_desc->shape().At(i)); + } + int64_t a_dim_index = conf.transpose_a() ? num_axes - 1 : num_axes - 2; + out_blob_desc->mut_shape().Set(num_axes - 2, a_blob_desc->shape().At(a_dim_index)); + int64_t b_dim_index = conf.transpose_b() ? num_axes - 2 : num_axes - 1; + out_blob_desc->mut_shape().Set(num_axes - 1, b_blob_desc->shape().At(b_dim_index)); + int64_t a_mid_dim_index = conf.transpose_a() ? num_axes - 2 : num_axes - 1; + int64_t b_mid_dim_index = conf.transpose_b() ? num_axes - 1 : num_axes - 2; + CHECK_EQ(a_blob_desc->shape().At(a_mid_dim_index), b_blob_desc->shape().At(b_mid_dim_index)); + if (device_type() == DeviceType::kGPU && num_axes >= 3) { + int batch_num = a_blob_desc->shape().Count(0, num_axes - 2); + // Assume gpu address is 64 bit + BlobDesc* fw_buf_blob_desc = GetBlobDesc4BnInOp("fw_buf"); + *fw_buf_blob_desc = *out_blob_desc; + fw_buf_blob_desc->mut_shape() = {3 * batch_num}; + fw_buf_blob_desc->set_data_type(DataType::kInt64); + fw_buf_blob_desc->set_has_data_id_field(false); + } +} + +int32_t MatmulOp::OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const { + CHECK_EQ(SbpInferHint4Ibn("a").num_axes(), SbpInferHint4Ibn("b").num_axes()); + const auto& b_sbp_infer_hint = SbpInferHint4Ibn("b"); + CHECK_EQ(SbpInferHint4Ibn("b").num_axes(), 2); + CHECK(b_sbp_infer_hint.is_model_split()); + int32_t b_model_split_axis = b_sbp_infer_hint.split_axis(); + if (op_conf().matmul_conf().transpose_b()) { + if (b_model_split_axis == 0) { return 1; } + } else { + if (b_model_split_axis == 1) { return 1; } + } + UNIMPLEMENTED(); + return -1; +} + +void MatmulOp::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - *out_blob_desc = *in_blob_desc; - out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), units}); - if (conf.has_bias()) { - // bias_multiplier - GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({in_blob_desc->shape().At(0), 1}); + size_t num_axes = out_blob_desc->shape().NumAxes(); + if (device_type() == DeviceType::kGPU && num_axes >= 3) { + BlobDesc* bw_buf_blob_desc = GetBlobDesc4BnInOp("bw_buf"); + int32_t batch_num = out_blob_desc->shape().Count(0, num_axes - 2); + *bw_buf_blob_desc = *out_blob_desc; + bw_buf_blob_desc->mut_shape() = {3 * batch_num}; + bw_buf_blob_desc->set_data_type(DataType::kInt64); + bw_buf_blob_desc->set_has_data_id_field(false); } } diff --git a/oneflow/core/operator/matmul_op.h b/oneflow/core/operator/matmul_op.h index 9c07cb64157015f71c3467727c730919c21f25ea..a7ced9dabb37a9f5240eb702ca5272f10597f8e6 100644 --- a/oneflow/core/operator/matmul_op.h +++ b/oneflow/core/operator/matmul_op.h @@ -8,12 +8,22 @@ class MatmulOp final : public Operator { OF_DISALLOW_COPY_AND_MOVE(MatmulOp); MatmulOp() = default; ~MatmulOp() = default; + void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; + bool NeedOutBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; - int32_t ModelSplitAxis() const override { return 1; } - int32_t MaxModelSplitNum() const override { return op_conf().matmul_conf().units(); } + void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override; + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; }; } // namespace oneflow diff --git a/oneflow/core/operator/maximum_op.h b/oneflow/core/operator/maximum_op.h index e76aed1dd0b1783d41839266fc2a8f1aee07355a..d67c7ee35b80b82192e6472556b19685f229e02d 100644 --- a/oneflow/core/operator/maximum_op.h +++ b/oneflow/core/operator/maximum_op.h @@ -14,9 +14,11 @@ class MaximumOp final : public CWiseOp { void VirtualInitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/mean_op.cpp b/oneflow/core/operator/mean_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..851d14b6c2457bcab165f824241beb1f5ea05a26 --- /dev/null +++ b/oneflow/core/operator/mean_op.cpp @@ -0,0 +1,31 @@ +#include "oneflow/core/operator/mean_op.h" + +namespace oneflow { + +void MeanOp::InitFromOpConf() { + CHECK(op_conf().has_mean_conf()); + EnrollInputBn("in"); + EnrollOutputBn("out"); + EnrollFwBufBn("fw_tmp"); + EnrollBwBufBn("bw_tmp"); +} + +void MeanOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* in_blob = GetBlobDesc4BnInOp("in"); + BlobDesc* out_blob = GetBlobDesc4BnInOp("out"); + *out_blob = *in_blob; + std::vector<int64_t> dim_vec = in_blob->shape().dim_vec(); + dim_vec.back() = 1; + out_blob->mut_shape() = Shape(std::move(dim_vec)); + *GetBlobDesc4BnInOp("fw_tmp") = *in_blob; +} + +void MeanOp::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const { + *GetBlobDesc4BnInOp("bw_tmp") = *GetBlobDesc4BnInOp("out"); +} + +REGISTER_OP(OperatorConf::kMeanConf, MeanOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/mean_op.h b/oneflow/core/operator/mean_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2a977d23d4705532e3302af68dd48e0c48e6f99e --- /dev/null +++ b/oneflow/core/operator/mean_op.h @@ -0,0 +1,29 @@ +#ifndef ONEFLOW_CORE_OPERATOR_MEAN_OP_H_ +#define ONEFLOW_CORE_OPERATOR_MEAN_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class MeanOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(MeanOp); + MeanOp() = default; + ~MeanOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override { return op_conf().mean_conf(); } + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_MEAN_OP_H_ diff --git a/oneflow/core/operator/model_save_op.h b/oneflow/core/operator/model_save_op.h index 157af7e8fa778b2e86f3d41848397f072eb33be9..e80dd03ecfcaf0e4ac4d8cefcbb1c3ab61460945 100644 --- a/oneflow/core/operator/model_save_op.h +++ b/oneflow/core/operator/model_save_op.h @@ -13,6 +13,9 @@ class ModelSaveOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } }; } // namespace oneflow diff --git a/oneflow/core/operator/momentum_model_update_op.cpp b/oneflow/core/operator/momentum_model_update_op.cpp index f39cc3892e34b9ac032f8a50d00d1969c81c487d..674dd29e1e93ef0a4ac0ec8a5ee43ec34b088781 100644 --- a/oneflow/core/operator/momentum_model_update_op.cpp +++ b/oneflow/core/operator/momentum_model_update_op.cpp @@ -4,7 +4,7 @@ namespace oneflow { void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollForwardModelBn("momentum"); } -void MomentumModelUpdateOp::InferBlobDescs( +void MomentumModelUpdateOp::MdUpdtVirtualInferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model"); diff --git a/oneflow/core/operator/momentum_model_update_op.h b/oneflow/core/operator/momentum_model_update_op.h index 7be6e82f85b9daca548fb8c358082b96bc4dd92b..085deb7f3525c11652b46a280b95767b39591caf 100644 --- a/oneflow/core/operator/momentum_model_update_op.h +++ b/oneflow/core/operator/momentum_model_update_op.h @@ -11,11 +11,10 @@ class MomentumModelUpdateOp final : public NormalModelUpdtOp { MomentumModelUpdateOp() = default; ~MomentumModelUpdateOp() = default; - void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override; - private: void MdUpdtVirtualInitFromOpConf() override; + void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; }; } // namespace oneflow diff --git a/oneflow/core/operator/multiply_op.h b/oneflow/core/operator/multiply_op.h index 9194a1c210a94cd188d422e0f0fb85388a56555f..5500497c508215ba38637d99b3ddbd4ba99a37a5 100644 --- a/oneflow/core/operator/multiply_op.h +++ b/oneflow/core/operator/multiply_op.h @@ -12,6 +12,9 @@ class MultiplyOp final : public Operator { const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/naive_model_update_op.h b/oneflow/core/operator/naive_model_update_op.h index 2582f20aa6921f6f1d6c00c3c68755524c67e0e5..414d5f14d4252d1d2c94a1aa5e1ec57d1db1c78e 100644 --- a/oneflow/core/operator/naive_model_update_op.h +++ b/oneflow/core/operator/naive_model_update_op.h @@ -10,9 +10,6 @@ class NaiveModelUpdateOp final : public NormalModelUpdtOp { OF_DISALLOW_COPY_AND_MOVE(NaiveModelUpdateOp); NaiveModelUpdateOp() = default; ~NaiveModelUpdateOp() = default; - - void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override {} }; } // namespace oneflow diff --git a/oneflow/core/operator/nccl_all_gather_op.h b/oneflow/core/operator/nccl_all_gather_op.h index d57f0189778001356e73e1fe159a5ad02b7ff508..b5fc4c1a289cbfc381db2add952d6b98b7852995 100644 --- a/oneflow/core/operator/nccl_all_gather_op.h +++ b/oneflow/core/operator/nccl_all_gather_op.h @@ -18,6 +18,8 @@ class NcclAllGatherOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/nccl_all_reduce_op.h b/oneflow/core/operator/nccl_all_reduce_op.h index 969f758b906a814cc86e7ebd9660ba410cbcfbb5..b634fa962a4143a95f928af532b43cd46d5c5059 100644 --- a/oneflow/core/operator/nccl_all_reduce_op.h +++ b/oneflow/core/operator/nccl_all_reduce_op.h @@ -18,6 +18,8 @@ class NcclAllReduceOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/nccl_reduce_scatter_op.h b/oneflow/core/operator/nccl_reduce_scatter_op.h index 9441205425cc7dbb48bf704f7d260dc21d93b3e2..086d1732b3099f8aae5580ce9007d30299c16128 100644 --- a/oneflow/core/operator/nccl_reduce_scatter_op.h +++ b/oneflow/core/operator/nccl_reduce_scatter_op.h @@ -18,6 +18,7 @@ class NcclReduceScatterOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/normal_model_update_op.cpp b/oneflow/core/operator/normal_model_update_op.cpp index a36e4ab03a24b1e7ca8cdde99c8a409193d7f406..28c7a9eb6df72270b6ee637e85758cb75f3f9ec0 100644 --- a/oneflow/core/operator/normal_model_update_op.cpp +++ b/oneflow/core/operator/normal_model_update_op.cpp @@ -2,6 +2,7 @@ #include "oneflow/core/operator/rmsprop_model_update_op.h" #include "oneflow/core/operator/momentum_model_update_op.h" #include "oneflow/core/operator/lars_model_update_op.h" +#include "oneflow/core/operator/adam_model_update_op.h" namespace oneflow { @@ -9,9 +10,24 @@ void NormalModelUpdtOp::InitFromOpConf() { EnrollInputBn("model_diff", false); EnrollInputBn("total_instance_num_diff", false); EnrollOutputBn("model", false); + if (op_conf().normal_mdupdt_conf().user_conf().has_clip_conf() + && op_conf().normal_mdupdt_conf().user_conf().clip_conf().has_clip_by_global_norm()) { + EnrollDataTmpBn("data_tmp"); + } MdUpdtVirtualInitFromOpConf(); } +void NormalModelUpdtOp::InferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + if (op_conf().normal_mdupdt_conf().user_conf().has_clip_conf() + && op_conf().normal_mdupdt_conf().user_conf().clip_conf().has_clip_by_global_norm()) { + *GetBlobDesc4BnInOp("data_tmp") = *GetBlobDesc4BnInOp("model_diff"); + GetBlobDesc4BnInOp("data_tmp")->mut_shape() = Shape({1}); + } + MdUpdtVirtualInferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); +} + const PbMessage& NormalModelUpdtOp::GetCustomizedConf() const { return op_conf().normal_mdupdt_conf(); } diff --git a/oneflow/core/operator/normal_model_update_op.h b/oneflow/core/operator/normal_model_update_op.h index 160c20e1a9277cd9ed70c9700de2c16bf8c57a28..32b2f1fcbd45cae3f0d0817d13f69164d44ba49a 100644 --- a/oneflow/core/operator/normal_model_update_op.h +++ b/oneflow/core/operator/normal_model_update_op.h @@ -10,14 +10,21 @@ class NormalModelUpdtOp : public Operator { OF_DISALLOW_COPY_AND_MOVE(NormalModelUpdtOp); virtual ~NormalModelUpdtOp() = default; - virtual void InitFromOpConf(); + void InitFromOpConf() override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; const PbMessage& GetCustomizedConf() const override; protected: NormalModelUpdtOp() = default; virtual void MdUpdtVirtualInitFromOpConf() {} + virtual void MdUpdtVirtualInferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const {} private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/normalization_op.cpp b/oneflow/core/operator/normalization_op.cpp index f53dfd90fdf457b5a511626c4f4f0b21fad36ce8..211e81b1d8d18d4b2a2e284d94d5b0efae5c3bde 100644 --- a/oneflow/core/operator/normalization_op.cpp +++ b/oneflow/core/operator/normalization_op.cpp @@ -39,7 +39,8 @@ const PbMessage& NormalizationOp::GetCustomizedConf() const { void NormalizationOp::InferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, std::function<void(OpContext*)> EnrollOpCtx) const { + const ParallelContext* parallel_ctx, int64_t record_piece_size, + std::function<void(OpContext*)> EnrollOpCtx) const { const auto& conf = op_conf().normalization_conf(); const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); const DataType in_data_type = in_blob_desc->data_type(); diff --git a/oneflow/core/operator/normalization_op.h b/oneflow/core/operator/normalization_op.h index 373bae04dbceb22a1e6534a9a872e4632d850a80..5d63921d8e3c88605aa1d3c14aab3631f020b8f4 100644 --- a/oneflow/core/operator/normalization_op.h +++ b/oneflow/core/operator/normalization_op.h @@ -25,10 +25,12 @@ class NormalizationOp final : public Operator { bool NeedOutBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void InferParamBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const NormalizationOpConf&, int64_t norm_part_num, DataType in_data_type, bool use_cudnn) const; diff --git a/oneflow/core/operator/one_hot_op.cpp b/oneflow/core/operator/one_hot_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c510125f44a4e21dfcdeb6760935171cdffe933 --- /dev/null +++ b/oneflow/core/operator/one_hot_op.cpp @@ -0,0 +1,33 @@ +#include "oneflow/core/operator/one_hot_op.h" + +namespace oneflow { + +void OneHotOp::InitFromOpConf() { + CHECK(op_conf().has_one_hot_conf()); + EnrollInputBn("indices", false); + EnrollOutputBn("out", false); +} + +const PbMessage& OneHotOp::GetCustomizedConf() const { return op_conf().one_hot_conf(); } + +void OneHotOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const OneHotOpConf& conf = op_conf().one_hot_conf(); + const int64_t depth = conf.depth(); + const DataType data_type = + conf.has_data_type() ? conf.data_type() : Global<JobDesc>::Get()->DefaultDataType(); + CHECK_GT(depth, 0); + const BlobDesc* indices = GetBlobDesc4BnInOp("indices"); + CHECK(IsIntegralDataType(indices->data_type())); + CHECK_GT(indices->shape().NumAxes(), 0); + BlobDesc* out = GetBlobDesc4BnInOp("out"); + *out = *indices; + out->set_data_type(data_type); + std::vector<int64_t> dim_vec = indices->shape().dim_vec(); + dim_vec.push_back(depth); + out->mut_shape() = Shape(dim_vec); +} + +REGISTER_OP(OperatorConf::kOneHotConf, OneHotOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/one_hot_op.h b/oneflow/core/operator/one_hot_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b8618e290a6b34cdefb5649947330d42d09c464a --- /dev/null +++ b/oneflow/core/operator/one_hot_op.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_OPERATOR_ONE_HOT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_ONE_HOT_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class OneHotOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(OneHotOp); + OneHotOp() = default; + ~OneHotOp() override = default; + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedOutBlobWhenBackward() const override { return false; } + bool NeedInBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_ONE_HOT_OP_H_ diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 19a856a481790816beb68464f721f8a802096e6b..033834a309d4f9c1e13dbba5af05d8493c295083 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -4,7 +4,9 @@ package oneflow; import "oneflow/core/common/shape.proto"; import "oneflow/core/common/data_type.proto"; import "oneflow/core/record/image.proto"; +import "oneflow/core/record/record.proto"; import "oneflow/core/job/resource.proto"; +import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/register/logical_blob_id.proto"; enum ActivationType { @@ -37,6 +39,10 @@ message RandomNormalInitializerConf { optional float std = 2 [default = 1]; } +message TruncatedNormalInitializerConf { + optional float std = 1 [default = 1]; +} + enum VarianceNorm { kFanIn = 0; kFanOut = 1; @@ -51,6 +57,20 @@ message MsraInitializerConf { required VarianceNorm variance_norm = 1; } +//output[D_0 ... D_(axis - 1) i D_(axis + 1) ... D_n] = start + i * stride +message RangeInitializerConf { + optional double start = 1 [default = 0]; + optional double stride = 2 [default = 1]; + optional int64 axis = 3 [default = -1]; +} + +message IntRangeInitializerConf { + optional int64 start = 1 [default = 0]; + optional int64 stride = 2 [default = 1]; + optional int64 axis = 3 [default = -1]; +} + + message InitializerConf { oneof type { ConstantInitializerConf constant_conf = 1; @@ -58,8 +78,11 @@ message InitializerConf { RandomUniformInitializerConf random_uniform_conf = 3; RandomUniformIntInitializerConf random_uniform_int_conf = 4; RandomNormalInitializerConf random_normal_conf = 5; - XavierInitializerConf xavier_conf = 6; - MsraInitializerConf msra_conf = 7; + TruncatedNormalInitializerConf truncated_normal_conf = 6; + XavierInitializerConf xavier_conf = 7; + MsraInitializerConf msra_conf = 8; + RangeInitializerConf range_conf = 9; + IntRangeInitializerConf int_range_conf = 10; } } @@ -206,6 +229,12 @@ enum LossReductionType { kSumOverNonZeroWeight = 3; } +message SparseCrossEntropyOpConf { + required string prediction = 1; + required string label = 2; + required string out = 3; +} + message SparseSoftmaxCrossEntropyLossOpConf { required string prediction = 1; required string label = 2; @@ -224,6 +253,25 @@ message SparseCrossEntropyLossOpConf { optional string weight = 6; } +message SigmoidCrossEntropyLossOpConf { + required string prediction = 1; + required string label = 2; + required string loss = 3; + optional bool normalize = 4 [default = true]; + optional float scale = 5 [default = 1.0]; + optional LossReductionType reduction = 6 [default = kSumOverN]; + optional float weight_scalar = 7 [default = 1.0]; + optional string weight = 8; +} + +message IdentityLossOpConf { + required string prediction = 1; + required string loss = 2; + optional LossReductionType reduction = 3 [default = kSumOverN]; + optional float weight_scalar = 4 [default = 1.0]; + optional string weight = 5; +} + message ConcatOpConf { repeated string in = 1; required string out = 2; @@ -293,6 +341,13 @@ message LARSModelUpdateConf { optional float lars_coefficient = 3 [default = 0.0001]; } +message AdamModelUpdateConf { + optional float beta1 = 1 [default = 0.9]; + optional float beta2 = 2 [default = 0.999]; + optional float epsilon = 3 [default = 1e-8]; + optional bool do_bias_correction = 4 [default = false]; +} + message ExponentialDecayConf { required int64 decay_batches = 1; required double decay_rate = 2; @@ -324,12 +379,12 @@ message PolynomialDecayConf { } message CosineDecayConf { - required int64 decay_batches = 1; + required int64 decay_batches = 1; optional double alpha = 2 [default = 0.0]; } message LinearCosineDecayConf { - required int64 decay_batches = 1; + required int64 decay_batches = 1; optional double num_periods = 2 [default = 0.5]; optional double alpha = 3 [default = 0.0]; optional double beta = 4 [default = 0.001]; @@ -364,14 +419,27 @@ message WarmupConf { } } +message ClipByGlobalNormConf { + required float clip_norm = 1; + optional float global_norm = 2; +} + +message ClipConf { + oneof type { + ClipByGlobalNormConf clip_by_global_norm = 1; + } +} + message NormalModelUpdateOpUserConf { optional LearningRateDecayConf learning_rate_decay = 1; optional WarmupConf warmup_conf = 2; + optional ClipConf clip_conf = 3; oneof normal_mdupdt { NaiveModelUpdateConf naive_conf = 1000; MomentumModelUpdateConf momentum_conf = 1001; RMSPropModelUpdateConf rmsprop_conf = 1002; LARSModelUpdateConf lars_conf = 1003; + AdamModelUpdateConf adam_conf = 1004; } } @@ -413,6 +481,11 @@ message LogCounterOpConf { optional int32 interval = 2 [default = 1]; } +message GeluOpConf { + required string in = 1; + required string out = 2; +} + message LossPrintOpConf { required LogicalBlobId loss_lbi = 1; required LogicalBlobId loss_instance_num_lbi = 2; @@ -433,8 +506,15 @@ message ReduceSumOpConf { LogicalBlobId in_sys = 2; // For System } required string out = 3; - optional int32 axis = 4; - optional bool keepdims = 5 [default = false]; + repeated int32 axis = 4; + optional bool keep_dims = 5 [default = false]; +} + +message ReduceMeanOpConf { + required string in = 1; + required string out = 2; + repeated int32 axis = 3; + optional bool keep_dims = 4 [default = false]; } message BasicRnnOpConf { @@ -460,6 +540,7 @@ message ReshapeOpConf { required string in = 1; required string out = 2; required ShapeProto shape = 3; + optional bool has_dim0_in_shape = 4; } message EmbeddingLookupOpConf { @@ -491,10 +572,20 @@ message CastOpConf { required DataType data_type = 3; } +message VariableOpConf { + required string tick = 1; + required string out = 2; + required ShapeProto shape = 3; + optional DataType data_type = 4; + optional InitializerConf initializer = 5; + optional string model_name = 6 [default = "weight"]; + optional int32 model_split_axis = 7 [default = 0]; +} + message LocalResponseNormalizationOpConf { required string in = 1; required string out = 2; - required string data_format = 3; + required string data_format = 3; optional int32 depth_radius = 4 [default = 5]; optional double bias = 5 [default = 1]; optional double alpha = 6 [default = 1]; @@ -542,11 +633,16 @@ message PreprocessConf { } } +message RandomShuffleConf { + optional int32 buffer_size = 1 [default = 1024]; +} + message RecordLoadOpConf { required string out = 1; required string data_dir = 2; optional string part_name_prefix = 3 [default = "part-"]; optional int32 part_name_suffix_length = 4 [default = -1]; + optional RandomShuffleConf random_shuffle_conf = 5; } message BlobConf { @@ -564,6 +660,7 @@ message DecodeOFRecordOpConf { optional int32 part_name_suffix_length = 3 [default = -1]; optional string in = 4; repeated BlobConf blob = 5; + optional RandomShuffleConf random_shuffle_conf = 6; } message DecodeRandomOpConf { @@ -582,6 +679,7 @@ message DefineTestBlobOpConf { optional int64 dim1_valid_num = 6; optional int64 dim2_valid_num = 7; repeated int64 record_id_in_device_piece = 8; + optional bool has_diff = 9 [default = false]; } message NormalizationOpConf { @@ -648,14 +746,17 @@ message AccuracyOpConf { required string label = 2; optional int32 top_k = 3 [default = 1]; required string accuracy = 4; + optional string weight = 5; } message MatmulOpConf { - required string in = 1; - required string weight = 2; - optional string bias = 3; - required int32 units = 4; + // input lbn + required string a = 1; + required string b = 2; + // output bn required string out = 5; + optional bool transpose_a = 6 [default = false]; + optional bool transpose_b = 7 [default = false]; } message DotOpConf { @@ -689,28 +790,193 @@ message HingeLossOpConf { message PackOpConf { required string in = 1; required string out = 2; - oneof pack_num_conf { - int32 pack_num = 3; - int32 pack_num_per_record = 4; - } - required string related_unpack = 5; + required int32 pack_num = 3; + required string related_unpack = 4; } message UnpackOpConf { required string in = 1; required string out = 2; - oneof unpack_num_conf { - int32 unpack_num = 3; - int32 unpack_num_per_record = 4; - } + required int32 unpack_num = 3; } message RepeatOpConf { required string in = 1; required string out = 2; - oneof repeat_num_conf { - int32 repeat_num = 3; - int32 repeat_num_per_record = 4; + required int32 repeat_num = 3; +} + +message GatherOpConf { + required string in = 1; + required string indices = 2; + required string out = 3; + optional int64 axis = 4 [default = 0]; +} + +message BatchGatherOpConf { + required string in = 1; + required string indices = 2; + required string out = 3; +} + +message SqrtOpConf { + required string in = 1; + required string out = 2; +} + +message RsqrtOpConf { + required string in = 1; + required string out = 2; +} + +message SquareOpConf { + required string in = 1; + required string out = 2; +} + +message BroadcastAddOpConf { + required string a = 1; + required string b = 2; + required string out = 3; + optional bool is_const = 4 [default = false]; +} + +message BroadcastSubOpConf { + required string a = 1; + required string b = 2; + required string out = 3; + optional bool is_const = 4 [default = false]; +} + +message BroadcastMulOpConf { + required string a = 1; + required string b = 2; + required string out = 3; + optional bool is_const = 4 [default = false]; +} + +message BroadcastDivOpConf { + required string a = 1; + required string b = 2; + required string out = 3; + optional bool is_const = 4 [default = false]; +} + + +message BiasAddOpConf { + // inputs + required string a = 1; + required string b = 2; + // output + required string out = 3; +} + +message MeanOpConf { + required string in = 1; + required string out = 2; + // TODO: axis of mean +} + +message DimSliceConf { + optional int32 start = 1 [default = 0]; + optional int32 end = 2 [default = 0]; + optional int32 stride = 3 [default = 1]; +} + +message SliceOpConf { + required string in = 1; + required string out = 2; + repeated DimSliceConf dim_slice_conf = 3; +} + +message LayerNormOpConf { + required string in = 1; + required string out = 2; + optional bool center = 3 [default = true]; + optional bool scale = 4 [default = true]; + optional ActivationType activation = 5 [default = kNone]; + optional int64 begin_norm_axis = 6 [default = 1]; + optional int64 begin_params_axis = 7 [default = -1]; + optional double epsilon = 8 [default = 1e-5]; +} + +message ConstantOpConf { + required string tick = 1; + required string out = 2; + optional ShapeProto shape = 3; + optional DataType data_type = 4; + optional InitializerConf initializer = 5; + optional bool use_device_piece_size_as_dim0 = 6 [default = false]; +} + +message DebugOpConf { + required string in = 1; + required string out = 2; + optional string in_blob_dump_dir = 3; + optional string out_diff_blob_dump_dir = 4; + optional string part_name_prefix = 5 [default = "part-"]; + optional int32 part_name_suffix_length = 6 [default = -1]; + oneof const_out { + string const_out_feature_load_filepath = 7; + Feature const_out_feature = 8; + } + oneof const_in_diff { + string const_in_diff_feature_load_filepath = 9; + Feature const_in_diff_feature = 10; + } +} + +message OneHotOpConf { + required string indices = 1; + required string out = 2; + required int64 depth = 3; + optional DataType data_type = 4; +} + +message ScalarAddOpConf { + required string in = 1; + required string out = 2; + oneof scalar_operand { + int64 int_operand = 3; + double float_operand = 4; + } +} + +message ScalarMulOpConf { + required string in = 1; + required string out = 2; + oneof scalar_operand { + int64 int_operand = 3; + double float_operand = 4; + } +} + +message ReduceIdentityOpConf { +} + +message TickOpConf { + optional string in = 1; + required string out = 2; +} + +message TupleIdentityOpConf { + repeated string in = 1; + repeated string out = 2; +} + +message TopKOpConf { + required string in = 1; + required string out = 2; + optional int32 k = 3 [default = 1]; + optional bool sorted = 4 [default = true]; +} + +message ParallelCastOpConf { + required string in = 1; + required string out = 2; + oneof parallel_type { + SplitParallel split_parallel = 3; + BroadcastParallel broadcast_parallel = 4; } } @@ -751,8 +1017,12 @@ message OperatorConf { ModelSaveOpConf model_save_conf = 119; SharedModelDiffAddOpConf shared_model_diff_add_conf = 120; CastOpConf cast_conf = 121; - + VariableOpConf variable_conf = 122; + ReduceIdentityOpConf reduce_identity_conf = 123; + TickOpConf tick_conf = 124; + // domain op + TupleIdentityOpConf tuple_identity_conf = 200; TransposeOpConf transpose_conf = 201; ReshapeOpConf reshape_conf = 202; BasicRnnOpConf basic_rnn_conf = 203; @@ -793,7 +1063,34 @@ message OperatorConf { UnpackOpConf unpack_conf = 238; RepeatOpConf repeat_conf = 239; LogCounterOpConf log_counter_conf = 240; - L2NormalizeOpConf l2_normalize_conf = 241; + GeluOpConf gelu_conf = 241; + GatherOpConf gather_conf = 242; + BatchGatherOpConf batch_gather_conf = 243; + MeanOpConf mean_conf = 251; + SliceOpConf slice_conf = 252; + BiasAddOpConf bias_add_conf = 253; + LayerNormOpConf layer_norm_conf = 254; + ConstantOpConf constant_conf = 255; + DebugOpConf debug_conf = 256; + SigmoidCrossEntropyLossOpConf sigmoid_cross_entropy_loss_conf = 257; + OneHotOpConf one_hot_conf = 258; + IdentityLossOpConf identity_loss_conf = 259; + SparseCrossEntropyOpConf sparse_cross_entropy_conf= 260; + ReduceMeanOpConf reduce_mean_conf = 261; + TopKOpConf top_k_conf = 262; + ParallelCastOpConf parallel_cast_conf = 263; + L2NormalizeOpConf l2_normalize_conf = 264; + + // math op + BroadcastAddOpConf broadcast_add_conf = 500; + BroadcastSubOpConf broadcast_sub_conf = 501; + BroadcastMulOpConf broadcast_mul_conf = 502; + BroadcastDivOpConf broadcast_div_conf = 503; + SquareOpConf square_conf = 504; + SqrtOpConf sqrt_conf = 505; + RsqrtOpConf rsqrt_conf = 506; + ScalarAddOpConf scalar_add_conf = 507; + ScalarMulOpConf scalar_mul_conf = 508; } } diff --git a/oneflow/core/operator/op_parallel_match_result.proto b/oneflow/core/operator/op_parallel_match_result.proto new file mode 100644 index 0000000000000000000000000000000000000000..2984274383bd3b33dee38eafe2ad5055e19bdb7d --- /dev/null +++ b/oneflow/core/operator/op_parallel_match_result.proto @@ -0,0 +1,39 @@ +syntax = "proto2"; +package oneflow; + +import "oneflow/core/job/placement.proto"; + +message OpParallelMatchSucess { +} + +message OpParallelSignatureMismatch { +} + +message ParallelPolicyError { + required ParallelPolicy configured = 1; + required ParallelPolicy expected = 2; +} + +message ParallelNumError { + required int64 configured = 1; + required int64 expected = 2; +} + +message OpParallelConfError { + optional ParallelPolicyError parallel_policy_error = 1; + optional ParallelNumError parallel_num_error = 2; +} + +message OpParallelMatchFail { + oneof fail_type { + OpParallelSignatureMismatch signature_mismatch = 1; + OpParallelConfError conf_error = 2; + } +} + +message OpParallelMatchResult { + oneof result_type { + OpParallelMatchSucess success = 1; + OpParallelMatchFail fail = 2; + } +} \ No newline at end of file diff --git a/oneflow/core/operator/op_parallel_signature.cpp b/oneflow/core/operator/op_parallel_signature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..decf39b7913c11afe1b098271503b02a8d07dbfa --- /dev/null +++ b/oneflow/core/operator/op_parallel_signature.cpp @@ -0,0 +1,347 @@ +#include "oneflow/core/operator/op_parallel_signature.h" +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +const OpParallelMatchResult MakeOpParallelMatchSuccess() { + OpParallelMatchResult success; + success.mutable_success(); + return success; +} + +const OpParallelMatchResult MakeOpParallelMatchSignatureMismatch() { + OpParallelMatchResult signature_mismatch; + signature_mismatch.mutable_fail()->mutable_signature_mismatch(); + return signature_mismatch; +} + +const OpParallelMatchResult MakeOpParallelMatchParallelPolicyError(ParallelPolicy configured, + ParallelPolicy expected) { + OpParallelMatchResult policy_error; + auto* err = policy_error.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error(); + err->set_configured(configured); + err->set_expected(expected); + return policy_error; +} + +const OpParallelMatchResult MakeOpParallelMatchParallelNumError(int64_t configured, + int64_t expected) { + OpParallelMatchResult parallel_num_error; + auto* err = parallel_num_error.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error(); + err->set_configured(configured); + err->set_expected(expected); + return parallel_num_error; +} + +namespace { + +class DataSplitOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(DataSplitOpParallelSignature); + ~DataSplitOpParallelSignature() override = default; + + DataSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { + return op().op_name() + ": (S(0), ...) -> (S(0), ...)"; + } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + bool is_data_split = true; + for (const auto& bn : op().input_bns()) { + const SbpInferHint& sbp_infer_hint = SbpInferHint4Ibn(bn); + if (!sbp_infer_hint.is_data_blob()) { + is_data_split = false; + break; + } + } + if (!is_data_split) { return MakeOpParallelMatchSignatureMismatch(); } + if (parallel_ctx->policy() == kDataParallel) { return MakeOpParallelMatchSuccess(); } + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kDataParallel); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + for (const auto& bn : op().input_bns()) { (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); } + for (const auto& bn : op().output_bns()) { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); + } + } +}; + +class BroadcastOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastOpParallelSignature); + ~BroadcastOpParallelSignature() override = default; + + BroadcastOpParallelSignature(const Operator* op) : OpParallelSignature(op) { + CHECK_EQ(op->input_bns().size(), 1); + CHECK(op->model_bns().empty()); + CHECK(op->const_model_bns().empty()); + } + + const std::string Description() const override { return op().op_name() + ": (B,) -> (B, ...)"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + if (!SbpInferHint4Ibn(op().SoleIbn()).sbp_parallel().has_broadcast_parallel()) { + return MakeOpParallelMatchSignatureMismatch(); + } + int64_t expected_parallel_num = SbpInferHint4Ibn(op().SoleIbn()).parallel_num(); + bool parallel_policy_matched = (parallel_ctx->policy() == kDataParallel); + bool parallel_num_matched = (parallel_ctx->parallel_num() == expected_parallel_num); + if (parallel_policy_matched && parallel_num_matched) { + return MakeOpParallelMatchSuccess(); + } else { + OpParallelMatchResult ret; + if (!parallel_policy_matched) { + auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error(); + err->set_configured(parallel_ctx->policy()); + err->set_expected(kDataParallel); + } else { + auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error(); + err->set_configured(parallel_ctx->parallel_num()); + err->set_expected(parallel_num_matched); + } + return ret; + } + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + for (const auto& bn : op().input_bns()) { (*bn2sbp)[bn].mutable_broadcast_parallel(); } + for (const auto& bn : op().output_bns()) { (*bn2sbp)[bn].mutable_broadcast_parallel(); } + } +}; + +class DS_MB_2_DS_OpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(DS_MB_2_DS_OpParallelSignature); + ~DS_MB_2_DS_OpParallelSignature() override = default; + + DS_MB_2_DS_OpParallelSignature(const Operator* op) : OpParallelSignature(op) { + std::vector<std::string> model_input_bns; + for (const auto& bn : op->input_bns()) { + if (op->IsInputBlobAllowedModelSplit(bn)) { + model_input_bns.push_back(bn); + } else { + data_input_bns_.push_back(bn); + } + } + CHECK_GT(data_input_bns_.size(), 0); + CHECK_EQ(model_input_bns.size(), 1); + model_input_bn_ = model_input_bns.at(0); + } + + const std::string Description() const override { + return op().op_name() + ": (B, S(0), ...) -> (S(0), ...)"; + } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + const auto& sbp_infer_hint = SbpInferHint4Ibn(model_input_bn_); + if (!sbp_infer_hint.is_model_broadcast()) { return MakeOpParallelMatchSignatureMismatch(); } + bool parallel_policy_matched = (parallel_ctx->policy() == kDataParallel); + bool parallel_num_matched = (parallel_ctx->parallel_num() == sbp_infer_hint.parallel_num()); + if (!parallel_policy_matched || !parallel_num_matched) { + OpParallelMatchResult ret; + if (!parallel_policy_matched) { + auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error(); + err->set_configured(parallel_ctx->policy()); + err->set_expected(kDataParallel); + } + if (!parallel_num_matched) { + auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error(); + err->set_configured(parallel_ctx->parallel_num()); + err->set_expected(parallel_num_matched); + } + return ret; + } + return MakeOpParallelMatchSuccess(); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + for (const auto& bn : data_input_bns_) { (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); } + (*bn2sbp)[model_input_bn_].mutable_broadcast_parallel(); + for (const auto& bn : op().output_bns()) { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis(0); + } + } + + private: + std::vector<std::string> data_input_bns_; + std::string model_input_bn_; +}; + +class SoleIbnOpModelSplitOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(SoleIbnOpModelSplitOpParallelSignature); + ~SoleIbnOpModelSplitOpParallelSignature() override = default; + + SoleIbnOpModelSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": (S,) -> (S, ...)"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + const SbpInferHint& sbp_infer_hint = SbpInferHint4Ibn(op().SoleIbn()); + if (!(sbp_infer_hint.is_model_split() + || (sbp_infer_hint.is_data_split() && sbp_infer_hint.split_axis() > 0))) { + return MakeOpParallelMatchSignatureMismatch(); + } + int64_t expected_parallel_num = sbp_infer_hint.parallel_num(); + bool parallel_policy_matched = (parallel_ctx->policy() == kModelParallel); + bool parallel_num_matched = (parallel_ctx->parallel_num() == expected_parallel_num); + if (!(parallel_policy_matched && parallel_num_matched)) { + OpParallelMatchResult ret; + if (!parallel_policy_matched) { + auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_policy_error(); + err->set_configured(parallel_ctx->policy()); + err->set_expected(kModelParallel); + } + if (!parallel_num_matched) { + auto* err = ret.mutable_fail()->mutable_conf_error()->mutable_parallel_num_error(); + err->set_configured(parallel_ctx->parallel_num()); + err->set_expected(parallel_num_matched); + } + return ret; + } + return MakeOpParallelMatchSuccess(); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + (*bn2sbp)[op().SoleIbn()] = SbpInferHint4Ibn(op().SoleIbn()).sbp_parallel(); + for (const auto& bn : op().output_bns()) { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis( + op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, bn)); + } + } +}; + +class ModelBnOpModelSplitOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(ModelBnOpModelSplitOpParallelSignature); + ~ModelBnOpModelSplitOpParallelSignature() override = default; + + ModelBnOpModelSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { + return op().op_name() + ": (B, ...) -> (S, ...)"; + } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); } + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + for (const auto& bn : op().input_bns()) { (*bn2sbp)[bn].mutable_broadcast_parallel(); } + for (const auto& bn : op().output_bns()) { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis( + op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, bn)); + } + } +}; + +class DB_MS_2_MS_OpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(DB_MS_2_MS_OpParallelSignature); + ~DB_MS_2_MS_OpParallelSignature() override = default; + + DB_MS_2_MS_OpParallelSignature(const Operator* op, std::function<bool(int32_t)> IsExpectedAxis) + : OpParallelSignature(op), IsExpectedAxis_(IsExpectedAxis) { + std::vector<std::string> model_input_bns; + for (const auto& bn : op->input_bns()) { + if (op->IsInputBlobAllowedModelSplit(bn)) { + model_input_bns.push_back(bn); + } else { + data_input_bns_.push_back(bn); + } + } + CHECK_GT(data_input_bns_.size(), 0); + CHECK_EQ(model_input_bns.size(), 1); + model_input_bn_ = model_input_bns.at(0); + } + + const std::string Description() const override { + return op().op_name() + ": (B, S, ...) -> (S, ...)"; + } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + const SbpInferHint& model_sbp_infer_hint = SbpInferHint4Ibn(model_input_bn_); + if (!(model_sbp_infer_hint.is_model_split() + && IsValidSplit(model_sbp_infer_hint.split_axis()))) { + return MakeOpParallelMatchSignatureMismatch(); + } + if (parallel_ctx->policy() == kModelParallel) { return MakeOpParallelMatchSuccess(); } + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + for (const auto& bn : data_input_bns_) { (*bn2sbp)[bn].mutable_broadcast_parallel(); } + (*bn2sbp)[model_input_bn_] = SbpInferHint4Ibn(model_input_bn_).sbp_parallel(); + for (const auto& bn : op().output_bns()) { + (*bn2sbp)[bn].mutable_split_parallel()->set_axis( + op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, bn)); + } + } + + private: + bool IsValidSplit(int32_t axis) const { return axis != -1 && IsExpectedAxis_(axis); } + + const std::function<bool(int32_t)> IsExpectedAxis_; + std::vector<std::string> data_input_bns_; + std::string model_input_bn_; +}; + +} // namespace + +std::unique_ptr<const OpParallelSignature> MakeDataSplitOpParallelSignature(const Operator* op) { + return std::unique_ptr<const OpParallelSignature>(new DataSplitOpParallelSignature(op)); +} + +std::unique_ptr<const OpParallelSignature> MakeBroadcastOpParallelSignature(const Operator* op) { + return std::unique_ptr<const OpParallelSignature>(new BroadcastOpParallelSignature(op)); +} + +std::unique_ptr<const OpParallelSignature> MakeModelSplitOpParallelSignature(const Operator* op) { + if (op->IsSoleInputBlobAllowedModelSplit()) { + return std::unique_ptr<const OpParallelSignature>( + new SoleIbnOpModelSplitOpParallelSignature(op)); + } else { + CHECK(!op->model_bns().empty() || !op->const_model_bns().empty()); + return std::unique_ptr<const OpParallelSignature>( + new ModelBnOpModelSplitOpParallelSignature(op)); + } +} + +std::unique_ptr<const OpParallelSignature> Make_DS_MB_2_DS_OpParallelSignature(const Operator* op) { + return std::unique_ptr<const OpParallelSignature>(new DS_MB_2_DS_OpParallelSignature(op)); +} + +std::unique_ptr<const OpParallelSignature> Make_DB_MS_2_MS_OpParallelSignature( + const Operator* op, std::function<bool(int32_t)> IsExpectedAxis) { + return std::unique_ptr<const OpParallelSignature>( + new DB_MS_2_MS_OpParallelSignature(op, IsExpectedAxis)); +} + +} // namespace oneflow diff --git a/oneflow/core/operator/op_parallel_signature.h b/oneflow/core/operator/op_parallel_signature.h new file mode 100644 index 0000000000000000000000000000000000000000..94b034b6973c2f620e15406da2627a238bed087a --- /dev/null +++ b/oneflow/core/operator/op_parallel_signature.h @@ -0,0 +1,63 @@ +#ifndef ONEFLOW_CORE_OPERATOR_OP_PARALLEL_SIGNATURE_H_ +#define ONEFLOW_CORE_OPERATOR_OP_PARALLEL_SIGNATURE_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/operator/op_parallel_match_result.pb.h" +#include "oneflow/core/job/sbp_infer_hint.h" + +namespace oneflow { + +class Operator; + +class OpParallelSignature { + public: + virtual ~OpParallelSignature() = default; + virtual const std::string Description() const = 0; + virtual const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const = 0; + virtual void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const = 0; + + protected: + OpParallelSignature(const Operator* op) : op_(op) {} + const Operator& op() const { return *op_; } + + private: + const Operator* op_; +}; + +const OpParallelMatchResult MakeOpParallelMatchSuccess(); +const OpParallelMatchResult MakeOpParallelMatchSignatureMismatch(); +const OpParallelMatchResult MakeOpParallelMatchParallelPolicyError(ParallelPolicy configured, + ParallelPolicy expected); +const OpParallelMatchResult MakeOpParallelMatchParallelNumError(int64_t configured, + int64_t expected); + +class Operator; + +// (S(0), ...) -> (S(0), ...) +std::unique_ptr<const OpParallelSignature> MakeDataSplitOpParallelSignature(const Operator* op); + +// (S,) -> (S, ...) or (B, ...) -> (S, ...) +std::unique_ptr<const OpParallelSignature> MakeModelSplitOpParallelSignature(const Operator* op); + +// (B,) -> (B, ...) +std::unique_ptr<const OpParallelSignature> MakeBroadcastOpParallelSignature(const Operator* op); + +// (B, S(0), ...) -> (S(0), ...) +// return blobs: data splitted +// intput blobs: split data input blobs and broadcast model input blobs +std::unique_ptr<const OpParallelSignature> Make_DS_MB_2_DS_OpParallelSignature(const Operator* op); + +// (B, S, ...) -> (S, ...) +// return blobs: model splitted +// input blobs: broadcast data input blobs and split model input blobs +std::unique_ptr<const OpParallelSignature> Make_DB_MS_2_MS_OpParallelSignature( + const Operator* op, std::function<bool(int32_t)> IsExpectedAxis); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_OP_PARALLEL_SIGNATURE_H_ diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 8d7aa8188998432edc035124633592abcc2b2922..7e050ecb9241c817e51e4da5aa119a11990561eb 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -1,5 +1,6 @@ #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/common/balanced_splitter.h" namespace oneflow { @@ -73,17 +74,23 @@ const std::string& Operator::SoleBbbn() const { } void Operator::InferBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const { - InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, EnrollOpCtx); + InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, record_piece_size, EnrollOpCtx); if (op_attribute_.model_bns().size() > 0) { InferTotalInstanceNumDesc(GetBlobDesc4BnInOp, parallel_ctx, EnrollOpCtx); } } void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const { + InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, record_piece_size); +} + +void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const { InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); } @@ -101,25 +108,149 @@ void Operator::InferBwBufBlobDescsIf( } } +void Operator::InferOutputBlobTimeShapeIf( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, Shape* time_shape) const { + InferOutputBlobTimeShape(GetTimeShape4BnInOp, parallel_ctx, time_shape); +} + +void Operator::InferOutputBlobTimeShape( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*, + Shape* time_shape) const { + for (const std::string& bn : input_bns()) { + CHECK_EQ(*GetTimeShape4BnInOp(input_bns().Get(0)), *GetTimeShape4BnInOp(bn)); + } + if (input_bns().empty() == false) { + *time_shape = *GetTimeShape4BnInOp(input_bns().Get(0)); + } else { + *time_shape = Shape( + {Global<JobDesc>::Get()->TotalBatchNum(), Global<JobDesc>::Get()->NumOfPiecesInBatch()}); + } +} + +int32_t Operator::OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const { + if (IsSoleInputBlobAllowedModelSplit()) { + return SbpInferHint4Ibn(SoleIbn()).split_axis(); + } else { + UNIMPLEMENTED(); + return -1; + } +} + +void Operator::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + bool has_model = !(model_bns().empty() && const_model_bns().empty()); + op_parallel_signatures->emplace_back(MakeDataSplitOpParallelSignature(this)); + if (IsSoleInputBlobAllowedModelSplit()) { + CHECK(!has_model); + op_parallel_signatures->emplace_back(MakeModelSplitOpParallelSignature(this)); + op_parallel_signatures->emplace_back(MakeBroadcastOpParallelSignature(this)); + } else if (has_model) { + for (const auto& ibn : input_bns()) { CHECK(!IsInputBlobAllowedModelSplit(ibn)); } + op_parallel_signatures->emplace_back(MakeModelSplitOpParallelSignature(this)); + } else if (input_bns().size() == 1) { + op_parallel_signatures->emplace_back(MakeBroadcastOpParallelSignature(this)); + } else { + // do nothing + } +} + +void Operator::InferInputOutputSbpParallelIf( + std::function<SbpParallel*(const std::string&)> SbpParallel4BnInOp, + std::function<const SbpInferHint&(const std::string&)> SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const { + std::vector<std::unique_ptr<const OpParallelSignature>> op_parallel_signatures; + GetOpParallelSignatures(&op_parallel_signatures); + std::vector<OpParallelMatchResult> match_results; + for (const auto& signature : op_parallel_signatures) { + match_results.push_back(signature->GetMatchResult(SbpInferHint4Ibn, parallel_ctx)); + } + int32_t match_success_cnt = 0; + for (const auto& result : match_results) { + if (result.has_success()) { ++match_success_cnt; } + } + if (match_success_cnt == 1) { + const OpParallelSignature* match_signature = nullptr; + FOR_RANGE(int32_t, i, 0, op_parallel_signatures.size()) { + if (match_results.at(i).has_success()) { + match_signature = op_parallel_signatures.at(i).get(); + } + } + HashMap<std::string, SbpParallel> bn2sbp; + match_signature->GenerateSignature(SbpInferHint4Ibn, &bn2sbp); + for (const auto& pair : bn2sbp) { + auto* sbp_parallel = SbpParallel4BnInOp(pair.first); + *sbp_parallel = pair.second; + } + } else if (match_success_cnt == 0) { + std::stringstream ss; + FOR_RANGE(int32_t, i, 0, op_parallel_signatures.size()) { + CHECK(match_results.at(i).has_fail()); + const auto& failed_msg = match_results.at(i).fail(); + ss << "op_parallel_signature match failed\n" + << op_parallel_signatures.at(i)->Description() << ":\n"; + if (failed_msg.has_signature_mismatch()) { + ss << "\t" + << "signature mismatch" + << "\n"; + } else { + CHECK(failed_msg.has_conf_error()); + if (failed_msg.conf_error().has_parallel_policy_error()) { + const auto& policy_error_msg = failed_msg.conf_error().parallel_policy_error(); + ss << "\t" + << "parallel_policy conf error, configured: " + << ParallelPolicy_Name(policy_error_msg.configured()) + << ", expected: " << ParallelPolicy_Name(policy_error_msg.expected()) << "\n"; + } + if (failed_msg.conf_error().has_parallel_num_error()) { + const auto& parallel_num_error_msg = failed_msg.conf_error().parallel_num_error(); + ss << "\t" + << "parallel_num conf error, configured: " << parallel_num_error_msg.configured() + << ", expected: " << parallel_num_error_msg.expected() << "\n"; + } + } + } + LOG(FATAL) << ss.str(); + } else { + UNIMPLEMENTED(); + } +} + +bool Operator::IsSoleInputBlobAllowedModelSplit() const { + return input_bns().size() == 1 && IsInputBlobAllowedModelSplit(SoleIbn()); +} + +void Operator::InferIsModelBlob4OutputBlobsIf( + std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const { + InferIsModelBlob4OutputBlobs(IsModelBlob4BnInOp); +} + +void Operator::InferIsModelBlob4OutputBlobs( + std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const { + bool is_model_blob = (IsSoleInputBlobAllowedModelSplit() && *IsModelBlob4BnInOp(SoleIbn())); + for (const std::string& obn : output_bns()) { *IsModelBlob4BnInOp(obn) = is_model_blob; } +} + void Operator::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const OpContext* op_ctx) const { InferBwBufBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); } -void Operator::FixParallelDesc(ParallelDesc* pr_desc) const { - if (model_bns().empty() && const_model_bns().empty()) { - pr_desc->set_policy(ParallelPolicy::kDataParallel); - } - if (pr_desc->policy() == kModelParallel && MaxModelSplitNum() != -1) { - pr_desc->RemoveNeedlessDevice(op_name(), MaxModelSplitNum()); +void Operator::FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* ctx) const { + VirtualFixInDiffBlobDescs(GetBlobDesc4BnInOp, ctx); + for (const std::string& input_diff_bn : input_diff_bns()) { + BlobDesc* blob_desc = GetBlobDesc4BnInOp(input_diff_bn); + if (!blob_desc) { continue; } + blob_desc->set_has_loss_instance_num_field(true); } - if (pr_desc->policy() == kDataParallel) { - pr_desc->RemoveNeedlessDevice(op_name(), Global<JobDesc>::Get()->PieceSize()); - } - VirtualFixParallelDesc(pr_desc); } +void Operator::FixParallelDesc(ParallelDesc* pr_desc) const { VirtualFixParallelDesc(pr_desc); } + void Operator::FixLbiWhenShareModel(const std::string& shared_op_name) { for (const std::string& model_bn : model_bns()) { mut_bn_in_op2lbi()->at(model_bn).set_op_name(shared_op_name); diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 135752a438ff35b966cfaf349c5430111c963698..c7e4bb4840525318a1300d4f98476c85c29062eb 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -6,10 +6,11 @@ #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/job/keyword.h" #include "oneflow/core/job/parallel_desc.h" -#include "oneflow/core/job/placement.pb.h" +#include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/kernel/kernel.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/register/blob_desc.h" +#include "oneflow/core/operator/op_parallel_signature.h" namespace oneflow { @@ -28,7 +29,8 @@ class Operator { // void InitFromOpConf(const OperatorConf& op_conf); virtual void InitFromOpConf() = 0; - virtual bool IsElemWiseOp() const { return false; } + bool IsSoleInputBlobAllowedModelSplit() const; + virtual bool IsInputBlobAllowedModelSplit(const std::string& ibn) const = 0; ActivationType GetActivationType() const; @@ -37,6 +39,7 @@ class Operator { virtual bool IsLossOp() const { return false; } virtual bool IsRecurrentOp() const { return false; } virtual bool IsEmbeddingLookupOp() const { return false; } + virtual bool IsAllOutputConst() const { return false; } bool NeedOutBlobWhenBackwardIf() const { return NeedOutBlobWhenBackward() || (GetActivationType() != ActivationType::kNone); @@ -44,6 +47,8 @@ class Operator { virtual bool NeedOutBlobWhenBackward() const { return true; } bool NeedInBlobWhenBackwardIf() const { return NeedInBlobWhenBackward(); } virtual bool NeedInBlobWhenBackward() const { return true; } + virtual bool IsForwardInplace() const { return false; } + virtual bool IsBackwardInplace() const { return false; } // bn_in_op <-> lbi const LogicalBlobId& BnInOp2Lbi(const std::string& bn_in_op) const; @@ -114,10 +119,13 @@ class Operator { // Read: shape of input_blobs // Write: shape of output_blobs, model_blobs, data_tmp_blobs, const_model_blobs, const_buf_blobs void InferBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*, std::function<void(OpContext*)> EnrollOpCtx) const; + const ParallelContext*, int64_t record_piece_size, + std::function<void(OpContext*)> EnrollOpCtx) const; virtual void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*, + const ParallelContext*, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const; + virtual void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*, int64_t record_piece_size) const; virtual void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const; void InferBwBufBlobDescsIf(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, @@ -131,17 +139,39 @@ class Operator { const ParallelContext*) const { UNIMPLEMENTED(); } + // Infer out blob's time shape + void InferOutputBlobTimeShapeIf( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*, + Shape* time_shape) const; + virtual void InferOutputBlobTimeShape( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*, + Shape* time_shape) const; + // Infer blob's SbpParallel + void InferInputOutputSbpParallelIf( + std::function<SbpParallel*(const std::string&)> SbpParallel4BnInOp, + std::function<const SbpInferHint&(const std::string&)> SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const; + // Infer is_model_blob + void InferIsModelBlob4OutputBlobsIf( + std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const; virtual void FixInDiffBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*) const {} + const ParallelContext*) const; + virtual void VirtualFixInDiffBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const {} void FixParallelDesc(ParallelDesc* pr_desc) const; void FixLbiWhenShareModel(const std::string& shared_op_name); - virtual int32_t ModelSplitAxis() const { return -1; } - virtual int32_t MaxModelSplitNum() const { return -1; } + virtual int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const; void GenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, bool is_forward, const ParallelContext*, KernelConf*, const OpContext*) const; protected: + virtual void InferIsModelBlob4OutputBlobs( + std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const; + int64_t cudnn_buf_limit_byte() const; virtual PbMessage* MutableCustomizedKernelConf(KernelConf*) const { @@ -220,6 +250,8 @@ class Operator { void StrFieldTolower(const std::string& field_name); private: + virtual void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const; LogicalBlobId dtbn2lbi(const std::string& data_tmp_bn) const; LogicalBlobId fbbn2lbi(const std::string& fw_buf_bn) const { return dtbn2lbi(fw_buf_bn); } LogicalBlobId bbbn2lbi(const std::string& bw_buf_bn) const { return dtbn2lbi(bw_buf_bn); } @@ -265,6 +297,12 @@ struct OnlyCpuSupportPredicator { std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf); +inline std::shared_ptr<Operator> ConstructOp(const OperatorConf& op_conf, DeviceType device_type) { + OperatorConf dev_op_conf = op_conf; + dev_op_conf.set_device_type(device_type); + return ConstructOp(dev_op_conf); +} + void EraseEmptyBnInVec(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, PbRpf<std::string>* bns); @@ -283,6 +321,14 @@ inline LogicalBlobId GenLogicalBlobId(const std::string& lbn) { return lbi; } +inline std::string GenLogicalBlobName(const LogicalBlobId& lbi) { + CHECK_EQ(lbi.has_op_name(), true); + CHECK_EQ(lbi.has_blob_name(), true); + CHECK_EQ(lbi.has_clone_id(), false); + CHECK_EQ(lbi.is_packed_id(), false); + return lbi.op_name() + "/" + lbi.blob_name(); +} + } // namespace oneflow #endif // ONEFLOW_CORE_OPERATOR_OPERATOR_H_ diff --git a/oneflow/core/operator/pack_op.cpp b/oneflow/core/operator/pack_op.cpp index a38c9a531b4df0d4cff51638fe6b02313795f130..6dec573945d1298316f21dcffe8fb9d96a256054 100644 --- a/oneflow/core/operator/pack_op.cpp +++ b/oneflow/core/operator/pack_op.cpp @@ -9,20 +9,20 @@ void PackOp::InitFromOpConf() { EnrollOutputBn("out", false); } -int32_t PackOp::GetPackNum(int64_t parallel_num) const { +void PackOp::InferOutputBlobTimeShape( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, Shape* time_shape) const { + std::vector<int64_t> dim_vec(GetTimeShape4BnInOp("in")->dim_vec()); + int32_t pack_num = GetPackNum(); + CHECK_EQ(pack_num, dim_vec.back()); + dim_vec.pop_back(); + *time_shape = Shape(dim_vec); +} + +int32_t PackOp::GetPackNum() const { CHECK(op_conf().has_pack_conf()); const PackOpConf& conf = op_conf().pack_conf(); - if (conf.has_pack_num()) { - return conf.pack_num(); - } else if (conf.has_pack_num_per_record()) { - CHECK_EQ(Global<JobDesc>::Get()->PieceSize() % parallel_num, 0); - int64_t pack_num = - Global<JobDesc>::Get()->PieceSize() / parallel_num * conf.pack_num_per_record(); - CHECK_LE(pack_num, static_cast<int64_t>(MaxVal<int32_t>())); - return static_cast<int32_t>(pack_num); - } else { - UNIMPLEMENTED(); - } + return conf.pack_num(); } REGISTER_OP(OperatorConf::kPackConf, PackOp); diff --git a/oneflow/core/operator/pack_op.h b/oneflow/core/operator/pack_op.h index 206d5a323a8a52b5c220759149c21d408e4d24be..835514527c5f7169698046cc3d1fb77a01923713 100644 --- a/oneflow/core/operator/pack_op.h +++ b/oneflow/core/operator/pack_op.h @@ -15,10 +15,16 @@ class PackOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override { return op_conf().pack_conf(); } LogicalNode* NewProperLogicalNode() { return new PackForwardLogicalNode; } + void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, + Shape* time_shape) const override; bool NeedInBlobWhenBackward() const override { return true; } bool NeedOutBlobWhenBackward() const override { return false; } - int32_t GetPackNum(int64_t parallel_num) const; + int32_t GetPackNum() const; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/parallel_cast_op.cpp b/oneflow/core/operator/parallel_cast_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7d803aa0df41884cdc05a024abecee07a1fd484 --- /dev/null +++ b/oneflow/core/operator/parallel_cast_op.cpp @@ -0,0 +1,58 @@ +#include "oneflow/core/operator/parallel_cast_op.h" + +namespace oneflow { + +namespace { + +SbpParallel GetSbpParallel(const ParallelCastOpConf& conf) { + SbpParallel ret; + if (conf.has_split_parallel()) { + *ret.mutable_split_parallel() = conf.split_parallel(); + } else if (conf.has_broadcast_parallel()) { + *ret.mutable_broadcast_parallel() = conf.broadcast_parallel(); + } else { + UNIMPLEMENTED(); + } + return ret; +} + +class ParallelCastOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(ParallelCastOpParallelSignature); + ~ParallelCastOpParallelSignature() override = default; + + ParallelCastOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": A -> A"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp, + const ParallelContext* parallel_ctx) const override { + const auto& configured_sbp_parallel = GetSbpParallel(op().op_conf().parallel_cast_conf()); + if (SbpInferHint4BnInOp("in").sbp_parallel() == configured_sbp_parallel + && parallel_ctx->parallel_num() != SbpInferHint4BnInOp("in").parallel_num()) { + return MakeOpParallelMatchParallelNumError(parallel_ctx->parallel_num(), + SbpInferHint4BnInOp("in").parallel_num()); + } + return MakeOpParallelMatchSuccess(); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + const auto& sbp_parallel = GetSbpParallel(op().op_conf().parallel_cast_conf()); + (*bn2sbp)["in"] = sbp_parallel; + (*bn2sbp)["out"] = sbp_parallel; + } +}; + +} // namespace + +void ParallelCastOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(new ParallelCastOpParallelSignature(this)); +} + +REGISTER_OP(OperatorConf::kParallelCastConf, ParallelCastOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/parallel_cast_op.h b/oneflow/core/operator/parallel_cast_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b2805dd05438dc6f8204745ae5ac8ffb86e46120 --- /dev/null +++ b/oneflow/core/operator/parallel_cast_op.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_OPERATOR_PARALLEL_CAST_OP_H_ +#define ONEFLOW_CORE_OPERATOR_PARALLEL_CAST_OP_H_ + +#include "oneflow/core/operator/identity_op.h" + +namespace oneflow { + +class ParallelCastOp final : public IdentityOp { + public: + OF_DISALLOW_COPY_AND_MOVE(ParallelCastOp); + ParallelCastOp() = default; + ~ParallelCastOp() override = default; + + const PbMessage& GetCustomizedConf() const override { return op_conf().parallel_cast_conf(); } + + private: + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_PARALLEL_CAST_OP_H_ diff --git a/oneflow/core/operator/pooling_op.h b/oneflow/core/operator/pooling_op.h index db08d6a0c7fecd4fe1bfb54b3a2f51f88bbe7c80..a23153a21189622048747bddf1285e5175bc1252 100644 --- a/oneflow/core/operator/pooling_op.h +++ b/oneflow/core/operator/pooling_op.h @@ -24,6 +24,8 @@ class PoolingOp : public Operator { KernelConf* kernel_conf) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void CheckPoolSizeAndStrides() const; Shape GetOutShape(int64_t in_n, int64_t in_c, const std::vector<int64_t>& out) const; }; diff --git a/oneflow/core/operator/print_op.h b/oneflow/core/operator/print_op.h index 030d02d25ab1f6741b0b5cdf1cc11dcec9b3a69c..a02fac123b27d1a0e4f7c0653e6444af5cb0d2af 100644 --- a/oneflow/core/operator/print_op.h +++ b/oneflow/core/operator/print_op.h @@ -20,6 +20,8 @@ class PrintOp final : public Operator { const ParallelContext* parallel_ctx) const override {} private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override; }; diff --git a/oneflow/core/operator/record_load_op.cpp b/oneflow/core/operator/record_load_op.cpp index 28f8ba5872eadcac5917bf2f28cf00a09488ef51..b119e979ccf05a0cc3724a03c8bae34f81538118 100644 --- a/oneflow/core/operator/record_load_op.cpp +++ b/oneflow/core/operator/record_load_op.cpp @@ -10,14 +10,21 @@ void RecordLoadOp::InitFromOpConf() { const PbMessage& RecordLoadOp::GetCustomizedConf() const { return op_conf().record_load_conf(); } void RecordLoadOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const { + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const { BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize(); - CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0); - out_blob_desc->mut_shape() = Shape({global_piece_size / parallel_ctx->parallel_num()}); + CHECK_EQ(record_piece_size % parallel_ctx->parallel_num(), 0); + out_blob_desc->mut_shape() = Shape({record_piece_size / parallel_ctx->parallel_num()}); out_blob_desc->set_data_type(kOFRecord); } +void RecordLoadOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { + int64_t device_piece_size = GetBlobDesc4BnInOp("out")->shape().At(0); + kernel_conf->mutable_record_load_conf()->set_device_piece_size(device_piece_size); +} + REGISTER_CPU_OP(OperatorConf::kRecordLoadConf, RecordLoadOp); } // namespace oneflow diff --git a/oneflow/core/operator/record_load_op.h b/oneflow/core/operator/record_load_op.h index cc069578176561ac1c650f9f7a122f3e0c002264..8ad1274de7ee4344bff3cdd86f498a66fc65a859 100644 --- a/oneflow/core/operator/record_load_op.h +++ b/oneflow/core/operator/record_load_op.h @@ -15,11 +15,15 @@ class RecordLoadOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override; + const ParallelContext* parallel_ctx, + int64_t record_piece_size) const override; + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*, KernelConf*) const override; LogicalNode* NewProperLogicalNode() override { return new RecordLoadLogicalNode; } private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/recurrent_op.cpp b/oneflow/core/operator/recurrent_op.cpp index 5afbdbcc0d84a8e22e040c89c0f012733d4d02d1..84948bb42bfb30a7685cb2eb7adf51f275888ef4 100644 --- a/oneflow/core/operator/recurrent_op.cpp +++ b/oneflow/core/operator/recurrent_op.cpp @@ -19,10 +19,6 @@ void RecurrentOp::InitFromOpConf() { VirtualInitFromOpConf(); } -int32_t RecurrentOp::MaxModelSplitNum() const { - return GetValFromCustomizedConf<int32_t>("hidden_size"); -} - void RecurrentOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); diff --git a/oneflow/core/operator/recurrent_op.h b/oneflow/core/operator/recurrent_op.h index cef95876e727ea184d9d9e5dbf13f63ad57dfc49..ae210368e45a69b984f17fd8bf66405c5bec153a 100644 --- a/oneflow/core/operator/recurrent_op.h +++ b/oneflow/core/operator/recurrent_op.h @@ -16,16 +16,21 @@ class RecurrentOp : public Operator { void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; - int32_t ModelSplitAxis() const override { return 1; } - int32_t MaxModelSplitNum() const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override { + return 1; + } private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } virtual void VirtualInitFromOpConf() { UNIMPLEMENTED(); } virtual void VirtualInferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override; LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/reduce_add_op.h b/oneflow/core/operator/reduce_add_op.h index 276abf7fc78035952c2f7329d98d95c2c3eb8f6d..12d37e36475d682e9d46885154188d49e6a44863 100644 --- a/oneflow/core/operator/reduce_add_op.h +++ b/oneflow/core/operator/reduce_add_op.h @@ -18,6 +18,8 @@ class ReduceAddOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/reduce_concat_op.cpp b/oneflow/core/operator/reduce_concat_op.cpp index 09861d11d8b37173cfb7edd624380b7e6821a350..65b4d171241431200a379efa95588a8d0c8ffb79 100644 --- a/oneflow/core/operator/reduce_concat_op.cpp +++ b/oneflow/core/operator/reduce_concat_op.cpp @@ -4,6 +4,11 @@ namespace oneflow { +struct ReduceConcatOpCtx : public OpContext { + ReduceConcatOpCtx(const int64_t elem_cnt) : out_blob_elem_cnt(elem_cnt) {} + int64_t out_blob_elem_cnt; +}; + void ReduceConcatOp::InitFromOpConf() { CHECK(op_conf().has_reduce_concat_conf()); for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) { @@ -17,29 +22,49 @@ const PbMessage& ReduceConcatOp::GetCustomizedConf() const { } void ReduceConcatOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const { - int32_t in_num = op_conf().reduce_concat_conf().in_num(); - CHECK_GE(in_num, 2); - BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0)); + const ParallelContext* parallel_ctx, int64_t record_piece_size, + std::function<void(OpContext*)> EnrollOpCtx) const { + const BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0)); + const DataType data_type = first_in_blob->data_type(); + for (int32_t i = 1; i < op_conf().reduce_concat_conf().in_num(); ++i) { + CHECK_EQ(data_type, GetBlobDesc4BnInOp(input_bns().Get(i))->data_type()); + } + BlobDesc* out_blob = GetBlobDesc4BnInOp(SoleObn()); *out_blob = *first_in_blob; - int64_t out_blob_elem_cnt = first_in_blob->shape().elem_cnt(); - for (int32_t i = 1; i < in_num; ++i) { - out_blob_elem_cnt += GetBlobDesc4BnInOp(input_bns().Get(i))->shape().elem_cnt(); + int64_t in_blob_body_size_sum = 0; + for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) { + in_blob_body_size_sum += + RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody(); } + const int64_t data_type_byte_size = + static_cast<int64_t>(GetSizeOfDataType(first_in_blob->data_type())); + CHECK_EQ(in_blob_body_size_sum % data_type_byte_size, 0); + const int64_t out_blob_elem_cnt = + RoundUp(in_blob_body_size_sum / data_type_byte_size, parallel_ctx->parallel_num()); out_blob->mut_shape() = Shape({out_blob_elem_cnt}); + + // construct reduce_concat_op_ctx for later CHECK in ReduceConcatOp::VirtualGenKernelConf + ReduceConcatOpCtx* reduce_concat_op_ctx = new ReduceConcatOpCtx(out_blob_elem_cnt); + EnrollOpCtx(reduce_concat_op_ctx); } void ReduceConcatOp::VirtualGenKernelConf( - std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, - KernelConf* kernel_conf) const { + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf, const OpContext* op_ctx) const { ReduceConcatKernelConf* reduce_concat_conf = kernel_conf->mutable_reduce_concat_conf(); int64_t offset = 0; for (int32_t i = 0; i < op_conf().reduce_concat_conf().in_num(); ++i) { reduce_concat_conf->mutable_data_offset()->Add(offset); offset += RtBlobDesc(*(GetBlobDesc4BnInOp(input_bns().Get(i)))).ByteSizeOfBlobBody(); } - CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleObn())).ByteSizeOfBlobBody()); + const int64_t data_type_byte_size = + static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(input_bns().Get(0))->data_type())); + CHECK_EQ(offset % data_type_byte_size, 0); + const int64_t out_blob_elem_cnt = + RoundUp(offset / data_type_byte_size, parallel_ctx->parallel_num()); + const ReduceConcatOpCtx* reduce_concat_op_ctx = static_cast<const ReduceConcatOpCtx*>(op_ctx); + CHECK_EQ(reduce_concat_op_ctx->out_blob_elem_cnt, out_blob_elem_cnt); } LogicalBlobId ReduceConcatOp::obn2lbi(const std::string& output_bn) const { diff --git a/oneflow/core/operator/reduce_concat_op.h b/oneflow/core/operator/reduce_concat_op.h index c2419c5178972343666f205b54df79355e9ec1d2..7e86ebd4f976d36e7116673f39905749e846211e 100644 --- a/oneflow/core/operator/reduce_concat_op.h +++ b/oneflow/core/operator/reduce_concat_op.h @@ -15,11 +15,15 @@ class ReduceConcatOp final : public Operator { const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override; + const ParallelContext* parallel_ctx, int64_t record_piece_size, + std::function<void(OpContext*)> EnrollOpCtx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*, KernelConf*) const override; + const ParallelContext*, KernelConf*, + const OpContext* op_ctx) const override; LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/reduce_gather_op.h b/oneflow/core/operator/reduce_gather_op.h index 2753330de0403791d7d4ab646784301f5cb990be..0b97f9e8f1729f7d242a1d6cc99fda7bac40d44b 100644 --- a/oneflow/core/operator/reduce_gather_op.h +++ b/oneflow/core/operator/reduce_gather_op.h @@ -18,6 +18,8 @@ class ReduceGatherOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*) const override; LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } diff --git a/oneflow/core/operator/reduce_identity_op.cpp b/oneflow/core/operator/reduce_identity_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..baffcc20dc086f865d6cf081cb68edc3a7601b18 --- /dev/null +++ b/oneflow/core/operator/reduce_identity_op.cpp @@ -0,0 +1,26 @@ +#include "oneflow/core/operator/reduce_identity_op.h" + +namespace oneflow { + +void ReduceIdentityOp::InitFromOpConf() { + EnrollInputBn("in", false); + EnrollOutputBn("out", false); +} + +void ReduceIdentityOp::InferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); + CHECK_EQ(GetBlobDesc4BnInOp("out")->shape().elem_cnt() % parallel_ctx->parallel_num(), 0); +} + +LogicalBlobId ReduceIdentityOp::obn2lbi(const std::string& output_bn) const { + LogicalBlobId ret; + ret.set_op_name(op_name()); + ret.set_blob_name(output_bn); + return ret; +} + +REGISTER_OP(OperatorConf::kReduceIdentityConf, ReduceIdentityOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/reduce_identity_op.h b/oneflow/core/operator/reduce_identity_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f5b2d2933bace02abad785a0148885daef1bb51c --- /dev/null +++ b/oneflow/core/operator/reduce_identity_op.h @@ -0,0 +1,32 @@ +#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_IDENTITY_OP_H_ +#define ONEFLOW_CORE_OPERATOR_REDUCE_IDENTITY_OP_H_ + +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/graph/logical_node.h" + +namespace oneflow { + +class ReduceIdentityOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceIdentityOp); + ReduceIdentityOp() = default; + ~ReduceIdentityOp() = default; + + LogicalNode* NewProperLogicalNode() { return new ReduceIdentityLogicalNode; } + void InitFromOpConf() override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + const PbMessage& GetCustomizedConf() const override { return op_conf().reduce_identity_conf(); } + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } + LogicalBlobId obn2lbi(const std::string& output_bn) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_REDUCE_IDENTITY_OP_H_ diff --git a/oneflow/core/operator/reduce_mean_op.cpp b/oneflow/core/operator/reduce_mean_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e39682fcd1bf7876363ea11a4931da36960f592 --- /dev/null +++ b/oneflow/core/operator/reduce_mean_op.cpp @@ -0,0 +1,103 @@ +#include "oneflow/core/operator/reduce_mean_op.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { +namespace { +std::vector<int64_t> KeepDims(const std::vector<int64_t> dim_vec, + const std::vector<int64_t> axis_vec) { + std::vector<int64_t> ret = dim_vec; + for (const auto& axis : axis_vec) { ret[axis] = 1; } + return ret; +} + +std::vector<int64_t> DropDims(const std::vector<int64_t> dim_vec, + const std::vector<int64_t> axis_vec) { + std::vector<int64_t> ret; + std::vector<int32_t> dim2is_reduced(dim_vec.size()); + for (const auto& axis : axis_vec) { dim2is_reduced[axis] = 1; } + FOR_RANGE(int64_t, i, 0, dim_vec.size()) { + if (dim2is_reduced[i] != 1) { ret.push_back(dim_vec[i]); } + } + if (ret.empty()) { ret.push_back(1); } + return ret; +} + +std::vector<int64_t> ShiftAxisIfNegative(std::vector<int64_t> axis_vec, const int64_t num_axes) { + FOR_RANGE(size_t, i, 0, axis_vec.size()) { + if (axis_vec[i] < 0) { axis_vec[i] += num_axes; } + CHECK_LT(axis_vec[i], num_axes); + CHECK_GE(axis_vec[i], 0); + } + return axis_vec; +} + +} // namespace + +void ReduceMeanOp::InitFromOpConf() { + CHECK(op_conf().has_reduce_mean_conf()); + EnrollInputBn("in"); + EnrollOutputBn("out"); + EnrollFwBufBn("fw_tmp"); + EnrollBwBufBn("bw_tmp"); +} + +const PbMessage& ReduceMeanOp::GetCustomizedConf() const { return op_conf().reduce_mean_conf(); } + +void ReduceMeanOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const { + const ReduceMeanOpConf& conf = op_conf().reduce_mean_conf(); + const BlobDesc* in_blob = GetBlobDesc4BnInOp("in"); + *GetBlobDesc4BnInOp("fw_tmp") = *in_blob; + std::vector<int64_t> out_dim_vec; + if (conf.axis().empty()) { + if (conf.keep_dims() == true) { + out_dim_vec.resize(in_blob->shape().NumAxes()); + std::fill(out_dim_vec.begin(), out_dim_vec.end(), 1); + } else { + out_dim_vec = {1}; + } + } else { + const PbRf<int32_t>& axis_repeated = conf.axis(); + std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()}; + axis_vec = ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes()); + std::sort(axis_vec.begin(), axis_vec.end()); + CHECK(std::unique(axis_vec.begin(), axis_vec.end()) == axis_vec.end()) + << "duplicate found in axis"; + if (conf.keep_dims() == true) { + out_dim_vec = KeepDims(in_blob->shape().dim_vec(), axis_vec); + } else { + out_dim_vec = DropDims(in_blob->shape().dim_vec(), axis_vec); + } + } + CHECK(!out_dim_vec.empty()); + BlobDesc* out_blob = GetBlobDesc4BnInOp("out"); + out_blob->set_data_type(in_blob->data_type()); + out_blob->mut_shape() = Shape(out_dim_vec); +} + +void ReduceMeanOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, + KernelConf* kernel_conf) const { + const ReduceMeanOpConf& conf = op_conf().reduce_mean_conf(); + const BlobDesc* in_blob = GetBlobDesc4BnInOp("in"); + std::vector<int64_t> kept_dims; + if (conf.axis().empty()) { + kept_dims.resize(in_blob->shape().NumAxes()); + std::fill(kept_dims.begin(), kept_dims.end(), 1); + } else { + const PbRf<int32_t>& axis_repeated = op_conf().reduce_mean_conf().axis(); + std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()}; + kept_dims = KeepDims(in_blob->shape().dim_vec(), + ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes())); + } + Shape(kept_dims).ToProto(kernel_conf->mutable_reduce_sum_conf()->mutable_kept_dims_shape()); +} + +void ReduceMeanOp::InferBwBufBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const { + *GetBlobDesc4BnInOp("bw_tmp") = *GetBlobDesc4BnInOp("out"); +} + +REGISTER_OP(OperatorConf::kReduceMeanConf, ReduceMeanOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/reduce_mean_op.h b/oneflow/core/operator/reduce_mean_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0477d2d25899146bdf4fea2af5eef235f240b739 --- /dev/null +++ b/oneflow/core/operator/reduce_mean_op.h @@ -0,0 +1,32 @@ +#ifndef ONEFLOW_CORE_OPERATOR_REDUCE_MEAN_OP_H_ +#define ONEFLOW_CORE_OPERATOR_REDUCE_MEAN_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ReduceMeanOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ReduceMeanOp); + ReduceMeanOp() = default; + ~ReduceMeanOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + KernelConf* kernel_conf) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_REDUCE_MEAN_OP_H_ diff --git a/oneflow/core/operator/reduce_scatter_op.h b/oneflow/core/operator/reduce_scatter_op.h index 745c4d1bca108b7cfcdb2ec65e031d67e4e8696b..be911903031770b6c49623493029a2e7874234ab 100644 --- a/oneflow/core/operator/reduce_scatter_op.h +++ b/oneflow/core/operator/reduce_scatter_op.h @@ -18,6 +18,8 @@ class ReduceScatterOp final : public Operator { const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; diff --git a/oneflow/core/operator/reduce_split_op.cpp b/oneflow/core/operator/reduce_split_op.cpp index 78995a57cc7de19e5cb0fd8ccdf2d2865355edc5..da809794d612f6515f4a1aea4f738b78c01d499d 100644 --- a/oneflow/core/operator/reduce_split_op.cpp +++ b/oneflow/core/operator/reduce_split_op.cpp @@ -15,15 +15,21 @@ void ReduceSplitOp::InitFromOpConf() { const PbMessage& ReduceSplitOp::GetCustomizedConf() const { return op_conf().reduce_split_conf(); } void ReduceSplitOp::VirtualGenKernelConf( - std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, - KernelConf* kernel_conf) const { + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { ReduceSplitKernelConf* reduce_split_conf = kernel_conf->mutable_reduce_split_conf(); int64_t offset = 0; for (int32_t i = 0; i < op_conf().reduce_split_conf().out_num(); ++i) { reduce_split_conf->mutable_data_offset()->Add(offset); offset += RtBlobDesc(*(GetBlobDesc4BnInOp(output_bns().Get(i)))).ByteSizeOfBlobBody(); } - CHECK_EQ(offset, RtBlobDesc(*GetBlobDesc4BnInOp(SoleIbn())).ByteSizeOfBlobBody()); + const int64_t data_type_byte_size = + static_cast<int64_t>(GetSizeOfDataType(GetBlobDesc4BnInOp(SoleIbn())->data_type())); + CHECK_EQ(offset % data_type_byte_size, 0); + const int64_t out_blob_elem_cnt_sum = + RoundUp(offset / data_type_byte_size, parallel_ctx->parallel_num()); + const int64_t in_blob_elem_cnt = GetBlobDesc4BnInOp(SoleIbn())->shape().elem_cnt(); + CHECK_EQ(out_blob_elem_cnt_sum, in_blob_elem_cnt); } REGISTER_OP(OperatorConf::kReduceSplitConf, ReduceSplitOp); diff --git a/oneflow/core/operator/reduce_split_op.h b/oneflow/core/operator/reduce_split_op.h index a6f632fe6d3a58842f7f9a67c627bd8e535f933e..d5f02679dc9ed4022973bf0eacabd48d5bb56f2f 100644 --- a/oneflow/core/operator/reduce_split_op.h +++ b/oneflow/core/operator/reduce_split_op.h @@ -13,11 +13,11 @@ class ReduceSplitOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override {} private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); } void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*) const override; LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } diff --git a/oneflow/core/operator/reduce_sum_op.cpp b/oneflow/core/operator/reduce_sum_op.cpp index 45d470a6483bc9f2dc93e8ebbef9deca3425ab65..38367bc1cb8fb2314d4ed860e7caf3d84a00f508 100644 --- a/oneflow/core/operator/reduce_sum_op.cpp +++ b/oneflow/core/operator/reduce_sum_op.cpp @@ -2,55 +2,98 @@ #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { +namespace { +std::vector<int64_t> KeepDims(const std::vector<int64_t> dim_vec, + const std::vector<int64_t> axis_vec) { + std::vector<int64_t> ret = dim_vec; + for (const auto& axis : axis_vec) { ret[axis] = 1; } + return ret; +} + +std::vector<int64_t> DropDims(const std::vector<int64_t> dim_vec, + const std::vector<int64_t> axis_vec) { + std::vector<int64_t> ret; + std::vector<int32_t> dim2is_reduced(dim_vec.size()); + for (const auto& axis : axis_vec) { dim2is_reduced[axis] = 1; } + FOR_RANGE(int64_t, i, 0, dim_vec.size()) { + if (dim2is_reduced[i] != 1) { ret.push_back(dim_vec[i]); } + } + if (ret.empty()) { ret.push_back(1); } + return ret; +} + +std::vector<int64_t> ShiftAxisIfNegative(std::vector<int64_t> axis_vec, const int64_t num_axes) { + FOR_RANGE(size_t, i, 0, axis_vec.size()) { + if (axis_vec[i] < 0) { axis_vec[i] += num_axes; } + CHECK_LT(axis_vec[i], num_axes); + CHECK_GE(axis_vec[i], 0); + } + return axis_vec; +} + +} // namespace void ReduceSumOp::InitFromOpConf() { + CHECK(op_conf().has_reduce_sum_conf()); EnrollInputBn("in"); EnrollOutputBn("out"); - EnrollDataTmpBn("fw_tmp"); + if (op_conf().reduce_sum_conf().has_in_sys()) { + EnrollDataTmpBn("fw_tmp"); + } else { + EnrollFwBufBn("fw_tmp"); + } } const PbMessage& ReduceSumOp::GetCustomizedConf() const { return op_conf().reduce_sum_conf(); } void ReduceSumOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const { + const ReduceSumOpConf& conf = op_conf().reduce_sum_conf(); const BlobDesc* in_blob = GetBlobDesc4BnInOp("in"); - std::vector<int64_t> out_dim_vec = {1}; - if (op_conf().reduce_sum_conf().has_axis()) { - out_dim_vec = in_blob->shape().dim_vec(); - int32_t axis = GetCorrectAxis(GetBlobDesc4BnInOp); - if (op_conf().reduce_sum_conf().keepdims() == true) { - out_dim_vec[axis] = 1; + *GetBlobDesc4BnInOp("fw_tmp") = *in_blob; + std::vector<int64_t> out_dim_vec; + if (conf.axis().empty()) { + if (conf.keep_dims() == true) { + out_dim_vec.resize(in_blob->shape().NumAxes()); + std::fill(out_dim_vec.begin(), out_dim_vec.end(), 1); } else { - out_dim_vec.erase(out_dim_vec.begin() + axis); + out_dim_vec = {1}; } - if (out_dim_vec.empty()) { out_dim_vec.push_back(1); } } else { - BlobDesc* fw_tmp_blob = GetBlobDesc4BnInOp("fw_tmp"); - fw_tmp_blob->mut_shape() = Shape({static_cast<int64_t>( - GetTmpSizeForReduceSum(in_blob->data_type(), in_blob->shape().elem_cnt()))}); - fw_tmp_blob->set_data_type(DataType::kChar); + const PbRf<int32_t>& axis_repeated = conf.axis(); + std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()}; + axis_vec = ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes()); + std::sort(axis_vec.begin(), axis_vec.end()); + CHECK(std::unique(axis_vec.begin(), axis_vec.end()) == axis_vec.end()) + << "duplicate found in axis"; + if (conf.keep_dims() == true) { + out_dim_vec = KeepDims(in_blob->shape().dim_vec(), axis_vec); + } else { + out_dim_vec = DropDims(in_blob->shape().dim_vec(), axis_vec); + } } + CHECK(!out_dim_vec.empty()); BlobDesc* out_blob = GetBlobDesc4BnInOp("out"); - if (op_conf().reduce_sum_conf().has_axis() && GetCorrectAxis(GetBlobDesc4BnInOp) > 0) { - *out_blob = *in_blob; - } else { - out_blob->set_data_type(in_blob->data_type()); - } + out_blob->set_data_type(in_blob->data_type()); out_blob->mut_shape() = Shape(out_dim_vec); } void ReduceSumOp::VirtualGenKernelConf( std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, KernelConf* kernel_conf) const { - if (op_conf().reduce_sum_conf().has_axis() == false) { return; } - kernel_conf->mutable_reduce_sum_conf()->set_axis(GetCorrectAxis(GetBlobDesc4BnInOp)); -} - -int32_t ReduceSumOp::GetCorrectAxis( - std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp) const { - int32_t axis = op_conf().reduce_sum_conf().axis(); - if (axis < 0) { axis += GetBlobDesc4BnInOp("in")->shape().NumAxes(); } - return axis; + const ReduceSumOpConf& conf = op_conf().reduce_sum_conf(); + const BlobDesc* in_blob = GetBlobDesc4BnInOp("in"); + std::vector<int64_t> kept_dims; + if (conf.axis().empty()) { + kept_dims.resize(in_blob->shape().NumAxes()); + std::fill(kept_dims.begin(), kept_dims.end(), 1); + } else { + const PbRf<int32_t>& axis_repeated = op_conf().reduce_sum_conf().axis(); + std::vector<int64_t> axis_vec = {axis_repeated.begin(), axis_repeated.end()}; + kept_dims = KeepDims(in_blob->shape().dim_vec(), + ShiftAxisIfNegative(axis_vec, in_blob->shape().NumAxes())); + } + Shape(kept_dims).ToProto(kernel_conf->mutable_reduce_sum_conf()->mutable_kept_dims_shape()); } REGISTER_OP(OperatorConf::kReduceSumConf, ReduceSumOp); diff --git a/oneflow/core/operator/reduce_sum_op.h b/oneflow/core/operator/reduce_sum_op.h index 0a43b4fe52e7c523ccb8fa3efb516f1a2db508a9..b8f0276653d7d9cc97d0e724a74d11b00d12442f 100644 --- a/oneflow/core/operator/reduce_sum_op.h +++ b/oneflow/core/operator/reduce_sum_op.h @@ -13,23 +13,24 @@ class ReduceSumOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const override; - int32_t GetCorrectAxis( - std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp) const; LogicalBlobId ibn2lbi(const std::string& input_bn) const override { const ReduceSumOpConf& conf = op_conf().reduce_sum_conf(); if (conf.has_in_sys()) { - CHECK_EQ(conf.axis(), 0); + CHECK_EQ(conf.axis_size(), 1); + CHECK_EQ(conf.axis().Get(0), 0); return conf.in_sys(); } else if (conf.has_in()) { - CHECK_GE(conf.axis(), 1); return GenLogicalBlobId(conf.in()); } else { UNIMPLEMENTED(); diff --git a/oneflow/core/operator/relu_op.h b/oneflow/core/operator/relu_op.h index 6a8f82025482b6d527dd7a73ecad119d6911a434..b3e64b543c55483fee6657c5f8abd665fe849c26 100644 --- a/oneflow/core/operator/relu_op.h +++ b/oneflow/core/operator/relu_op.h @@ -13,13 +13,13 @@ class ReluOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - bool IsElemWiseOp() const override { return true; } bool NeedInBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } }; } // namespace oneflow diff --git a/oneflow/core/operator/repeat_op.cpp b/oneflow/core/operator/repeat_op.cpp index 5a1e396a804d9b16364f7076ad179cbb804fab0a..5721a994e7eb5ed552d229ff68f42001065e936c 100644 --- a/oneflow/core/operator/repeat_op.cpp +++ b/oneflow/core/operator/repeat_op.cpp @@ -6,31 +6,24 @@ namespace oneflow { void oneflow::RepeatOp::InitFromOpConf() { CHECK(op_conf().has_repeat_conf()); const RepeatOpConf& conf = op_conf().repeat_conf(); - if (conf.has_repeat_num()) { - CHECK_GE(conf.repeat_num(), 1); - } else if (conf.has_repeat_num_per_record()) { - CHECK_GE(conf.repeat_num_per_record(), 1); - } else { - UNIMPLEMENTED(); - } + CHECK_GE(conf.repeat_num(), 1); EnrollInputBn("in"); EnrollOutputBn("out"); } -int32_t RepeatOp::GetRepeatNum(int64_t parallel_num) const { +void RepeatOp::InferOutputBlobTimeShape( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, Shape* time_shape) const { + std::vector<int64_t> dim_vec(GetTimeShape4BnInOp("in")->dim_vec()); + int32_t repeat_num = GetRepeatNum(); + dim_vec.push_back(repeat_num); + *time_shape = Shape(dim_vec); +} + +int32_t RepeatOp::GetRepeatNum() const { CHECK(op_conf().has_repeat_conf()); const RepeatOpConf& conf = op_conf().repeat_conf(); - if (conf.has_repeat_num()) { - return conf.repeat_num(); - } else if (conf.has_repeat_num_per_record()) { - CHECK_EQ(Global<JobDesc>::Get()->PieceSize() % parallel_num, 0); - int64_t repeat_num = - Global<JobDesc>::Get()->PieceSize() / parallel_num * conf.repeat_num_per_record(); - CHECK_LE(repeat_num, static_cast<int64_t>(MaxVal<int32_t>())); - return static_cast<int32_t>(repeat_num); - } else { - UNIMPLEMENTED(); - } + return conf.repeat_num(); } const PbMessage& RepeatOp::GetCustomizedConf() const { return op_conf().repeat_conf(); } diff --git a/oneflow/core/operator/repeat_op.h b/oneflow/core/operator/repeat_op.h index 3635aaebef4fe8a05dd327f6aad8f96c50390eef..d7b0b27b48d28d551720aeea19d21e4fbfe28b7b 100644 --- a/oneflow/core/operator/repeat_op.h +++ b/oneflow/core/operator/repeat_op.h @@ -11,12 +11,17 @@ class RepeatOp final : public Operator { RepeatOp() = default; ~RepeatOp() override = default; - int32_t GetRepeatNum(int64_t parallel_num) const; + int32_t GetRepeatNum() const; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, + Shape* time_shape) const override; + void InferDiffBlobDescsWithoutFwBlob( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*) const override; diff --git a/oneflow/core/operator/reshape_op.cpp b/oneflow/core/operator/reshape_op.cpp index cae8727816de3b2955fdb2d8224292804ceca5a5..66395f865c0768347fe36d302ba59050d0bf1155 100644 --- a/oneflow/core/operator/reshape_op.cpp +++ b/oneflow/core/operator/reshape_op.cpp @@ -17,9 +17,25 @@ void ReshapeOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetB *out_blob_desc = *in_blob_desc; const ReshapeOpConf& conf = op_conf().reshape_conf(); - std::vector<int64_t> dim_vec(1 + conf.shape().dim_size()); - dim_vec[0] = in_blob_desc->shape().At(0); - FOR_RANGE(size_t, i, 1, dim_vec.size()) { dim_vec[i] = conf.shape().dim(i - 1); } + + std::vector<int64_t> dim_vec; + if (!conf.has_dim0_in_shape()) { dim_vec.push_back(in_blob_desc->shape().At(0)); } + for (int32_t i = 0; i < conf.shape().dim_size(); ++i) { dim_vec.push_back(conf.shape().dim(i)); } + int32_t dim_cnt_need_infer = 0; + int32_t dim_index_need_infer = -1; + int64_t elem_cnt = 1; + for (int32_t i = 0; i < dim_vec.size(); ++i) { + if (dim_vec[i] == -1) { + ++dim_cnt_need_infer; + dim_index_need_infer = i; + } else { + elem_cnt *= dim_vec[i]; + } + } + CHECK_LE(dim_cnt_need_infer, 1); + if (dim_cnt_need_infer == 1) { + dim_vec[dim_index_need_infer] = in_blob_desc->shape().elem_cnt() / elem_cnt; + } out_blob_desc->mut_shape() = Shape(dim_vec); CHECK_EQ(out_blob_desc->shape().elem_cnt(), in_blob_desc->shape().elem_cnt()); } diff --git a/oneflow/core/operator/reshape_op.h b/oneflow/core/operator/reshape_op.h index 53bdb3ab279f15a496706ac510e0837f941ef29b..8c3b3dc111125382c1c1d56f61844340a7a0ff64 100644 --- a/oneflow/core/operator/reshape_op.h +++ b/oneflow/core/operator/reshape_op.h @@ -13,12 +13,16 @@ class ReshapeOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - bool IsElemWiseOp() const override { return true; } + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + bool IsForwardInplace() const override { return true; } + bool IsBackwardInplace() const override { return true; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/rmsprop_model_update_op.cpp b/oneflow/core/operator/rmsprop_model_update_op.cpp index c9702acec5c72b365ffef6abd0b583753b791c11..80c7ebfddd10d24c8a69b1ddca2a167d8c0a9dd2 100644 --- a/oneflow/core/operator/rmsprop_model_update_op.cpp +++ b/oneflow/core/operator/rmsprop_model_update_op.cpp @@ -4,7 +4,7 @@ namespace oneflow { void RMSPropModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollDataTmpBn("mean_square"); } -void RMSPropModelUpdateOp::InferBlobDescs( +void RMSPropModelUpdateOp::MdUpdtVirtualInferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* model_blob_desc = GetBlobDesc4BnInOp("model"); diff --git a/oneflow/core/operator/rmsprop_model_update_op.h b/oneflow/core/operator/rmsprop_model_update_op.h index 52c80395ae9ee607be1972688ebe078be3469e24..f2deced85ecc5f20ded3d73a355e62188f0cbb86 100644 --- a/oneflow/core/operator/rmsprop_model_update_op.h +++ b/oneflow/core/operator/rmsprop_model_update_op.h @@ -11,11 +11,10 @@ class RMSPropModelUpdateOp final : public NormalModelUpdtOp { RMSPropModelUpdateOp() = default; ~RMSPropModelUpdateOp() = default; - void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx) const override; - private: void MdUpdtVirtualInitFromOpConf() override; + void MdUpdtVirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; }; } // namespace oneflow diff --git a/oneflow/core/operator/rsqrt_op.cpp b/oneflow/core/operator/rsqrt_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8a135b7f51cee26e5914fd520d8e867520a7e702 --- /dev/null +++ b/oneflow/core/operator/rsqrt_op.cpp @@ -0,0 +1,20 @@ +#include "oneflow/core/operator/rsqrt_op.h" + +namespace oneflow { + +void RsqrtOp::InitFromOpConf() { + CHECK(op_conf().has_rsqrt_conf()); + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +const PbMessage& RsqrtOp::GetCustomizedConf() const { return op_conf().rsqrt_conf(); } + +void RsqrtOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_OP(OperatorConf::kRsqrtConf, RsqrtOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/rsqrt_op.h b/oneflow/core/operator/rsqrt_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c5d9a09ed70b3c3ad9f4f075d3a2569270d39706 --- /dev/null +++ b/oneflow/core/operator/rsqrt_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_RSQRT_H_ +#define ONEFLOW_CORE_OPERATOR_RSQRT_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class RsqrtOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(RsqrtOp); + RsqrtOp() = default; + ~RsqrtOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_RSQRT_H_ diff --git a/oneflow/core/operator/scalar_add_op.cpp b/oneflow/core/operator/scalar_add_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cac8b8203d5917a1eca90f88d6a86206343c106 --- /dev/null +++ b/oneflow/core/operator/scalar_add_op.cpp @@ -0,0 +1,17 @@ +#include "oneflow/core/operator/scalar_add_op.h" + +namespace oneflow { + +void ScalarAddOp::InitFromOpConf() { + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +void ScalarAddOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_OP(OperatorConf::kScalarAddConf, ScalarAddOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/scalar_add_op.h b/oneflow/core/operator/scalar_add_op.h new file mode 100644 index 0000000000000000000000000000000000000000..12f52b53f712583f0caadeffe8b0e138aa9a88c3 --- /dev/null +++ b/oneflow/core/operator/scalar_add_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_ADD_OP_H_ +#define ONEFLOW_CORE_OPERATOR_SCALAR_ADD_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ScalarAddOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ScalarAddOp); + ScalarAddOp() = default; + ~ScalarAddOp() = default; + + void InitFromOpConf() override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + const PbMessage& GetCustomizedConf() const override { return op_conf().scalar_add_conf(); } + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SCALAR_ADD_OP_H_ diff --git a/oneflow/core/operator/scalar_mul_op.cpp b/oneflow/core/operator/scalar_mul_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..175d49a5a173c5b9fa7525d95907c9ee373d79a7 --- /dev/null +++ b/oneflow/core/operator/scalar_mul_op.cpp @@ -0,0 +1,17 @@ +#include "oneflow/core/operator/scalar_mul_op.h" + +namespace oneflow { + +void ScalarMulOp::InitFromOpConf() { + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +void ScalarMulOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_OP(OperatorConf::kScalarMulConf, ScalarMulOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/scalar_mul_op.h b/oneflow/core/operator/scalar_mul_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fc0126135aa1bdf7c052b7dae43a5c78802b5653 --- /dev/null +++ b/oneflow/core/operator/scalar_mul_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_ +#define ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ScalarMulOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ScalarMulOp); + ScalarMulOp() = default; + ~ScalarMulOp() = default; + + void InitFromOpConf() override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + const PbMessage& GetCustomizedConf() const override { return op_conf().scalar_mul_conf(); } + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_ diff --git a/oneflow/core/operator/shared_model_diff_add_op.h b/oneflow/core/operator/shared_model_diff_add_op.h index 61d6b36a49db76f07774f478a1bb4a86d84369b6..a95b61bb3018327cdc093696efcb88a9da287ab6 100644 --- a/oneflow/core/operator/shared_model_diff_add_op.h +++ b/oneflow/core/operator/shared_model_diff_add_op.h @@ -17,6 +17,10 @@ class SharedModelDiffAddOp final : public Operator { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { + return UNIMPLEMENTED(); + } + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; diff --git a/oneflow/core/operator/sigmoid_cross_entropy_loss_op.cpp b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adb5a828747208e5537e74f925ce54b7deb4b90b --- /dev/null +++ b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.cpp @@ -0,0 +1,61 @@ +#include "oneflow/core/operator/sigmoid_cross_entropy_loss_op.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +void SigmoidCrossEntropyLossOp::VirtualInitFromOpConf() { + EnrollDataTmpBn("count"); + EnrollDataTmpBn("label_num"); + EnrollDataTmpBn("loss_buf"); + EnrollDataTmpBn("sum_buf"); +} + +const PbMessage& SigmoidCrossEntropyLossOp::GetCustomizedConf() const { + return op_conf().sigmoid_cross_entropy_loss_conf(); +} + +LossKernelConf* SigmoidCrossEntropyLossOp::GetMutLossKernelConf(KernelConf* kernel_conf) const { + return kernel_conf->mutable_sigmoid_cross_entropy_loss_conf()->mutable_loss_conf(); +} + +void SigmoidCrossEntropyLossOp::VirtualInferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + // label + const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label"); + // a label must be in {-1, 0, 1} while -1 indicates ignorance + CHECK_GE(label_blob_desc->shape().NumAxes(), 2); + // prediction + const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction"); + CHECK_EQ(pred_blob_desc->shape(), label_blob_desc->shape()); + CHECK_GE(pred_blob_desc->shape().NumAxes(), 2); + + int64_t data_num = pred_blob_desc->shape().At(0); + int64_t data_dim = pred_blob_desc->shape().Count(1); + + // loss + BlobDesc* loss_blob_desc = GetBlobDesc4BnInOp("loss"); + loss_blob_desc->mut_shape() = Shape({data_num}); + loss_blob_desc->set_data_type(pred_blob_desc->data_type()); + // count + BlobDesc* count_blob_desc = GetBlobDesc4BnInOp("count"); + count_blob_desc->mut_shape() = Shape({data_dim}); + count_blob_desc->set_data_type(pred_blob_desc->data_type()); + // loss_buf + BlobDesc* loss_buf_desc = GetBlobDesc4BnInOp("loss_buf"); + loss_buf_desc->mut_shape() = Shape({data_dim}); + loss_buf_desc->set_data_type(pred_blob_desc->data_type()); + // label_num + BlobDesc* label_num_blob_desc = GetBlobDesc4BnInOp("label_num"); + label_num_blob_desc->mut_shape() = Shape({1}); + label_num_blob_desc->set_data_type(pred_blob_desc->data_type()); + // sum buf + BlobDesc* sum_buf_blob_desc = GetBlobDesc4BnInOp("sum_buf"); + const int64_t sum_buf_size = GetTmpSizeForReduceSum(pred_blob_desc->data_type(), data_dim); + sum_buf_blob_desc->mut_shape() = Shape({sum_buf_size}); + sum_buf_blob_desc->set_data_type(DataType::kChar); +} + +REGISTER_OP(OperatorConf::kSigmoidCrossEntropyLossConf, SigmoidCrossEntropyLossOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/sigmoid_cross_entropy_loss_op.h b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..33cffad76301f51fe9037f71045862f304fb1cb0 --- /dev/null +++ b/oneflow/core/operator/sigmoid_cross_entropy_loss_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ +#define ONEFLOW_CORE_OPERATOR_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ + +#include "oneflow/core/operator/loss_op.h" + +namespace oneflow { + +class SigmoidCrossEntropyLossOp final : public LossOp { + public: + OF_DISALLOW_COPY_AND_MOVE(SigmoidCrossEntropyLossOp); + SigmoidCrossEntropyLossOp() = default; + ~SigmoidCrossEntropyLossOp() = default; + + const PbMessage& GetCustomizedConf() const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + + void VirtualInitFromOpConf() override; + void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + LossKernelConf* GetMutLossKernelConf(KernelConf*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ diff --git a/oneflow/core/operator/sigmoid_op.h b/oneflow/core/operator/sigmoid_op.h index e206b0876c44b2ae86d9393c37525d4af7256543..e98707e03a1ba097810c0b773f21b66f7d4c0a71 100644 --- a/oneflow/core/operator/sigmoid_op.h +++ b/oneflow/core/operator/sigmoid_op.h @@ -13,13 +13,13 @@ class SigmoidOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - bool IsElemWiseOp() const override { return true; } bool NeedInBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } }; } // namespace oneflow diff --git a/oneflow/core/operator/slice_op.cpp b/oneflow/core/operator/slice_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6442ba3ed41fe09ecd99e712fc528f1cb0210205 --- /dev/null +++ b/oneflow/core/operator/slice_op.cpp @@ -0,0 +1,61 @@ +#include "oneflow/core/operator/slice_op.h" + +namespace oneflow { + +void SliceOp::InitFromOpConf() { + CHECK(op_conf().has_slice_conf()); + EnrollInputBn("in"); + EnrollOutputBn("out"); + if (op_conf().device_type() == DeviceType::kGPU) { EnrollConstBufBn("out_to_in_offset"); } +} + +const PbMessage& SliceOp::GetCustomizedConf() const { return op_conf().slice_conf(); } + +void SliceOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { + const Shape& in_shape = GetBlobDesc4BnInOp("in")->shape(); + in_shape.ToProto(kernel_conf->mutable_slice_conf()->mutable_in_shape()); +} + +bool SliceOp::IsInputBlobAllowedModelSplit(const std::string& ibn) const { + CHECK(std::find(input_bns().begin(), input_bns().end(), ibn) != input_bns().end()); + return ibn == "in"; +} + +void SliceOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const SliceOpConf& conf = op_conf().slice_conf(); + const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); + CHECK_EQ(conf.dim_slice_conf_size(), in_blob_desc->shape().NumAxes() - 1); + std::vector<int64_t> shape_vec(in_blob_desc->shape().NumAxes()); + shape_vec[0] = in_blob_desc->shape().At(0); + FOR_RANGE(size_t, i, 0, conf.dim_slice_conf_size()) { + int32_t dim_len = in_blob_desc->shape().At(i + 1); + const DimSliceConf& dim_slice_conf = conf.dim_slice_conf(i); + int32_t step = dim_slice_conf.stride(); + CHECK_GT(step, 0); + int32_t start = dim_slice_conf.has_start() ? dim_slice_conf.start() : 0; + int32_t end = dim_slice_conf.has_end() ? dim_slice_conf.end() : dim_len; + if (start < 0) { start += dim_len; } + if (end < 0) { end += dim_len; } + CHECK_GE(start, 0); + CHECK_LT(start, end); + CHECK_LE(end, dim_len); + shape_vec[i + 1] = (end - start - 1) / std::abs(step) + 1; + } + + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + *out_blob_desc = *in_blob_desc; + out_blob_desc->mut_shape() = Shape(shape_vec); + + BlobDesc* offset_blob_desc = GetBlobDesc4BnInOp("out_to_in_offset"); + if (offset_blob_desc) { + *offset_blob_desc = *out_blob_desc; + offset_blob_desc->set_data_type(DataType::kInt64); + } +} + +REGISTER_OP(OperatorConf::kSliceConf, SliceOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/slice_op.h b/oneflow/core/operator/slice_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bb15463cea499d17412cd425fec769a96f6c4d50 --- /dev/null +++ b/oneflow/core/operator/slice_op.h @@ -0,0 +1,28 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SLICE_OP_H_ +#define ONEFLOW_CORE_OPERATOR_SLICE_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class SliceOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(SliceOp); + SliceOp() = default; + ~SliceOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + KernelConf* kernel_conf) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SLICE_OP_H_ diff --git a/oneflow/core/operator/softmax_op.cpp b/oneflow/core/operator/softmax_op.cpp index b2b3152c66835e519f5bd6923ed69471e3b20213..5a59e9e4bcf2a911c35c4ff258f6f51c8b72d9d9 100644 --- a/oneflow/core/operator/softmax_op.cpp +++ b/oneflow/core/operator/softmax_op.cpp @@ -20,7 +20,7 @@ void SoftmaxOp::InitFromOpConf() { const PbMessage& SoftmaxOp::GetCustomizedConf() const { return op_conf().softmax_conf(); } void SoftmaxOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const { // in const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); diff --git a/oneflow/core/operator/softmax_op.h b/oneflow/core/operator/softmax_op.h index a3303a58aeb9b47ffcf2e0743a0073fd4aabfabb..d4295a576f727596dfd4e90ebc76f703e8f550c7 100644 --- a/oneflow/core/operator/softmax_op.h +++ b/oneflow/core/operator/softmax_op.h @@ -23,12 +23,14 @@ class SoftmaxOp final : public Operator { const PbMessage& GetCustomizedConf() const override; void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, - const ParallelContext*, + const ParallelContext* parallel_ctx, int64_t record_piece_size, std::function<void(OpContext*)> EnrollOpCtx) const override; void InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, const OpContext*) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*, const OpContext*) const override; SoftmaxOpCtx* NewSoftmaxOpCtx(const Shape& in_shape) const; diff --git a/oneflow/core/operator/sparse_cross_entropy_loss_op.h b/oneflow/core/operator/sparse_cross_entropy_loss_op.h index 716aac25f7047b484af04d20e4fd5f983f118af2..7fb10aa9661aaca5ae18c769b380d35a6af5f53f 100644 --- a/oneflow/core/operator/sparse_cross_entropy_loss_op.h +++ b/oneflow/core/operator/sparse_cross_entropy_loss_op.h @@ -14,6 +14,8 @@ class SparseCrossEntropyLossOp final : public LossOp { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + LossKernelConf* GetMutLossKernelConf(KernelConf*) const override; }; diff --git a/oneflow/core/operator/sparse_cross_entropy_op.cpp b/oneflow/core/operator/sparse_cross_entropy_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..833ce7863e79a7c481f435922922e2d4ff9c7684 --- /dev/null +++ b/oneflow/core/operator/sparse_cross_entropy_op.cpp @@ -0,0 +1,46 @@ +#include "oneflow/core/operator/sparse_cross_entropy_op.h" + +namespace oneflow { + +void SparseCrossEntropyOp::InitFromOpConf() { + CHECK(op_conf().has_sparse_cross_entropy_conf()); + EnrollInputBn("prediction"); + EnrollInputBn("label", false); + EnrollOutputBn("out"); +} + +const PbMessage& SparseCrossEntropyOp::GetCustomizedConf() const { + return op_conf().sparse_cross_entropy_conf(); +} + +void SparseCrossEntropyOp::InferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, int64_t record_piece_size, + std::function<void(OpContext*)> EnrollOpCtx) const { + const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction"); + const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label"); + CHECK(IsIntegralDataType(label_blob_desc->data_type())); + CHECK(IsFloatingDataType(pred_blob_desc->data_type())); + CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field()); + CHECK_EQ(pred_blob_desc->has_dim0_valid_num_field(), label_blob_desc->has_dim0_valid_num_field()); + CHECK_EQ(pred_blob_desc->has_dim0_inner_shape(), label_blob_desc->has_dim0_inner_shape()); + if (pred_blob_desc->has_dim0_inner_shape()) { + CHECK_EQ(pred_blob_desc->dim0_inner_shape().At(0), 1); + CHECK_EQ(pred_blob_desc->dim0_inner_shape(), label_blob_desc->dim0_inner_shape()); + } + CHECK_GE(pred_blob_desc->shape().NumAxes(), 2); + const int64_t num_out_axes = pred_blob_desc->shape().NumAxes() - 1; + CHECK_GE(label_blob_desc->shape().NumAxes(), num_out_axes); + CHECK_EQ(label_blob_desc->shape().Count(num_out_axes), 1); + FOR_RANGE(int64_t, i, 0, num_out_axes) { + CHECK_EQ(pred_blob_desc->shape().At(i), label_blob_desc->shape().At(i)); + } + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + *out_blob_desc = *pred_blob_desc; + out_blob_desc->mut_shape() = Shape(std::vector<int64_t>( + pred_blob_desc->shape().dim_vec().cbegin(), pred_blob_desc->shape().dim_vec().cend() - 1)); +} + +REGISTER_OP(OperatorConf::kSparseCrossEntropyConf, SparseCrossEntropyOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/sparse_cross_entropy_op.h b/oneflow/core/operator/sparse_cross_entropy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6e188b22059d8a4ad866df5bf4a2cd9331efe5b4 --- /dev/null +++ b/oneflow/core/operator/sparse_cross_entropy_op.h @@ -0,0 +1,28 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SPARSE_CROSS_ENTROPY_OP_H_ +#define ONEFLOW_CORE_OPERATOR_SPARSE_CROSS_ENTROPY_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class SparseCrossEntropyOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(SparseCrossEntropyOp); + SparseCrossEntropyOp() = default; + ~SparseCrossEntropyOp() override = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, int64_t record_piece_size, + std::function<void(OpContext*)> EnrollOpCtx) const override; + bool NeedOutBlobWhenBackward() const override { return false; } + bool NeedInBlobWhenBackward() const override { return true; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SPARSE_CROSS_ENTROPY_OP_H_ diff --git a/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h b/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h index 2836a7f6182eaea4bbb872f88a5b9173d44d451b..b8f510746d700560ce46dfa8e664c787f9315b7c 100644 --- a/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h +++ b/oneflow/core/operator/sparse_softmax_cross_entropy_loss_op.h @@ -14,6 +14,8 @@ class SparseSoftmaxCrossEntropyLossOp final : public LossOp { const PbMessage& GetCustomizedConf() const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void VirtualInitFromOpConf() override; void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; diff --git a/oneflow/core/operator/sqrt_op.cpp b/oneflow/core/operator/sqrt_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f19e25a0c65ec6bb7394ecc87d4eaef5d4e98e0 --- /dev/null +++ b/oneflow/core/operator/sqrt_op.cpp @@ -0,0 +1,20 @@ +#include "oneflow/core/operator/sqrt_op.h" + +namespace oneflow { + +void SqrtOp::InitFromOpConf() { + CHECK(op_conf().has_sqrt_conf()); + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +const PbMessage& SqrtOp::GetCustomizedConf() const { return op_conf().sqrt_conf(); } + +void SqrtOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_OP(OperatorConf::kSqrtConf, SqrtOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/sqrt_op.h b/oneflow/core/operator/sqrt_op.h new file mode 100644 index 0000000000000000000000000000000000000000..70904865234a138235192b4261c996db250985fb --- /dev/null +++ b/oneflow/core/operator/sqrt_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SQRT_H_ +#define ONEFLOW_CORE_OPERATOR_SQRT_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class SqrtOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(SqrtOp); + SqrtOp() = default; + ~SqrtOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SQRT_H_ diff --git a/oneflow/core/operator/square_op.cpp b/oneflow/core/operator/square_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77afab0f8763e7142d0999cbaae0297491bb6d22 --- /dev/null +++ b/oneflow/core/operator/square_op.cpp @@ -0,0 +1,20 @@ +#include "oneflow/core/operator/square_op.h" + +namespace oneflow { + +void SquareOp::InitFromOpConf() { + CHECK(op_conf().has_square_conf()); + EnrollInputBn("in"); + EnrollOutputBn("out"); +} + +const PbMessage& SquareOp::GetCustomizedConf() const { return op_conf().square_conf(); } + +void SquareOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); +} + +REGISTER_OP(OperatorConf::kSquareConf, SquareOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/square_op.h b/oneflow/core/operator/square_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7af18ce16184a02ce7552135a2b581c0ec3d3db6 --- /dev/null +++ b/oneflow/core/operator/square_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SQUARE_H_ +#define ONEFLOW_CORE_OPERATOR_SQUARE_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class SquareOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(SquareOp); + SquareOp() = default; + ~SquareOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedOutBlobWhenBackward() const override { return false; } + + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SQUARE_H_ diff --git a/oneflow/core/operator/tanh_op.h b/oneflow/core/operator/tanh_op.h index 2d402b8452fda1dd772c0111f2a402f753552a48..b0b8deaab159905931ebec754268432b72e54144 100644 --- a/oneflow/core/operator/tanh_op.h +++ b/oneflow/core/operator/tanh_op.h @@ -13,13 +13,13 @@ class TanHOp final : public Operator { void InitFromOpConf() override; const PbMessage& GetCustomizedConf() const override; - bool IsElemWiseOp() const override { return true; } bool NeedInBlobWhenBackward() const override { return false; } void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } }; } // namespace oneflow diff --git a/oneflow/core/operator/tick_op.cpp b/oneflow/core/operator/tick_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a79b3ed7e43dd6e0a80cb11f25fa2c0eb1ffc837 --- /dev/null +++ b/oneflow/core/operator/tick_op.cpp @@ -0,0 +1,20 @@ +#include "oneflow/core/operator/tick_op.h" + +namespace oneflow { + +void TickOp::InitFromOpConf() { + EnrollInputBn("in", false); + EnrollOutputBn("out", false); +} + +void TickOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + BlobDesc* out = GetBlobDesc4BnInOp("out"); + *out = *in; + out->mut_shape() = Shape({1}); +} + +REGISTER_CPU_OP(OperatorConf::kTickConf, TickOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/tick_op.h b/oneflow/core/operator/tick_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8f37fe8818cab8ecdfa6f10c628b1d8481353edc --- /dev/null +++ b/oneflow/core/operator/tick_op.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_OPERATOR_TICK_OP_H_ +#define ONEFLOW_CORE_OPERATOR_TICK_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class TickOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(TickOp); + TickOp() = default; + ~TickOp() = default; + + void InitFromOpConf() override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + const PbMessage& GetCustomizedConf() const override { return op_conf().tick_conf(); } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_TICK_OP_H_ diff --git a/oneflow/core/operator/top_k_op.cpp b/oneflow/core/operator/top_k_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..78a5d33d4cecfa3a0c30e929a677a9e6d21deaf4 --- /dev/null +++ b/oneflow/core/operator/top_k_op.cpp @@ -0,0 +1,38 @@ +#include "oneflow/core/operator/top_k_op.h" + +namespace oneflow { + +void TopKOp::InitFromOpConf() { + CHECK(op_conf().has_top_k_conf()); + EnrollInputBn("in", false); + EnrollFwBufBn("fw_buf"); + EnrollOutputBn("out", false); +} + +void TopKOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + CHECK_LE(in->shape().elem_cnt(), GetMaxVal<int32_t>()); + const TopKOpConf& conf = op_conf().top_k_conf(); + CHECK_GE(conf.k(), 1); + CHECK_LE(conf.k(), in->shape().dim_vec().back()); + // fw_buf + BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf"); + fw_buf->mut_shape() = Shape({in->shape().dim_vec().back()}); + fw_buf->set_data_type(DataType::kInt32); + // out + BlobDesc* out = GetBlobDesc4BnInOp("out"); + *out = *in; + out->mut_shape().Set(in->shape().NumAxes() - 1, conf.k()); + out->set_data_type(DataType::kInt32); +} + +void TopKOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, + KernelConf* kernel_conf) const { + kernel_conf->set_data_type(GetBlobDesc4BnInOp("in")->data_type()); +} + +REGISTER_CPU_OP(OperatorConf::kTopKConf, TopKOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/top_k_op.h b/oneflow/core/operator/top_k_op.h new file mode 100644 index 0000000000000000000000000000000000000000..31d5eb5dc48eb0355ead7d0425922df741b0e5a4 --- /dev/null +++ b/oneflow/core/operator/top_k_op.h @@ -0,0 +1,27 @@ +#ifndef ONEFLOW_CORE_OPERATOR_TOP_K_OP_H_ +#define ONEFLOW_CORE_OPERATOR_TOP_K_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class TopKOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(TopKOp); + TopKOp() = default; + ~TopKOp() override = default; + + void InitFromOpConf() override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + const PbMessage& GetCustomizedConf() const override { return op_conf().top_k_conf(); } + + private: + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*, KernelConf*) const override; + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_TOP_K_OP_H_ diff --git a/oneflow/core/operator/transpose_op.cpp b/oneflow/core/operator/transpose_op.cpp index f553d3451500a4a6d42484f6ea74868bdefb25f2..df40d757dbf8f3a252a053082df3ec7a02657ba6 100644 --- a/oneflow/core/operator/transpose_op.cpp +++ b/oneflow/core/operator/transpose_op.cpp @@ -5,12 +5,12 @@ namespace oneflow { namespace { void CheckIsPerm(const PbRf<int32_t>& perm) { - std::vector<bool> is_used(perm.size(), 0); + std::vector<bool> is_used(perm.size(), false); FOR_RANGE(size_t, i, 0, perm.size()) { - CHECK_GE(perm[i], 1); + CHECK_GE(perm[i], 0); CHECK_LE(perm[i], perm.size()); - CHECK_EQ(is_used[perm[i] - 1], false); - is_used[perm[i] - 1] = true; + CHECK_EQ(is_used[perm[i]], false); + is_used[perm[i]] = true; } } @@ -29,13 +29,44 @@ void TransposeOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> Ge const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); const Shape& in_blob_shape = in_blob_desc->shape(); const PbRf<int32_t>& perm = op_conf().transpose_conf().perm(); - CHECK_EQ(perm.size(), in_blob_shape.NumAxes() - 1); + CHECK_EQ(perm.size(), in_blob_shape.NumAxes()); CheckIsPerm(perm); + if (perm.Get(0) != 0) { + CHECK(!in_blob_desc->has_dim0_valid_num_field()); + } else if (perm.size() >= 2 && perm.Get(1) != 1) { + CHECK(!in_blob_desc->has_dim1_valid_num_field()); + } else if (perm.size() >= 3 && perm.Get(2) != 2) { + CHECK(!in_blob_desc->has_dim2_valid_num_field()); + } else { + // do nothing + } BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); *out_blob_desc = *in_blob_desc; FOR_RANGE(size_t, i, 0, perm.size()) { - out_blob_desc->mut_shape().Set(i + 1, in_blob_shape.At(perm[i])); + out_blob_desc->mut_shape().Set(i, in_blob_shape.At(perm[i])); + } +} + +int32_t TransposeOp::OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const { + const auto& in_sbp_infer_hint = SbpInferHint4Ibn("in"); + const PbRf<int32_t>& perm = op_conf().transpose_conf().perm(); + CHECK_GT(perm.size(), 0); + CHECK_EQ(perm.size(), in_sbp_infer_hint.num_axes()); + int32_t split_axis = -1; + FOR_RANGE(int32_t, i, 0, perm.size()) { + if (perm[i] == in_sbp_infer_hint.split_axis()) { + split_axis = i; + break; + } + } + CHECK_NE(split_axis, -1); + if (in_sbp_infer_hint.is_data_blob()) { + CHECK_GT(in_sbp_infer_hint.split_axis(), 0); + CHECK_GT(split_axis, 0); } + return split_axis; } void TransposeOp::VirtualGenKernelConf( @@ -43,9 +74,8 @@ void TransposeOp::VirtualGenKernelConf( const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { const PbRf<int32_t>& src_perm = op_conf().transpose_conf().perm(); PbRf<int32_t>* perm = kernel_conf->mutable_transpose_conf()->mutable_perm(); - perm->Add(0); - perm->MergeFrom(src_perm); - CHECK_EQ(perm->size(), src_perm.size() + 1); + *perm = src_perm; + CHECK_EQ(perm->size(), src_perm.size()); PbRf<int32_t>* invert_perm = kernel_conf->mutable_transpose_conf()->mutable_invert_perm(); invert_perm->Reserve(perm->size()); invert_perm->CopyFrom(*perm); diff --git a/oneflow/core/operator/transpose_op.h b/oneflow/core/operator/transpose_op.h index 889062c114f3345b5d458ad7d5368f0d28d6d196..4f37ed3c4fbdf8ecffdfe2b1194ea0f911e436d7 100644 --- a/oneflow/core/operator/transpose_op.h +++ b/oneflow/core/operator/transpose_op.h @@ -18,8 +18,13 @@ class TransposeOp final : public Operator { void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override; private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return true; } + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, KernelConf*) const override; }; diff --git a/oneflow/core/operator/tuple_identity_op.cpp b/oneflow/core/operator/tuple_identity_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ff4a0ec0150b86082704e9e2b27c4c6b401474c --- /dev/null +++ b/oneflow/core/operator/tuple_identity_op.cpp @@ -0,0 +1,70 @@ +#include "oneflow/core/operator/tuple_identity_op.h" + +namespace oneflow { + +namespace { + +class TupleIdentityOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(TupleIdentityOpParallelSignature); + ~TupleIdentityOpParallelSignature() override = default; + + TupleIdentityOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": A -> A"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp, + const ParallelContext* parallel_ctx) const override { + const auto& ibn = op().input_bns().Get(0); + if (parallel_ctx->parallel_num() != SbpInferHint4BnInOp(ibn).parallel_num()) { + return MakeOpParallelMatchParallelNumError(parallel_ctx->parallel_num(), + SbpInferHint4BnInOp(ibn).parallel_num()); + } + return MakeOpParallelMatchSuccess(); + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4BnInOp, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + FOR_RANGE(int32_t, i, 0, op().input_bns().size()) { + const auto& sbp_parallel = SbpInferHint4BnInOp(op().input_bns().Get(i)).sbp_parallel(); + (*bn2sbp)[op().input_bns().Get(i)] = sbp_parallel; + (*bn2sbp)[op().output_bns().Get(i)] = sbp_parallel; + } + } +}; + +} // namespace + +void TupleIdentityOp::InitFromOpConf() { + CHECK(op_conf().has_tuple_identity_conf()); + int32_t in_size = op_conf().tuple_identity_conf().in_size(); + int32_t out_size = op_conf().tuple_identity_conf().out_size(); + CHECK_GT(in_size, 0); + CHECK_EQ(in_size, out_size); + EnrollRepeatedInputBn("in", in_size); + EnrollRepeatedOutputBn("out", out_size); +} + +const PbMessage& TupleIdentityOp::GetCustomizedConf() const { + return op_conf().tuple_identity_conf(); +} + +void TupleIdentityOp::InferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + size_t bn_size = op_conf().tuple_identity_conf().in_size(); + FOR_RANGE(int, i, 0, bn_size) { + *GetBlobDesc4BnInOp(output_bns().Get(i)) = *GetBlobDesc4BnInOp(input_bns().Get(i)); + } +} + +void TupleIdentityOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(new TupleIdentityOpParallelSignature(this)); +} + +REGISTER_OP(OperatorConf::kTupleIdentityConf, TupleIdentityOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/tuple_identity_op.h b/oneflow/core/operator/tuple_identity_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d4319efe1f16f4080fcbc35c1fb780c01ae26dce --- /dev/null +++ b/oneflow/core/operator/tuple_identity_op.h @@ -0,0 +1,32 @@ +#ifndef ONEFLOW_CORE_OPERATOR_TUPLE_IDENTITY_OP_H_ +#define ONEFLOW_CORE_OPERATOR_TUPLE_IDENTITY_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class TupleIdentityOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(TupleIdentityOp); + TupleIdentityOp() = default; + ~TupleIdentityOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { + return op_conf().tuple_identity_conf().in_size() == 1 && ibn == SoleIbn(); + } + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_TUPLE_IDENTITY_OP_H_ diff --git a/oneflow/core/operator/unpack_op.cpp b/oneflow/core/operator/unpack_op.cpp index bda0945ae765cba5cd831f95807140c52531785a..3dfe6b06ba15199ff1743ce8ef7a9cebf61c1d3c 100644 --- a/oneflow/core/operator/unpack_op.cpp +++ b/oneflow/core/operator/unpack_op.cpp @@ -14,7 +14,7 @@ void UnpackOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBl const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp(SoleIbn()); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp(SoleObn()); *out_blob_desc = *in_blob_desc; - int32_t unpack_num = GetUnpackNum(parallel_ctx->parallel_num()); + int32_t unpack_num = GetUnpackNum(); if (in_blob_desc->has_dim0_inner_shape()) { CHECK_EQ(0, unpack_num % in_blob_desc->dim0_inner_shape().At(0)); CHECK_EQ(0, in_blob_desc->dim0_inner_shape().Count(1) @@ -26,22 +26,17 @@ void UnpackOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBl out_blob_desc->mut_dim0_inner_shape() = Shape({1, out_blob_desc->shape().At(0)}); } -int32_t UnpackOp::GetUnpackNum(int64_t parallel_num) const { - CHECK(op_conf().has_unpack_conf()); - const UnpackOpConf& conf = op_conf().unpack_conf(); - if (conf.has_unpack_num()) { - return conf.unpack_num(); - } else if (conf.has_unpack_num_per_record()) { - CHECK_EQ(Global<JobDesc>::Get()->PieceSize() % parallel_num, 0); - int64_t unpack_num = - Global<JobDesc>::Get()->PieceSize() / parallel_num * conf.unpack_num_per_record(); - CHECK_LE(unpack_num, static_cast<int64_t>(MaxVal<int32_t>())); - return static_cast<int32_t>(unpack_num); - } else { - UNIMPLEMENTED(); - } +void UnpackOp::InferOutputBlobTimeShape( + std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, Shape* time_shape) const { + std::vector<int64_t> dim_vec(GetTimeShape4BnInOp("in")->dim_vec()); + int32_t unpack_num = GetUnpackNum(); + dim_vec.push_back(unpack_num); + *time_shape = Shape(dim_vec); } +int32_t UnpackOp::GetUnpackNum() const { return op_conf().unpack_conf().unpack_num(); } + REGISTER_OP(OperatorConf::kUnpackConf, UnpackOp); } // namespace oneflow diff --git a/oneflow/core/operator/unpack_op.h b/oneflow/core/operator/unpack_op.h index 72b12b45575cfb120f1de322fb43ef097345bf02..f52869c4d89dceafcca420c2257445db30012124 100644 --- a/oneflow/core/operator/unpack_op.h +++ b/oneflow/core/operator/unpack_op.h @@ -18,9 +18,15 @@ class UnpackOp final : public Operator { void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override; + void InferOutputBlobTimeShape(std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, + const ParallelContext* parallel_ctx, + Shape* time_shape) const override; bool NeedInBlobWhenBackward() const override { return true; } bool NeedOutBlobWhenBackward() const override { return false; } - int32_t GetUnpackNum(int64_t parallel_num) const; + int32_t GetUnpackNum() const; + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } }; } // namespace oneflow diff --git a/oneflow/core/operator/variable_op.cpp b/oneflow/core/operator/variable_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6382046311a837a6ef5524998b61642c2612cc3a --- /dev/null +++ b/oneflow/core/operator/variable_op.cpp @@ -0,0 +1,125 @@ +#include "oneflow/core/operator/variable_op.h" +#include "oneflow/core/common/balanced_splitter.h" + +namespace oneflow { + +namespace { + +// S(0) -> C +class VariableOpDataSplitOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(VariableOpDataSplitOpParallelSignature); + ~VariableOpDataSplitOpParallelSignature() override = default; + + VariableOpDataSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": S(0) -> C"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + if (parallel_ctx->policy() == kDataParallel) { + return MakeOpParallelMatchSuccess(); + } else { + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kDataParallel); + } + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + CHECK(SbpInferHint4Ibn("tick").is_data_split()); + (*bn2sbp)["tick"].mutable_split_parallel()->set_axis(0); + (*bn2sbp)["out"].mutable_broadcast_parallel(); + } +}; + +// S(0) -> S +class VariableOpModelSplitOpParallelSignature final : public OpParallelSignature { + public: + OF_DISALLOW_COPY_AND_MOVE(VariableOpModelSplitOpParallelSignature); + ~VariableOpModelSplitOpParallelSignature() override = default; + + VariableOpModelSplitOpParallelSignature(const Operator* op) : OpParallelSignature(op) {} + + const std::string Description() const override { return op().op_name() + ": S(0) -> S"; } + + const OpParallelMatchResult GetMatchResult( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const ParallelContext* parallel_ctx) const override { + if (parallel_ctx->policy() == kModelParallel) { + return MakeOpParallelMatchSuccess(); + } else { + return MakeOpParallelMatchParallelPolicyError(parallel_ctx->policy(), kModelParallel); + } + } + + void GenerateSignature( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + HashMap<std::string, SbpParallel>* bn2sbp) const override { + CHECK(SbpInferHint4Ibn("tick").is_data_split()); + (*bn2sbp)["tick"].mutable_split_parallel()->set_axis(0); + (*bn2sbp)["out"].mutable_split_parallel()->set_axis( + (op().OutputBlobModelSplitAxis(SbpInferHint4Ibn, "out"))); + } +}; + +} // namespace + +void VariableOp::InitFromOpConf() { + CHECK(op_conf().has_variable_conf()); + EnrollInputBn("tick", false); + EnrollOutputBn("out", Global<JobDesc>::Get()->IsTrain() && op_conf().trainable()); + EnrollModelBn(op_conf().variable_conf().model_name()); +} + +const PbMessage& VariableOp::GetCustomizedConf() const { return op_conf().variable_conf(); } + +int32_t VariableOp::OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const { + return op_conf().variable_conf().model_split_axis(); +} + +void VariableOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const VariableOpConf& variable_conf = op_conf().variable_conf(); + BlobDesc* model_blob_desc = GetBlobDesc4BnInOp(variable_conf.model_name()); + model_blob_desc->mut_shape() = Shape(variable_conf.shape()); + model_blob_desc->set_data_type(variable_conf.has_data_type() + ? variable_conf.data_type() + : Global<JobDesc>::Get()->DefaultDataType()); + if (parallel_ctx->policy() == kModelParallel) { + int32_t model_split_axis = variable_conf.model_split_axis(); + CHECK_GE(model_split_axis, 0); + CHECK_LT(model_split_axis, model_blob_desc->shape().NumAxes()); + int64_t split_dim_num = model_blob_desc->shape().At(model_split_axis); + BalancedSplitter bs(split_dim_num, parallel_ctx->parallel_num()); + model_blob_desc->mut_shape().Set(model_split_axis, bs.At(parallel_ctx->parallel_id()).size()); + } else { + CHECK_EQ(parallel_ctx->policy(), kDataParallel); + } + *GetBlobDesc4BnInOp("out") = *model_blob_desc; +} + +void VariableOp::GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>* op_parallel_signatures) const { + op_parallel_signatures->emplace_back(new VariableOpDataSplitOpParallelSignature(this)); + op_parallel_signatures->emplace_back(new VariableOpModelSplitOpParallelSignature(this)); +} + +void VariableOp::InferIsModelBlob4OutputBlobs( + std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const { + *IsModelBlob4BnInOp("out") = true; +} + +void VariableOp::VirtualGenKernelConf( + std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*, + KernelConf* conf) const { + conf->mutable_variable_conf()->set_is_fw_inplace(*is_fw_inplace_); + conf->mutable_variable_conf()->set_is_bw_inplace(*is_bw_inplace_); +} + +REGISTER_OP(OperatorConf::kVariableConf, VariableOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/variable_op.h b/oneflow/core/operator/variable_op.h new file mode 100644 index 0000000000000000000000000000000000000000..79c8291c6417382fc7ee80b6b7137a46e65ac434 --- /dev/null +++ b/oneflow/core/operator/variable_op.h @@ -0,0 +1,45 @@ +#ifndef ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_ +#define ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class VariableOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(VariableOp); + VariableOp() + : Operator(), + is_fw_inplace_(std::make_unique<bool>(false)), + is_bw_inplace_(std::make_unique<bool>(false)) {} + ~VariableOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + bool NeedInBlobWhenBackward() const override { return false; } + bool NeedOutBlobWhenBackward() const override { return false; } + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + int32_t OutputBlobModelSplitAxis( + const std::function<const SbpInferHint&(const std::string&)>& SbpInferHint4Ibn, + const std::string& obn) const override; + + void set_is_fw_inplace(bool val) const { *is_fw_inplace_ = val; } + void set_is_bw_inplace(bool val) const { *is_bw_inplace_ = val; } + + private: + bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; } + void GetOpParallelSignatures( + std::vector<std::unique_ptr<const OpParallelSignature>>*) const override; + void InferIsModelBlob4OutputBlobs( + std::function<bool*(const std::string&)> IsModelBlob4BnInOp) const; + void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext*, KernelConf*) const override; + + std::unique_ptr<bool> is_fw_inplace_; + std::unique_ptr<bool> is_bw_inplace_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_VARIABLE_OP_H_ diff --git a/oneflow/core/persistence/tee_persistent_log_stream.h b/oneflow/core/persistence/tee_persistent_log_stream.h index efe21e30ddb659810c413115e1416e17e594fdd6..5b29239cf44344faf7bde0363f12342b8374e072 100644 --- a/oneflow/core/persistence/tee_persistent_log_stream.h +++ b/oneflow/core/persistence/tee_persistent_log_stream.h @@ -29,10 +29,10 @@ class TeePersistentLogStream final { void Write(const PbMessage& proto); static std::unique_ptr<TeePersistentLogStream> Create(const std::string& path); + void Flush(); private: explicit TeePersistentLogStream(const std::string& path); - void Flush(); std::vector<LogStreamDestination> destinations_; std::vector<std::unique_ptr<PersistentOutStream>> branches_; }; diff --git a/oneflow/core/persistence/windows/windows_file_system.cpp b/oneflow/core/persistence/windows/windows_file_system.cpp index 9b5335ed3dbd947c72f8567008b975c7385027a9..eac14309e9538026c7ddff32e1fb9a8d2570f237 100644 --- a/oneflow/core/persistence/windows/windows_file_system.cpp +++ b/oneflow/core/persistence/windows/windows_file_system.cpp @@ -17,7 +17,7 @@ typedef std::unique_ptr<void, decltype(CloseHandleFunc)> UniqueCloseHandlePtr; // PLEASE NOTE: hfile is expected to be an async handle // (i.e. opened with FILE_FLAG_OVERLAPPED) SSIZE_T pread(HANDLE hfile, char* src, size_t num_bytes, uint64_t offset) { - assert(num_bytes <= MaxVal<DWORD>()); + assert(num_bytes <= GetMaxVal<DWORD>()); OVERLAPPED overlapped = {0}; ULARGE_INTEGER offset_union; offset_union.QuadPart = offset; diff --git a/oneflow/core/record/ofrecord_raw_decoder.cpp b/oneflow/core/record/ofrecord_raw_decoder.cpp index 5db8113b3807381ba0180beb429708643a9b9c0b..8b112fe0fca431dbc93348f9086dfb41eedd154f 100644 --- a/oneflow/core/record/ofrecord_raw_decoder.cpp +++ b/oneflow/core/record/ofrecord_raw_decoder.cpp @@ -85,7 +85,12 @@ void OFRecordDecoderImpl<EncodeCase::kRaw, T>::ReadOneCol( else if (feature.has_##PbT##_list()) { \ const auto& list = feature.PbT##_list(); \ const CppT* in_dptr = list.value().data(); \ - one_col_elem_num = std::max<int64_t>(one_col_elem_num, list.value_size()); \ + if (blob_conf.encode_case().raw().dim1_varying_length()) { \ + CHECK_LE(list.value_size(), one_col_elem_num); \ + one_col_elem_num = list.value_size(); \ + } else { \ + CHECK_EQ(one_col_elem_num, list.value_size()); \ + } \ FixInDptrThenCopyElem<CppT, T>(ctx, in_dptr, col_id, one_col_elem_num, out_dptr); \ } DEFINE_ONE_ELIF(float, float) diff --git a/oneflow/core/record/ofrecord_raw_decoder.h b/oneflow/core/record/ofrecord_raw_decoder.h index dc72bddcb63d521688633ac88ef14791f0d2e9a6..29e4d83e1a372f96a4dd10a131c809f9974078d1 100644 --- a/oneflow/core/record/ofrecord_raw_decoder.h +++ b/oneflow/core/record/ofrecord_raw_decoder.h @@ -11,11 +11,12 @@ class OFRecordDecoderImpl<EncodeCase::kRaw, T> final : public OFRecordDecoder<En bool HasDim1ValidNumField(const EncodeConf& encode_conf) const override; bool HasDim2ValidNumField(const EncodeConf& encode_conf) const override { return false; } - private: - int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override; void ReadOneCol(DeviceCtx*, const Feature&, const BlobConf& blob_conf, int32_t col_id, T* out_dptr, int64_t one_col_elem_num, std::function<int32_t(void)> NextRandomInt) const override; + + private: + int32_t GetColNumOfFeature(const Feature&, int64_t one_col_elem_num) const override; void SetDim1ValidNum(const Feature& feature, Blob* out_blob, int64_t dim0_idx) const override; }; diff --git a/oneflow/core/record/ofrecord_raw_encoder.cpp b/oneflow/core/record/ofrecord_raw_encoder.cpp index 53afc259d0b65e49719efdff7bd891110c390871..db85c1ef155befeb1cd15ee8e9a609bc4e16993a 100644 --- a/oneflow/core/record/ofrecord_raw_encoder.cpp +++ b/oneflow/core/record/ofrecord_raw_encoder.cpp @@ -2,6 +2,31 @@ namespace oneflow { +namespace { + +template<typename T> +void EncodeDataToFeature(DeviceCtx* ctx, Feature* feature, const T* in_dptr, size_t elem_num) { + DataType data_type = GetDataType<T>(); + if (data_type == DataType::kInt8) { + feature->mutable_bytes_list()->add_value(reinterpret_cast<const char*>(in_dptr), elem_num); + } +#define DEFINE_ONE_ELIF(CppT, ListT) \ + else if (data_type == GetDataType<CppT>()) { \ + feature->mutable_##ListT##_list()->mutable_value()->Resize(elem_num, 0); \ + CppT* out_dptr = feature->mutable_##ListT##_list()->mutable_value()->mutable_data(); \ + Memcpy<DeviceType::kCPU>(nullptr, out_dptr, in_dptr, elem_num * sizeof(T)); \ + } + DEFINE_ONE_ELIF(float, float) + DEFINE_ONE_ELIF(double, double) + DEFINE_ONE_ELIF(int32_t, int32) +#undef DEFINE_ONE_ELIF + else { + UNIMPLEMENTED(); + } +} + +} // namespace + template<typename T> void OFRecordEncoderImpl<EncodeCase::kRaw, T>::EncodeOneCol(DeviceCtx* ctx, const Blob* in_blob, int64_t in_offset, Feature& feature, @@ -15,23 +40,13 @@ void OFRecordEncoderImpl<EncodeCase::kRaw, T>::EncodeOneCol(DeviceCtx* ctx, cons int64_t elem_num = shape.NumAxes() == 1 ? 1 : in_blob->dim1_valid_num(dim0_idx) * shape.Count(2); const T* in_dptr = in_blob->dptr<T>() + in_offset; - DataType data_type = GetDataType<T>(); - if (data_type == DataType::kInt8) { - feature.mutable_bytes_list()->add_value(reinterpret_cast<const char*>(in_dptr), elem_num); - } -#define DEFINE_ONE_ELIF(CppT, ListT) \ - else if (data_type == GetDataType<CppT>()) { \ - feature.mutable_##ListT##_list()->mutable_value()->Resize(elem_num, 0); \ - CppT* out_dptr = feature.mutable_##ListT##_list()->mutable_value()->mutable_data(); \ - Memcpy<DeviceType::kCPU>(nullptr, out_dptr, in_dptr, elem_num * sizeof(T)); \ - } - DEFINE_ONE_ELIF(float, float) - DEFINE_ONE_ELIF(double, double) - DEFINE_ONE_ELIF(int32_t, int32) -#undef DEFINE_ONE_ELIF - else { - UNIMPLEMENTED(); - } + EncodeDataToFeature(ctx, &feature, in_dptr, elem_num); +} + +template<typename T> +void OFRecordEncoderImpl<EncodeCase::kRaw, T>::EncodeBlob(DeviceCtx* ctx, const Blob* in_blob, + Feature* feature) const { + EncodeDataToFeature(ctx, feature, in_blob->dptr<T>(), in_blob->shape().elem_cnt()); } #define INSTANTIATE_OFRECORD_RAW_ENCODER(type_cpp, type_proto) \ diff --git a/oneflow/core/record/ofrecord_raw_encoder.h b/oneflow/core/record/ofrecord_raw_encoder.h index b8c985bebfe7b35b2b4db6c2d8bcd51ebf5d6806..fcd5ffbd01bd9149c5f870a78db81cd4e9824dde 100644 --- a/oneflow/core/record/ofrecord_raw_encoder.h +++ b/oneflow/core/record/ofrecord_raw_encoder.h @@ -7,6 +7,9 @@ namespace oneflow { template<typename T> class OFRecordEncoderImpl<EncodeCase::kRaw, T> final : public OFRecordEncoderIf { + public: + void EncodeBlob(DeviceCtx* ctx, const Blob* in_blob, Feature* feature) const; + private: void EncodeOneCol(DeviceCtx*, const Blob* in_blob, int64_t in_offset, Feature&, const std::string& field_name, int64_t one_col_elem_num) const override; diff --git a/oneflow/core/record/ofrecord_reader.cpp b/oneflow/core/record/ofrecord_reader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93d36d2fa5ce79262b36df28d9aa248552a028db --- /dev/null +++ b/oneflow/core/record/ofrecord_reader.cpp @@ -0,0 +1,81 @@ +#include "oneflow/core/record/ofrecord_reader.h" + +namespace oneflow { + +constexpr int64_t MAX_CHUNK_SIZE = 64 * 1024 * 1024; // 64M + +namespace { + +bool ReadChunk(PersistentInStream* is, OFRecordChunk* chunk) { + if (is->Read(reinterpret_cast<char*>(&chunk->size), sizeof(int64_t)) == 0) { + CHECK_GE(chunk->size, 0); + CHECK_LE(chunk->size, MAX_CHUNK_SIZE); + chunk->data.reset(new char[chunk->size]); + CHECK_EQ(is->Read(chunk->data.get(), chunk->size), 0); + return true; + } + return false; +} + +} // namespace + +NaiveOFRecordReader::NaiveOFRecordReader(PersistentInStream* in, size_t num_max_read) + : in_stream_(in), num_read_(0), num_max_read_(num_max_read) {} + +size_t NaiveOFRecordReader::Read(size_t n, OFRecord* allocated_records) { + OFRecordChunk chunk; + const size_t can_read = std::min(n, num_max_read_ - num_read_); + FOR_RANGE(size_t, i, 0, can_read) { + if (ReadChunk(in_stream_, &chunk)) { + CHECK(allocated_records[i].ParseFromArray(chunk.data.get(), chunk.size)); + ++num_read_; + } else { + return i; + } + } + return can_read; +} + +RandomShuffleOFRecordReader::RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size, + size_t num_max_read, int32_t random_seed) + : in_stream_(in), + buffer_size_(buffer_size), + num_max_read_(num_max_read), + random_gen_(random_seed), + is_eof_(false) { + CHECK_GT(buffer_size, 0); + buffered_chunks_.reserve(buffer_size); +} + +void RandomShuffleOFRecordReader::FillBuffer() { + for (; num_read_ < num_max_read_ && buffered_chunks_.size() < buffer_size_; ++num_read_) { + OFRecordChunk chunk; + if (ReadChunk(in_stream_, &chunk)) { + buffered_chunks_.emplace_back(std::move(chunk)); + } else { + is_eof_ = true; + break; + } + } + if (num_read_ == num_max_read_) { is_eof_ = true; } +} + +size_t RandomShuffleOFRecordReader::Read(size_t n, OFRecord* allocated_records) { + size_t cur_read = 0; + while (cur_read < n) { + if (!is_eof_) { FillBuffer(); } + if (buffered_chunks_.empty()) { break; } + const size_t pos = + std::uniform_int_distribution<size_t>(0, buffered_chunks_.size() - 1)(random_gen_); + if (pos != buffered_chunks_.size() - 1) { + std::swap(buffered_chunks_[pos], buffered_chunks_.back()); + } + CHECK(allocated_records[cur_read].ParseFromArray(buffered_chunks_.back().data.get(), + buffered_chunks_.back().size)); + buffered_chunks_.pop_back(); + ++cur_read; + } + return cur_read; +} + +} // namespace oneflow diff --git a/oneflow/core/record/ofrecord_reader.h b/oneflow/core/record/ofrecord_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..559647d64b3a23e4cd6b0f34d1599bbe6da2f85b --- /dev/null +++ b/oneflow/core/record/ofrecord_reader.h @@ -0,0 +1,66 @@ +#ifndef ONEFLOW_CORE_RECORD_OFRECORD_READER_H_ +#define ONEFLOW_CORE_RECORD_OFRECORD_READER_H_ + +#include "oneflow/core/record/record.pb.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/persistence/persistent_in_stream.h" + +namespace oneflow { + +struct OFRecordChunk { + int64_t size = 0; + std::unique_ptr<char[]> data; +}; + +class OFRecordReader { + public: + OF_DISALLOW_COPY_AND_MOVE(OFRecordReader); + OFRecordReader() = default; + virtual ~OFRecordReader() = default; + + virtual size_t Read(size_t n, OFRecord* allocated_records) = 0; +}; + +class NaiveOFRecordReader final : public OFRecordReader { + public: + OF_DISALLOW_COPY_AND_MOVE(NaiveOFRecordReader); + explicit NaiveOFRecordReader(PersistentInStream* in) + : NaiveOFRecordReader(in, MaxVal<size_t>::value) {} + NaiveOFRecordReader(PersistentInStream* in, size_t num_max_read); + ~NaiveOFRecordReader() override = default; + + private: + size_t Read(size_t n, OFRecord* allocated_records) override; + + PersistentInStream* in_stream_; + size_t num_read_; + const size_t num_max_read_; +}; + +class RandomShuffleOFRecordReader final : public OFRecordReader { + public: + OF_DISALLOW_COPY_AND_MOVE(RandomShuffleOFRecordReader); + RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size, size_t num_max_read, + int32_t random_seed); + RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size, size_t num_max_read) + : RandomShuffleOFRecordReader(in, buffer_size, num_max_read, std::random_device()()) {} + RandomShuffleOFRecordReader(PersistentInStream* in, size_t buffer_size) + : RandomShuffleOFRecordReader(in, buffer_size, MaxVal<size_t>::value) {} + ~RandomShuffleOFRecordReader() override = default; + + private: + size_t Read(size_t n, OFRecord* allocated_records) override; + void FillBuffer(); + + PersistentInStream* in_stream_; + const size_t buffer_size_; + const size_t num_max_read_; + std::mt19937 random_gen_; + size_t num_read_; + std::vector<OFRecordChunk> buffered_chunks_; + bool is_eof_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_RECORD_OFRECORD_READER_H_ diff --git a/oneflow/core/register/blob.cpp b/oneflow/core/register/blob.cpp index 38cc3a9d8a4e87ccac52185a8dfee034d6a4d3aa..49dfd2cc33d84ac777513de87f6cdcad6e0cdb90 100644 --- a/oneflow/core/register/blob.cpp +++ b/oneflow/core/register/blob.cpp @@ -28,6 +28,7 @@ void Blob::Init(Regst* regst, const RtBlobDesc* blob_desc, char* header_ptr, cha dim2_valid_num_ptr_ = header_pod_ptr_.MutTensorPtr<int64_t>(FieldKey::kDim2ValidNum, nullptr); record_id_in_device_piece_ptr_ = header_pod_ptr_.MutTensorPtr<int64_t>(FieldKey::kRecordIdInDevicePiece, nullptr); + loss_instance_num_ptr_ = header_pod_ptr_.MutTensorPtr<float>(FieldKey::kLossInstanceNum, nullptr); dptr_ = body_ptr; dynamic_shape_ = blob_desc->shape(); } @@ -131,6 +132,16 @@ void Blob::set_record_id_in_device_piece(int64_t no, int64_t val) { record_id_in_device_piece_ptr_[no] = val; } +float Blob::loss_instance_num() const { + CHECK_NOTNULL(loss_instance_num_ptr_); + return *loss_instance_num_ptr_; +} + +void Blob::set_loss_instance_num(float val) { + CHECK_NOTNULL(loss_instance_num_ptr_); + *loss_instance_num_ptr_ = val; +} + void Blob::set_dim2_valid_num(int64_t dim0_idx, int64_t dim1_idx, int64_t val) { CHECK_NOTNULL(dim2_valid_num_ptr_); CHECK_GE(dim0_idx, 0); diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h index af2edcb031fb838bd2eb2b7892da08bc6f2bd862..a9c45e113da1593e9fb67204cf5b698f265a4759 100644 --- a/oneflow/core/register/blob.h +++ b/oneflow/core/register/blob.h @@ -54,6 +54,10 @@ class Blob final { void set_record_id_in_device_piece(int64_t no, int64_t val); const int64_t* record_id_in_device_piece_ptr() const { return record_id_in_device_piece_ptr_; } int64_t* mut_record_id_in_device_piece_ptr() { return record_id_in_device_piece_ptr_; } + float loss_instance_num() const; + void set_loss_instance_num(float val); + const float* loss_instance_num_ptr() const { return loss_instance_num_ptr_; } + float* mut_loss_instance_num_ptr() { return loss_instance_num_ptr_; } const void* header_ptr() const { return header_ptr_; } void* mut_header_ptr() { return header_ptr_; } @@ -99,6 +103,7 @@ class Blob final { bool has_record_id_in_device_piece_field() const { return blob_desc_->has_record_id_in_device_piece_field(); } + bool has_loss_instance_num_field() const { return blob_desc_->has_loss_instance_num_field(); } int32_t max_col_num() const { return blob_desc_->max_col_num(); } size_t ByteSizeOfBlobHeader() const { return blob_desc_->ByteSizeOfBlobHeader(); } size_t ByteSizeOfDataIdField() const { return blob_desc_->ByteSizeOfDataIdField(); } @@ -161,6 +166,7 @@ class Blob final { int64_t* dim1_valid_num_ptr_; int64_t* dim2_valid_num_ptr_; int64_t* record_id_in_device_piece_ptr_; + float* loss_instance_num_ptr_; void* dptr_; const RtBlobDesc* blob_desc_; Regst* regst_; diff --git a/oneflow/core/register/blob_desc.cpp b/oneflow/core/register/blob_desc.cpp index 4714eeee09aac5abb030e6c9ca8808368341b1ce..e35e5a6ec5c62c52ea729e8ca9f386a69f4c1039 100644 --- a/oneflow/core/register/blob_desc.cpp +++ b/oneflow/core/register/blob_desc.cpp @@ -16,6 +16,7 @@ BlobDesc::BlobDesc(const Shape& shape, DataType data_type, bool has_data_id, boo has_dim1_valid_num_(false), has_dim2_valid_num_(false), has_record_id_in_device_piece_(false), + has_loss_instance_num_(false), max_col_num_(max_col_num), blob_mem_id_(-1), body_field_(shape, data_type) {} @@ -35,6 +36,7 @@ void BlobDesc::InitFromProto(const BlobDescProto& proto) { has_dim1_valid_num_ = false; has_dim2_valid_num_ = false; has_record_id_in_device_piece_ = false; + has_loss_instance_num_ = false; opaque_header_ = FieldDesc(proto.header().opaque_header()); } else { CHECK(proto.header().has_field_header()); @@ -45,6 +47,7 @@ void BlobDesc::InitFromProto(const BlobDescProto& proto) { has_dim1_valid_num_ = header_pod_desc_.HasField(FieldKey::kDim1ValidNum); has_dim2_valid_num_ = header_pod_desc_.HasField(FieldKey::kDim2ValidNum); has_record_id_in_device_piece_ = header_pod_desc_.HasField(FieldKey::kRecordIdInDevicePiece); + has_loss_instance_num_ = header_pod_desc_.HasField(FieldKey::kLossInstanceNum); } if (proto.has_dim0_inner_shape()) { dim0_inner_shape_.reset(new Shape(proto.dim0_inner_shape())); @@ -59,6 +62,7 @@ BlobDesc::BlobDesc(const StructPodDesc& header_pod_desc, int64_t header_byte_siz has_dim1_valid_num_(false), has_dim2_valid_num_(false), has_record_id_in_device_piece_(false), + has_loss_instance_num_(false), max_col_num_(max_col_num), blob_mem_id_(-1), body_field_(shape, data_type) { @@ -102,6 +106,11 @@ void BlobDesc::set_has_record_id_in_device_piece_field(bool val) { has_record_id_in_device_piece_ = val; } +void BlobDesc::set_has_loss_instance_num_field(bool val) { + CHECK(!header_is_opaque_); + has_loss_instance_num_ = val; +} + Shape& BlobDesc::mut_dim0_inner_shape() { CHECK(!header_is_opaque_); if (!dim0_inner_shape_) { dim0_inner_shape_.reset(new Shape()); } @@ -145,6 +154,10 @@ void BlobDesc::RecordIdInDevicePieceToProto(StructPodDesc* header_pod_desc) cons TensorPodDesc(shape, DataType::kInt64)); } +void BlobDesc::LossInstanceNumToProto(StructPodDesc* header_pod_desc) const { + header_pod_desc->AddField(FieldKey::kLossInstanceNum, TensorPodDesc({1}, DataType::kFloat)); +} + void BlobDesc::HeaderToProto(BlobDescProto* proto) const { proto->mutable_header()->set_max_col_num(max_col_num_); proto->mutable_header()->set_blob_mem_id(blob_mem_id_); @@ -157,6 +170,7 @@ void BlobDesc::HeaderToProto(BlobDescProto* proto) const { if (has_dim1_valid_num_field()) { Dim1ValidNumToProto(&header_pod_desc); } if (has_dim2_valid_num_field()) { Dim2ValidNumToProto(&header_pod_desc); } if (has_record_id_in_device_piece_field()) { RecordIdInDevicePieceToProto(&header_pod_desc); } + if (has_loss_instance_num_field()) { LossInstanceNumToProto(&header_pod_desc); } header_pod_desc.ToProto(proto->mutable_header()->mutable_header_pod_desc()); } else { opaque_header_.ToProto(proto->mutable_header()->mutable_opaque_header()); @@ -177,8 +191,8 @@ bool BlobDesc::operator==(const BlobDesc& rhs) const { && has_dim1_valid_num_ == rhs.has_dim1_valid_num_ && has_dim2_valid_num_ == rhs.has_dim2_valid_num_ && has_record_id_in_device_piece_ == rhs.has_record_id_in_device_piece_ - && max_col_num_ == rhs.max_col_num_ && blob_mem_id_ == rhs.blob_mem_id_ - && body_field_ == rhs.body_field_; + && has_loss_instance_num_ == rhs.has_loss_instance_num_ && max_col_num_ == rhs.max_col_num_ + && blob_mem_id_ == rhs.blob_mem_id_ && body_field_ == rhs.body_field_; } BlobDesc& BlobDesc::operator=(const BlobDesc& blob_desc) { @@ -248,4 +262,9 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc( return ret; } +bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) { + return (lhs.blob_desc().header().blob_mem_id() < rhs.blob_desc().header().blob_mem_id()) + || (lhs.lbi() < rhs.lbi()); +} + } // namespace oneflow diff --git a/oneflow/core/register/blob_desc.h b/oneflow/core/register/blob_desc.h index 39467de729fae56509dd6806ee15e6982b88b9e0..80bbe18d135cff32afe1f3b0a54f762e48e889e3 100644 --- a/oneflow/core/register/blob_desc.h +++ b/oneflow/core/register/blob_desc.h @@ -7,6 +7,7 @@ #include "oneflow/core/register/blob_desc.pb.h" #include "oneflow/core/register/pod_desc.h" #include "oneflow/core/job/job_desc.h" +#include "oneflow/core/register/register_desc.pb.h" namespace oneflow { @@ -50,6 +51,9 @@ class BlobDesc { bool has_record_id_in_device_piece_field() const { return has_record_id_in_device_piece_; } void set_has_record_id_in_device_piece_field(bool val); + bool has_loss_instance_num_field() const { return has_loss_instance_num_; } + void set_has_loss_instance_num_field(bool val); + bool has_col_num_field() const { return has_col_num_; } void set_has_col_num_field(bool val); @@ -72,6 +76,7 @@ class BlobDesc { void Dim1ValidNumToProto(StructPodDesc* header_pod_desc) const; void Dim2ValidNumToProto(StructPodDesc* header_pod_desc) const; void RecordIdInDevicePieceToProto(StructPodDesc* header_pod_desc) const; + void LossInstanceNumToProto(StructPodDesc* header_pod_desc) const; bool header_is_opaque_; FieldDesc opaque_header_; @@ -83,6 +88,7 @@ class BlobDesc { bool has_dim1_valid_num_; bool has_dim2_valid_num_; bool has_record_id_in_device_piece_; + bool has_loss_instance_num_; int64_t max_col_num_; int32_t blob_mem_id_; @@ -93,6 +99,8 @@ class BlobDesc { std::unique_ptr<BlobDesc> ComputePackedBlobDesc( const HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>>& lbi2blob_desc); +bool CompareLbiBlobDescPair(const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs); + } // namespace oneflow #endif // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ diff --git a/oneflow/core/register/pod.proto b/oneflow/core/register/pod.proto index 3ca40a886206c61ddf939258b9be66ca50db94b6..a2a1c83f6f66117fc0d83249a3a3f92b5836e37a 100644 --- a/oneflow/core/register/pod.proto +++ b/oneflow/core/register/pod.proto @@ -22,6 +22,7 @@ enum FieldKey { kDim1ValidNum = 4; kDim2ValidNum = 5; kRecordIdInDevicePiece = 6; + kLossInstanceNum = 7; } message FieldId { diff --git a/oneflow/core/register/pod_desc.h b/oneflow/core/register/pod_desc.h index 19444c2077b9ab27286d8c17fb894cfe0ad5df27..145dcf890f86afa6139eacc9f8ad1386147d7296 100644 --- a/oneflow/core/register/pod_desc.h +++ b/oneflow/core/register/pod_desc.h @@ -81,23 +81,25 @@ class StructPodDesc final : public PodDesc { ~StructPodDesc() = default; StructPodDesc* MutStructField(const FieldId& field_id); + StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment); const PodDesc& Field(FieldKey field_key) const { return Field(NewFieldId(field_key)); } const PodDesc& Field(const FieldId& field_id) const; void AddField(FieldKey field_key, const PodDesc& pod_desc); void AddField(const FieldId& field_id, const PodDesc& pod_desc); - size_t ByteSize() const override; - void InitFromProto(const StructPodProto& struct_pod); - + void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment); bool HasField(FieldKey field_key) const { return HasField(NewFieldId(field_key)); } bool HasField(const FieldId& field_id) const; - StructPodDesc& operator=(const StructPodDesc&); + std::unique_ptr<PodDesc> Clone() const override { return std::make_unique<StructPodDesc>(*this); } + void InitFromProto(const StructPodProto& struct_pod); void ToProto(PodProto* pod_proto) const override { ToProto(pod_proto->mutable_struct_pod()); } void ToProto(StructPodProto* pod_proto) const; - StructPodDesc* MutStructField(const FieldId& field_id, int32_t default_alignment); - void AddField(const FieldId& field_id, const PodDesc& pod_desc, size_t alignment); - bool operator==(const PodDesc& rhs) const override; + size_t ByteOffset4Field(const FieldId& field_name) const; + size_t ByteSize() const override; + + StructPodDesc& operator=(const StructPodDesc&); + bool operator==(const PodDesc& rhs) const override; private: void Clear(); diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index 04b736b61b3c5df5fbaa991de036f66dbfbd58a7..7730c3f348ed97b16601376a64cb71331a9f0e25 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -3,6 +3,7 @@ #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/register/runtime_blob_desc.h" +#include "oneflow/core/register/runtime_register_desc.h" namespace oneflow { @@ -25,6 +26,12 @@ RegstDesc::RegstDesc() { enable_mem_sharing_ = false; mem_shared_id_ = -1; mem_shared_offset_ = -1; + mem_shared_inplace_block_id_ = -1; +} + +int64_t RegstDesc::mem_shared_offset() const { + CHECK_GE(mem_shared_offset_, 0); + return mem_shared_offset_; } void RegstDesc::AddConsumer(const TaskNode* new_consumer) { @@ -60,6 +67,12 @@ void RegstDesc::CopyBlobDescFrom(const RegstDesc* rhs) { CopyBlobDescWithoutAddLbi(rhs); } +void RegstDesc::CopyMemSharedInfoFrom(const RegstDesc* rhs) { + enable_mem_sharing_ = rhs->enable_mem_sharing_; + mem_shared_id_ = rhs->mem_shared_id_; + mem_shared_offset_ = rhs->mem_shared_offset_; +} + void RegstDesc::CopyBlobDescWithoutAddLbi(const RegstDesc* rhs) { CHECK_EQ(is_locked_, false); for (const auto& pair : lbi2blob_desc_) { @@ -141,6 +154,15 @@ void RegstDesc::ToProto(RegstDescProto* ret) const { ret->set_enable_mem_sharing(enable_mem_sharing_); ret->set_mem_shared_id(mem_shared_id_); ret->set_mem_shared_offset(mem_shared_offset_); + ret->add_mem_block_hierarchy()->set_mem_block_id(Global<IDMgr>::Get()->NewMemBlockId()); + if (mem_shared_inplace_block_id_ != -1) { + ret->add_mem_block_hierarchy()->set_mem_block_id(mem_shared_inplace_block_id_); + } +} + +bool RegstDesc::HasSameMemSize(const RegstDesc* rhs) { + return RtBlobDesc(*(packed_blob_desc_.get())).TotalByteSize() + == RtBlobDesc(*(rhs->packed_blob_desc_.get())).TotalByteSize(); } bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) { @@ -153,6 +175,34 @@ bool RegstDesc::HasSameBlobDescs(const RegstDesc* rhs) { return true; } +int64_t RegstDesc::ByteOffsetInPackedBlobDescBody(const LogicalBlobId& lbi) const { + RegstDescProto regst_desc_proto; + ToProto(®st_desc_proto); + RtRegstDesc rt_regst_desc(regst_desc_proto); + std::vector<LbiBlobDescPair> lbi_blob_desc_pairs; + for (const auto& pair : lbi2blob_desc_) { + LbiBlobDescPair lbi_blob_desc_pair; + *lbi_blob_desc_pair.mutable_lbi() = pair.first; + pair.second->ToProto(lbi_blob_desc_pair.mutable_blob_desc()); + lbi_blob_desc_pairs.push_back(lbi_blob_desc_pair); + } + std::sort(lbi_blob_desc_pairs.begin(), lbi_blob_desc_pairs.end(), CompareLbiBlobDescPair); + + bool found = false; + int64_t offset = 0; + rt_regst_desc.ForEachBlobDescOffsetInOnRegst( + lbi_blob_desc_pairs, + [&](const LbiBlobDescPair& lbi_blob_desc_pair, int64_t body_offset, int64_t header_offset) { + if (found) { return; } + if (lbi_blob_desc_pair.lbi() == lbi) { + offset = body_offset; + found = true; + } + }); + CHECK(found); + return offset; +} + void InitCtrlRegstDesc(int64_t producer_task_id, RegstDescProto* ctrl_regst_proto) { CHECK_NOTNULL(ctrl_regst_proto); ctrl_regst_proto->set_regst_desc_id(Global<IDMgr>::Get()->NewRegstDescId()); diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index ed09ec71fb0ccb6ff7cf2e5895194a94c74ab26e..ec473e2b3e42f005b2943825334054376b90d8c9 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -50,14 +50,20 @@ class RegstDesc final { // mem const MemoryCase& mem_case() const { return mem_case_; } MemoryCase* mut_mem_case() { return &mem_case_; } + bool enable_mem_sharing() { return enable_mem_sharing_; } void set_enable_mem_sharing(bool enable_mem_sharing) { enable_mem_sharing_ = enable_mem_sharing; } - int64_t mem_shared_offset() const { return mem_shared_offset_; } + int64_t mem_shared_offset() const; void set_mem_shared_offset(int64_t val) { mem_shared_offset_ = val; } + int64_t mem_shared_inplace_block_id() const { return mem_shared_inplace_block_id_; } + void set_mem_shared_inplace_block_id(int64_t val) { mem_shared_inplace_block_id_ = val; } int32_t mem_shared_id() const { return mem_shared_id_; } void set_mem_shared_id(int32_t val) { mem_shared_id_ = val; } + bool HasSetMemSharedId() { return mem_shared_id_ != -1; } + void CopyMemSharedInfoFrom(const RegstDesc*); const std::shared_ptr<Shape>& data_regst_time_shape() const { CHECK(regst_desc_type_.has_data_regst_desc()); + CHECK(data_regst_time_shape_); return data_regst_time_shape_; } std::shared_ptr<Shape>* mut_data_regst_time_shape() { @@ -66,12 +72,14 @@ class RegstDesc final { } RegstDescTypeProto* mut_regst_desc_type() { return ®st_desc_type_; } const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; } + bool HasSameMemSize(const RegstDesc*); // util int32_t MaxColNum() const { return packed_blob_desc_->max_col_num(); } void EraseZeroSizeBlob(); void ToProto(RegstDescProto*) const; bool HasSameBlobDescs(const RegstDesc*); + int64_t ByteOffsetInPackedBlobDescBody(const LogicalBlobId& lbi) const; private: int64_t regst_desc_id_; @@ -89,10 +97,28 @@ class RegstDesc final { bool enable_mem_sharing_; int32_t mem_shared_id_; int64_t mem_shared_offset_; + int64_t mem_shared_inplace_block_id_; std::shared_ptr<Shape> data_regst_time_shape_; }; +inline bool operator==(const MemBlock& lhs, const MemBlock& rhs) { + bool ret = (lhs.mem_block_id() == rhs.mem_block_id()); + if (ret) { CHECK_EQ(lhs.mem_reduce_method(), rhs.mem_reduce_method()); } + return ret; +} + } // namespace oneflow +namespace std { + +template<> +struct hash<oneflow::MemBlock> final { + size_t operator()(const oneflow::MemBlock& mem_block) const { + return hash<int64_t>()(mem_block.mem_block_id()); + } +}; + +} // namespace std + #endif // ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_ diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto index e933c283c8eb445b9c2bb877b71f6fdf2ba7e916..37e6111faf00bad2658ed1492f83a44c6ad39fc6 100644 --- a/oneflow/core/register/register_desc.proto +++ b/oneflow/core/register/register_desc.proto @@ -29,6 +29,17 @@ message RegstDescTypeProto { } } +enum MemReduceMethod { + kMemInvalidSharedMethod = 0; + kMemSum = 1; + kMemMax = 2; +} + +message MemBlock { + required int64 mem_block_id = 1; + optional MemReduceMethod mem_reduce_method = 2 [default = kMemMax]; +} + message RegstDescProto { required int64 regst_desc_id = 1; required int64 producer_task_id = 2; @@ -41,4 +52,6 @@ message RegstDescProto { required bool enable_mem_sharing = 9; required int32 mem_shared_id = 10; required int64 mem_shared_offset = 11; + // from bottom to top + repeated MemBlock mem_block_hierarchy = 13; } diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp index 6d79acda4ce17b49e1b14fff8589d7fb91fce069..2ef56a5b6f970d25b04b346faa319a503ae14acb 100644 --- a/oneflow/core/register/register_manager.cpp +++ b/oneflow/core/register/register_manager.cpp @@ -75,12 +75,7 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, for (const LbiBlobDescPair& pair : regst_desc_type.data_regst_desc().lbi2blob_desc()) { lbi_pairs.push_back(pair); } - std::sort(lbi_pairs.begin(), lbi_pairs.end(), - [&](const LbiBlobDescPair& lhs, const LbiBlobDescPair& rhs) { - return lhs.blob_desc().header().blob_mem_id() - < rhs.blob_desc().header().blob_mem_id() - || lhs.lbi() < rhs.lbi(); - }); + std::sort(lbi_pairs.begin(), lbi_pairs.end(), &CompareLbiBlobDescPair); CHECK(!lbi_pairs.empty()); CHECK(main_mem_ptr != nullptr); } @@ -126,23 +121,14 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs cur_header_pointer = main_mem_ptr; cur_body_pointer = main_mem_ptr + packed_blob_desc->ByteSizeOfBlobHeader(); } - int32_t last_blob_mem_id = -1; - size_t last_size = 0; - for (const LbiBlobDescPair& lbi : lbis) { - const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi()); - int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id(); - if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) { - cur_body_pointer += last_size; - } - std::unique_ptr<Blob> blob_ptr( - new Blob(regst, blob_desc, cur_header_pointer, cur_body_pointer)); - InitOFRecordBlobIfNeed(blob_ptr.get()); - CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second); - cur_header_pointer += blob_desc->ByteSizeOfBlobHeader(); - - last_blob_mem_id = cur_blob_mem_id; - last_size = blob_desc->ByteSizeOfBlobBody(); - } + rt_regst_desc->ForEachBlobDescOffsetInOnRegst( + lbis, [&](const LbiBlobDescPair& lbi, int64_t body_offset, int64_t header_offset) { + const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi()); + std::unique_ptr<Blob> blob_ptr(new Blob( + regst, blob_desc, cur_header_pointer + header_offset, cur_body_pointer + body_offset)); + InitOFRecordBlobIfNeed(blob_ptr.get()); + CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second); + }); } void RegstMgr::InitOFRecordBlobIfNeed(Blob* blob_ptr) { diff --git a/oneflow/core/register/runtime_blob_desc.cpp b/oneflow/core/register/runtime_blob_desc.cpp index 6251d6b736ca0110e3de3ac68d7cb4966d9d1a31..59c1a9f47be64637500b80533bfb50aa518101c9 100644 --- a/oneflow/core/register/runtime_blob_desc.cpp +++ b/oneflow/core/register/runtime_blob_desc.cpp @@ -43,6 +43,10 @@ bool RtBlobDesc::has_record_id_in_device_piece_field() const { return header_pod_desc_.HasField(FieldKey::kRecordIdInDevicePiece); } +bool RtBlobDesc::has_loss_instance_num_field() const { + return header_pod_desc_.HasField(FieldKey::kLossInstanceNum); +} + size_t RtBlobDesc::ByteSizeOfBlobHeader() const { return header_pod_desc_.ByteSize(); } size_t RtBlobDesc::ByteSizeOfBlobBody() const { return body_desc_.AlignedByteSize(); } diff --git a/oneflow/core/register/runtime_blob_desc.h b/oneflow/core/register/runtime_blob_desc.h index 7cb6a8d3f4faa89744b9c026c516f8b8e84b9a55..d84591039e48157d5cb391e061c03f1943b97c89 100644 --- a/oneflow/core/register/runtime_blob_desc.h +++ b/oneflow/core/register/runtime_blob_desc.h @@ -29,6 +29,7 @@ class RtBlobDesc { bool has_dim1_valid_num_field() const; bool has_dim2_valid_num_field() const; bool has_record_id_in_device_piece_field() const; + bool has_loss_instance_num_field() const; const StructPodDesc& header_pod_desc() const { return header_pod_desc_; } int32_t max_col_num() const { return blob_desc_proto_.header().max_col_num(); } diff --git a/oneflow/core/register/runtime_register_desc.cpp b/oneflow/core/register/runtime_register_desc.cpp index 83826b0b355de45d90c5318f83a787a0e0f89018..cbd94e251f49787ab6d443949bd42d8bf989c52d 100644 --- a/oneflow/core/register/runtime_register_desc.cpp +++ b/oneflow/core/register/runtime_register_desc.cpp @@ -70,4 +70,25 @@ const Shape& RtRegstDesc::data_regst_time_shape() const { return *data_regst_time_shape_; } +void RtRegstDesc::ForEachBlobDescOffsetInOnRegst( + const std::vector<LbiBlobDescPair>& lbis, + const std::function<void(const LbiBlobDescPair&, int64_t body_offset, int64_t header_offset)>& + Handler) const { + int32_t last_blob_mem_id = -1; + size_t last_size = 0; + int64_t cur_body_offset = 0; + int64_t cur_header_offset = 0; + for (const LbiBlobDescPair& lbi : lbis) { + const RtBlobDesc* blob_desc = GetRtBlobDescFromLbi(lbi.lbi()); + int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id(); + if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) { + cur_body_offset += last_size; + } + Handler(lbi, cur_body_offset, cur_header_offset); + cur_header_offset += blob_desc->ByteSizeOfBlobHeader(); + last_blob_mem_id = cur_blob_mem_id; + last_size = blob_desc->ByteSizeOfBlobBody(); + } +} + } // namespace oneflow diff --git a/oneflow/core/register/runtime_register_desc.h b/oneflow/core/register/runtime_register_desc.h index 6747b4619a685ea177d28ce22764b2467ecbb550..939e2e5426961698ec5b26468db8b4a51afd647c 100644 --- a/oneflow/core/register/runtime_register_desc.h +++ b/oneflow/core/register/runtime_register_desc.h @@ -31,6 +31,11 @@ class RtRegstDesc { size_t MainByteSize4OneRegst() const; const Shape& data_regst_time_shape() const; + void ForEachBlobDescOffsetInOnRegst( + const std::vector<LbiBlobDescPair>& lbis, + const std::function<void(const LbiBlobDescPair&, int64_t body_offset, int64_t header_offset)>& + Handler) const; + private: int64_t regst_desc_id_; int64_t producer_actor_id_; diff --git a/oneflow/core/thread/gpu_thread.cpp b/oneflow/core/thread/gpu_thread.cpp index 8e36bffdaaeca2ee6d948171c83fb91d1a024eee..ab13574a6a5b7f9a35a751cd6fde2b9486df1588 100644 --- a/oneflow/core/thread/gpu_thread.cpp +++ b/oneflow/core/thread/gpu_thread.cpp @@ -14,7 +14,8 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) { ctx.cb_event_chan = &cb_event_chan_; PollMsgChannel(ctx); }); - cb_event_poller_ = std::thread([this]() { + cb_event_poller_ = std::thread([this, dev_id]() { + CudaCheck(cudaSetDevice(dev_id)); CudaCBEvent cb_event; while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) { CudaCheck(cudaEventSynchronize(cb_event.event));