diff --git a/oneflow/core/kernel/foreign_input_kernel.cpp b/oneflow/core/kernel/foreign_input_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f817bf5fc85508f4a578be8c4da58325287dac5 --- /dev/null +++ b/oneflow/core/kernel/foreign_input_kernel.cpp @@ -0,0 +1,20 @@ +#include "oneflow/core/kernel/foreign_input_kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/register/foreign_blob.h" + +namespace oneflow { + +void ForeignInputKernel::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const auto& buffer_name = op_conf().foreign_input_conf().foreign_blob_buffer_name(); + std::shared_ptr<ForeignBlob> foreign_blob; + BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignBlob>>>::Get() + ->Get(buffer_name) + ->TryReceive(&foreign_blob); + CHECK_NE(buffer_status, kBufferStatusEmpty); + foreign_blob->CopyTo(BnInOp2Blob("out")); +} + +REGISTER_KERNEL(OperatorConf::kForeignInputConf, ForeignInputKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/foreign_input_kernel.h b/oneflow/core/kernel/foreign_input_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2e1cae19c0502d5afaa74fdd24ac4b50bcdad019 --- /dev/null +++ b/oneflow/core/kernel/foreign_input_kernel.h @@ -0,0 +1,21 @@ +#ifndef ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +class ForeignInputKernel final : public KernelIf<DeviceType::kCPU> { + public: + OF_DISALLOW_COPY_AND_MOVE(ForeignInputKernel); + ForeignInputKernel() = default; + ~ForeignInputKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_ diff --git a/oneflow/core/kernel/foreign_output_kernel.cpp b/oneflow/core/kernel/foreign_output_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d00377c2d2d15cb399db1f75c1768ebbcf47e61 --- /dev/null +++ b/oneflow/core/kernel/foreign_output_kernel.cpp @@ -0,0 +1,20 @@ +#include "oneflow/core/kernel/foreign_output_kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/register/foreign_blob.h" + +namespace oneflow { + +void ForeignOutputKernel::ForwardDataContent( + const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + const auto& buffer_name = op_conf().foreign_output_conf().foreign_blob_buffer_name(); + std::shared_ptr<ForeignBlob> foreign_blob; + BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignBlob>>>::Get() + ->Get(buffer_name) + ->TryReceive(&foreign_blob); + CHECK_NE(buffer_status, kBufferStatusEmpty); + foreign_blob->CopyFrom(BnInOp2Blob("in")); +} + +REGISTER_KERNEL(OperatorConf::kForeignOutputConf, ForeignOutputKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/foreign_output_kernel.h b/oneflow/core/kernel/foreign_output_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9d9d36292374b31af065925a5cee3e5b914ebf05 --- /dev/null +++ b/oneflow/core/kernel/foreign_output_kernel.h @@ -0,0 +1,21 @@ +#ifndef ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +class ForeignOutputKernel final : public KernelIf<DeviceType::kCPU> { + public: + OF_DISALLOW_COPY_AND_MOVE(ForeignOutputKernel); + ForeignOutputKernel() = default; + ~ForeignOutputKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function<Blob*(const std::string&)> BnInOp2Blob) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_ diff --git a/oneflow/core/kernel/input_kernel.h b/oneflow/core/kernel/input_kernel.h index 4aace67a7eac039f38227befd2cfff149cde90f7..4b18da3ae6eb8b07a7f2747294591dedfbc822a9 100644 --- a/oneflow/core/kernel/input_kernel.h +++ b/oneflow/core/kernel/input_kernel.h @@ -1,3 +1,6 @@ +#ifndef ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_ + #include "oneflow/core/kernel/kernel.h" namespace oneflow { @@ -15,3 +18,5 @@ class InputKernel final : public KernelIf<device_type> { }; } // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_ diff --git a/oneflow/core/kernel/output_kernel.h b/oneflow/core/kernel/output_kernel.h index f9b6eed664aac5a54d06346e4e311f23be9e731f..24a0aedb6bb895e91efedcc238a09bc96510569f 100644 --- a/oneflow/core/kernel/output_kernel.h +++ b/oneflow/core/kernel/output_kernel.h @@ -1,3 +1,6 @@ +#ifndef ONEFLOW_CORE_KERNEL_OUTPUT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_OUTPUT_KERNEL_H_ + #include "oneflow/core/kernel/kernel.h" namespace oneflow { @@ -15,3 +18,5 @@ class OutputKernel final : public KernelIf<device_type> { }; } // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_OUTPUT_KERNEL_H_ diff --git a/oneflow/core/operator/foreign_input_op.cpp b/oneflow/core/operator/foreign_input_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3c65ec9f88563aa1311f5395e353f678526deb8 --- /dev/null +++ b/oneflow/core/operator/foreign_input_op.cpp @@ -0,0 +1,52 @@ +#include "oneflow/core/operator/foreign_input_op.h" +#include "oneflow/core/job/sbp_signature_builder.h" + +namespace oneflow { + +namespace { + +void CheckOpConf(const OperatorConf& op_conf) { + CHECK(op_conf.ctrl_in_op_name().empty()); + if (op_conf.foreign_input_conf().blob_conf().has_dim0_inner_shape()) { TODO(); } + if (op_conf.foreign_input_conf().blob_conf().has_dim1_valid_num()) { TODO(); } + if (op_conf.foreign_input_conf().blob_conf().has_dim2_valid_num()) { TODO(); } +} + +} // namespace + +void ForeignInputOp::InitFromOpConf() { + CHECK(op_conf().has_foreign_input_conf()); + if (op_conf().foreign_input_conf().has_tick()) { EnrollOutputBn("tick", false); } + EnrollOutputBn("out", false); +} + +const PbMessage& ForeignInputOp::GetCustomizedConf() const { + return op_conf().foreign_input_conf(); +} + +void ForeignInputOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + CHECK_EQ(parallel_ctx->parallel_num(), 1); + CheckOpConf(op_conf()); + const auto& conf = op_conf().foreign_input_conf().blob_conf(); + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + out_blob_desc->mut_shape() = Shape(conf.shape()); + if (conf.has_data_type()) { + out_blob_desc->set_data_type(conf.data_type()); + } else { + out_blob_desc->set_data_type(Global<JobDesc>::Get()->DefaultDataType()); + } + out_blob_desc->set_has_dim1_valid_num_field(conf.dim0_valid_num()); +} + +void ForeignInputOp::InferHasBatchDim( + std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const { + *HasBatchDim4BnInOp("out") = op_conf().foreign_input_conf().blob_conf().has_batch_dim(); +} + +void ForeignInputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {} + +REGISTER_OP(OperatorConf::kForeignInputConf, ForeignInputOp); +REGISTER_OP_SAME_OUTPUT_BLOB_MEM_BLOCK_NUM(OperatorConf::kForeignInputConf, 1); + +} // namespace oneflow diff --git a/oneflow/core/operator/foreign_input_op.h b/oneflow/core/operator/foreign_input_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d2e47a0b55932f0bb02df7db36a48d35a845eed3 --- /dev/null +++ b/oneflow/core/operator/foreign_input_op.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_OPERATOR_FOREIGN_INPUT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_FOREIGN_INPUT_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ForeignInputOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ForeignInputOp); + ForeignInputOp() : Operator() {} + ~ForeignInputOp() = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + void InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override; + void GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_FOREIGN_INPUT_OP_H_ diff --git a/oneflow/core/operator/foreign_output_op.cpp b/oneflow/core/operator/foreign_output_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d30b2fa9d7509db028859a02ed5dea35e7f6260a --- /dev/null +++ b/oneflow/core/operator/foreign_output_op.cpp @@ -0,0 +1,30 @@ +#include "oneflow/core/operator/foreign_output_op.h" +#include "oneflow/core/job/sbp_signature_builder.h" + +namespace oneflow { + +void ForeignOutputOp::InitFromOpConf() { + CHECK(op_conf().has_foreign_output_conf()); + EnrollInputBn("in"); +} + +void ForeignOutputOp::InferBlobDescs( + std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + CHECK_EQ(parallel_ctx->parallel_num(), 1); +} + +const PbMessage& ForeignOutputOp::GetCustomizedConf() const { + return op_conf().foreign_output_conf(); +} + +void ForeignOutputOp::InferHasBatchDim( + std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {} + +void ForeignOutputOp::GetSbpSignatures( + const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn, + SbpSignatureList* sbp_sig_list) const {} + +REGISTER_OP(OperatorConf::kForeignOutputConf, ForeignOutputOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/foreign_output_op.h b/oneflow/core/operator/foreign_output_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c1896d4e4d937772e967f81ce24c92c206866a2 --- /dev/null +++ b/oneflow/core/operator/foreign_output_op.h @@ -0,0 +1,28 @@ +#ifndef ONEFLOW_CORE_OPERATOR_FOREIGN_OUTPUT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_FOREIGN_OUTPUT_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class ForeignOutputOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(ForeignOutputOp); + ForeignOutputOp() = default; + ~ForeignOutputOp() override = default; + + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + void InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override; + void GetSbpSignatures( + const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn, + SbpSignatureList* sbp_sig_list) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_FOREIGN_OUTPUT_OP_H_ diff --git a/oneflow/core/operator/input_op.cpp b/oneflow/core/operator/input_op.cpp index a584613cca8e5c684960fafc02f16d37feae9529..337e211dddbd4fd880da8d5cec8c91563d38e3dd 100644 --- a/oneflow/core/operator/input_op.cpp +++ b/oneflow/core/operator/input_op.cpp @@ -7,10 +7,9 @@ namespace { void CheckOpConf(const OperatorConf& op_conf) { CHECK(op_conf.ctrl_in_op_name().empty()); - if (op_conf.input_conf().has_dim0_inner_shape()) { TODO(); } - if (op_conf.input_conf().has_dim0_valid_num()) { TODO(); } - if (op_conf.input_conf().has_dim1_valid_num()) { TODO(); } - if (op_conf.input_conf().has_dim2_valid_num()) { TODO(); } + if (op_conf.input_conf().blob_conf().has_dim0_inner_shape()) { TODO(); } + if (op_conf.input_conf().blob_conf().has_dim1_valid_num()) { TODO(); } + if (op_conf.input_conf().blob_conf().has_dim2_valid_num()) { TODO(); } } void CheckShape(const Shape& shape) { @@ -21,7 +20,7 @@ void CheckShape(const Shape& shape) { void InputOp::InitFromOpConf() { CHECK(op_conf().has_input_conf()); - if (op_conf().input_conf().has_tick()) { EnrollOutputBn("in", false); } + if (op_conf().input_conf().has_tick()) { EnrollInputBn("tick", false); } EnrollOutputBn("out", false); } @@ -30,7 +29,7 @@ const PbMessage& InputOp::GetCustomizedConf() const { return op_conf().input_con void InputOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, int64_t record_piece_size) const { CheckOpConf(op_conf()); - const auto& conf = op_conf().input_conf(); + const auto& conf = op_conf().input_conf().blob_conf(); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); out_blob_desc->mut_shape() = Shape(conf.shape()); CheckShape(out_blob_desc->shape()); @@ -45,14 +44,15 @@ void InputOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlo } else { out_blob_desc->set_data_type(Global<JobDesc>::Get()->DefaultDataType()); } + out_blob_desc->set_has_dim1_valid_num_field(conf.dim0_valid_num()); } void InputOp::InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const { - *HasBatchDim4BnInOp("out") = op_conf().input_conf().has_batch_dim(); + *HasBatchDim4BnInOp("out") = op_conf().input_conf().blob_conf().has_batch_dim(); } void InputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { - int64_t num_axes = op_conf().input_conf().shape().dim_size(); + int64_t num_axes = op_conf().input_conf().blob_conf().shape().dim_size(); SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index d5d931d10643a69b5cdd8a6f4f006e649361e0a0..0a5f11c0427c012a4d1faae9dd0941ca3e03427a 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -755,16 +755,27 @@ message CastOpConf { required DataType data_type = 3; } +message InputBlobConf { + required ShapeProto shape = 1; + optional DataType data_type = 2; + optional ShapeProto dim0_inner_shape = 3; + optional int64 dim0_valid_num = 4; + optional int64 dim1_valid_num = 5; + optional int64 dim2_valid_num = 6; + optional bool has_batch_dim = 7 [default = true]; +} + message InputOpConf { optional string tick = 1; required string out = 2; - required ShapeProto shape = 3; - optional DataType data_type = 4; - optional ShapeProto dim0_inner_shape = 5; - optional int64 dim0_valid_num = 6; - optional int64 dim1_valid_num = 7; - optional int64 dim2_valid_num = 8; - optional bool has_batch_dim = 9 [default = true]; + required InputBlobConf blob_conf = 3; +} + +message ForeignInputOpConf { + optional string tick = 1; + required string out = 2; + required InputBlobConf blob_conf = 3; + required string foreign_blob_buffer_name = 4; } message OutputOpConf { @@ -772,6 +783,12 @@ message OutputOpConf { required string out = 2; } +message ForeignOutputOpConf { + required string in = 1; + required string out = 2; + required string foreign_blob_buffer_name = 3; +} + message VariableOpConf { optional string tick = 1; required string out = 2; @@ -1465,6 +1482,8 @@ message OperatorConf { WaitAndSendIdsOpConf wait_and_send_ids_conf = 139; ReentrantLockOpConf reentrant_lock_conf = 140; CallbackNotifyOpConf callback_notify_conf = 141; + ForeignInputOpConf foreign_input_conf = 142; + ForeignOutputOpConf foreign_output_conf = 143; // domain op TupleIdentityOpConf tuple_identity_conf = 200; diff --git a/oneflow/core/operator/output_op.cpp b/oneflow/core/operator/output_op.cpp index eff6052bd74be1ff87d668da064f3b106139059a..7c348c9d330379eba584ece26597a9b23ac1d711 100644 --- a/oneflow/core/operator/output_op.cpp +++ b/oneflow/core/operator/output_op.cpp @@ -4,6 +4,7 @@ namespace oneflow { void OutputOp::InitFromOpConf() { + CHECK(op_conf().has_output_conf()); EnrollInputBn("in"); EnrollOutputBn("out"); } @@ -22,7 +23,6 @@ void OutputOp::InferHasBatchDim(std::function<bool*(const std::string&)> HasBatc void OutputOp::GetSbpSignatures( const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { - const auto bns = StdVec2PbRpf<std::string>({"in", "out"}); int64_t num_axes = LogicalBlobDesc4Ibn(input_bns().Get(0)).shape().NumAxes(); SbpSignatureBuilder() .Split(input_bns(), 0) diff --git a/oneflow/core/register/foreign_blob.h b/oneflow/core/register/foreign_blob.h new file mode 100644 index 0000000000000000000000000000000000000000..79e077cbfdd9f28d58550e61e43f70a5a0dd3ff0 --- /dev/null +++ b/oneflow/core/register/foreign_blob.h @@ -0,0 +1,19 @@ +#ifndef ONEFLOW_CORE_REGISTER_FOREIGN_BLOB_H_ +#define ONEFLOW_CORE_REGISTER_FOREIGN_BLOB_H_ + +#include "oneflow/core/register/blob.h" + +namespace oneflow { + +class ForeignBlob { + public: + OF_DISALLOW_COPY_AND_MOVE(ForeignBlob); + virtual ~ForeignBlob() = default; + + virtual void CopyFrom(const Blob* blob) = 0; + virtual void CopyTo(Blob* blob) const = 0; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_REGISTER_FOREIGN_BLOB_H_