Skip to content
Snippets Groups Projects
Commit 80b4bbe0 authored by Xinqi's avatar Xinqi
Browse files

ForeignInputOp/ForeignOutputOp

parent 25761df9
No related branches found
No related tags found
No related merge requests found
Showing
with 282 additions and 16 deletions
#include "oneflow/core/kernel/foreign_input_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/register/foreign_blob.h"
namespace oneflow {
void ForeignInputKernel::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto& buffer_name = op_conf().foreign_input_conf().foreign_blob_buffer_name();
std::shared_ptr<ForeignBlob> foreign_blob;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignBlob>>>::Get()
->Get(buffer_name)
->TryReceive(&foreign_blob);
CHECK_NE(buffer_status, kBufferStatusEmpty);
foreign_blob->CopyTo(BnInOp2Blob("out"));
}
REGISTER_KERNEL(OperatorConf::kForeignInputConf, ForeignInputKernel);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
class ForeignInputKernel final : public KernelIf<DeviceType::kCPU> {
public:
OF_DISALLOW_COPY_AND_MOVE(ForeignInputKernel);
ForeignInputKernel() = default;
~ForeignInputKernel() = default;
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_
#include "oneflow/core/kernel/foreign_output_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/register/foreign_blob.h"
namespace oneflow {
void ForeignOutputKernel::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto& buffer_name = op_conf().foreign_output_conf().foreign_blob_buffer_name();
std::shared_ptr<ForeignBlob> foreign_blob;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignBlob>>>::Get()
->Get(buffer_name)
->TryReceive(&foreign_blob);
CHECK_NE(buffer_status, kBufferStatusEmpty);
foreign_blob->CopyFrom(BnInOp2Blob("in"));
}
REGISTER_KERNEL(OperatorConf::kForeignOutputConf, ForeignOutputKernel);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
class ForeignOutputKernel final : public KernelIf<DeviceType::kCPU> {
public:
OF_DISALLOW_COPY_AND_MOVE(ForeignOutputKernel);
ForeignOutputKernel() = default;
~ForeignOutputKernel() = default;
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_FOREIGN_INPUT_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
......@@ -15,3 +18,5 @@ class InputKernel final : public KernelIf<device_type> {
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_OUTPUT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_OUTPUT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
......@@ -15,3 +18,5 @@ class OutputKernel final : public KernelIf<device_type> {
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_OUTPUT_KERNEL_H_
#include "oneflow/core/operator/foreign_input_op.h"
#include "oneflow/core/job/sbp_signature_builder.h"
namespace oneflow {
namespace {
void CheckOpConf(const OperatorConf& op_conf) {
CHECK(op_conf.ctrl_in_op_name().empty());
if (op_conf.foreign_input_conf().blob_conf().has_dim0_inner_shape()) { TODO(); }
if (op_conf.foreign_input_conf().blob_conf().has_dim1_valid_num()) { TODO(); }
if (op_conf.foreign_input_conf().blob_conf().has_dim2_valid_num()) { TODO(); }
}
} // namespace
void ForeignInputOp::InitFromOpConf() {
CHECK(op_conf().has_foreign_input_conf());
if (op_conf().foreign_input_conf().has_tick()) { EnrollOutputBn("tick", false); }
EnrollOutputBn("out", false);
}
const PbMessage& ForeignInputOp::GetCustomizedConf() const {
return op_conf().foreign_input_conf();
}
void ForeignInputOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
CHECK_EQ(parallel_ctx->parallel_num(), 1);
CheckOpConf(op_conf());
const auto& conf = op_conf().foreign_input_conf().blob_conf();
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape(conf.shape());
if (conf.has_data_type()) {
out_blob_desc->set_data_type(conf.data_type());
} else {
out_blob_desc->set_data_type(Global<JobDesc>::Get()->DefaultDataType());
}
out_blob_desc->set_has_dim1_valid_num_field(conf.dim0_valid_num());
}
void ForeignInputOp::InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {
*HasBatchDim4BnInOp("out") = op_conf().foreign_input_conf().blob_conf().has_batch_dim();
}
void ForeignInputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {}
REGISTER_OP(OperatorConf::kForeignInputConf, ForeignInputOp);
REGISTER_OP_SAME_OUTPUT_BLOB_MEM_BLOCK_NUM(OperatorConf::kForeignInputConf, 1);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_FOREIGN_INPUT_OP_H_
#define ONEFLOW_CORE_OPERATOR_FOREIGN_INPUT_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ForeignInputOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ForeignInputOp);
ForeignInputOp() : Operator() {}
~ForeignInputOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
void InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
void GetSbpSignatures(SbpSignatureList* sbp_sig_list) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_FOREIGN_INPUT_OP_H_
#include "oneflow/core/operator/foreign_output_op.h"
#include "oneflow/core/job/sbp_signature_builder.h"
namespace oneflow {
void ForeignOutputOp::InitFromOpConf() {
CHECK(op_conf().has_foreign_output_conf());
EnrollInputBn("in");
}
void ForeignOutputOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
CHECK_EQ(parallel_ctx->parallel_num(), 1);
}
const PbMessage& ForeignOutputOp::GetCustomizedConf() const {
return op_conf().foreign_output_conf();
}
void ForeignOutputOp::InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {}
void ForeignOutputOp::GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {}
REGISTER_OP(OperatorConf::kForeignOutputConf, ForeignOutputOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_FOREIGN_OUTPUT_OP_H_
#define ONEFLOW_CORE_OPERATOR_FOREIGN_OUTPUT_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ForeignOutputOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ForeignOutputOp);
ForeignOutputOp() = default;
~ForeignOutputOp() override = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
void InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
void GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_FOREIGN_OUTPUT_OP_H_
......@@ -7,10 +7,9 @@ namespace {
void CheckOpConf(const OperatorConf& op_conf) {
CHECK(op_conf.ctrl_in_op_name().empty());
if (op_conf.input_conf().has_dim0_inner_shape()) { TODO(); }
if (op_conf.input_conf().has_dim0_valid_num()) { TODO(); }
if (op_conf.input_conf().has_dim1_valid_num()) { TODO(); }
if (op_conf.input_conf().has_dim2_valid_num()) { TODO(); }
if (op_conf.input_conf().blob_conf().has_dim0_inner_shape()) { TODO(); }
if (op_conf.input_conf().blob_conf().has_dim1_valid_num()) { TODO(); }
if (op_conf.input_conf().blob_conf().has_dim2_valid_num()) { TODO(); }
}
void CheckShape(const Shape& shape) {
......@@ -21,7 +20,7 @@ void CheckShape(const Shape& shape) {
void InputOp::InitFromOpConf() {
CHECK(op_conf().has_input_conf());
if (op_conf().input_conf().has_tick()) { EnrollOutputBn("in", false); }
if (op_conf().input_conf().has_tick()) { EnrollInputBn("tick", false); }
EnrollOutputBn("out", false);
}
......@@ -30,7 +29,7 @@ const PbMessage& InputOp::GetCustomizedConf() const { return op_conf().input_con
void InputOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, int64_t record_piece_size) const {
CheckOpConf(op_conf());
const auto& conf = op_conf().input_conf();
const auto& conf = op_conf().input_conf().blob_conf();
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape(conf.shape());
CheckShape(out_blob_desc->shape());
......@@ -45,14 +44,15 @@ void InputOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlo
} else {
out_blob_desc->set_data_type(Global<JobDesc>::Get()->DefaultDataType());
}
out_blob_desc->set_has_dim1_valid_num_field(conf.dim0_valid_num());
}
void InputOp::InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {
*HasBatchDim4BnInOp("out") = op_conf().input_conf().has_batch_dim();
*HasBatchDim4BnInOp("out") = op_conf().input_conf().blob_conf().has_batch_dim();
}
void InputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
int64_t num_axes = op_conf().input_conf().shape().dim_size();
int64_t num_axes = op_conf().input_conf().blob_conf().shape().dim_size();
SbpSignatureBuilder()
.Split(input_bns(), 0)
.Split(output_bns(), 0)
......
......@@ -755,16 +755,27 @@ message CastOpConf {
required DataType data_type = 3;
}
message InputBlobConf {
required ShapeProto shape = 1;
optional DataType data_type = 2;
optional ShapeProto dim0_inner_shape = 3;
optional int64 dim0_valid_num = 4;
optional int64 dim1_valid_num = 5;
optional int64 dim2_valid_num = 6;
optional bool has_batch_dim = 7 [default = true];
}
message InputOpConf {
optional string tick = 1;
required string out = 2;
required ShapeProto shape = 3;
optional DataType data_type = 4;
optional ShapeProto dim0_inner_shape = 5;
optional int64 dim0_valid_num = 6;
optional int64 dim1_valid_num = 7;
optional int64 dim2_valid_num = 8;
optional bool has_batch_dim = 9 [default = true];
required InputBlobConf blob_conf = 3;
}
message ForeignInputOpConf {
optional string tick = 1;
required string out = 2;
required InputBlobConf blob_conf = 3;
required string foreign_blob_buffer_name = 4;
}
message OutputOpConf {
......@@ -772,6 +783,12 @@ message OutputOpConf {
required string out = 2;
}
message ForeignOutputOpConf {
required string in = 1;
required string out = 2;
required string foreign_blob_buffer_name = 3;
}
message VariableOpConf {
optional string tick = 1;
required string out = 2;
......@@ -1465,6 +1482,8 @@ message OperatorConf {
WaitAndSendIdsOpConf wait_and_send_ids_conf = 139;
ReentrantLockOpConf reentrant_lock_conf = 140;
CallbackNotifyOpConf callback_notify_conf = 141;
ForeignInputOpConf foreign_input_conf = 142;
ForeignOutputOpConf foreign_output_conf = 143;
// domain op
TupleIdentityOpConf tuple_identity_conf = 200;
......
......@@ -4,6 +4,7 @@
namespace oneflow {
void OutputOp::InitFromOpConf() {
CHECK(op_conf().has_output_conf());
EnrollInputBn("in");
EnrollOutputBn("out");
}
......@@ -22,7 +23,6 @@ void OutputOp::InferHasBatchDim(std::function<bool*(const std::string&)> HasBatc
void OutputOp::GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
const auto bns = StdVec2PbRpf<std::string>({"in", "out"});
int64_t num_axes = LogicalBlobDesc4Ibn(input_bns().Get(0)).shape().NumAxes();
SbpSignatureBuilder()
.Split(input_bns(), 0)
......
#ifndef ONEFLOW_CORE_REGISTER_FOREIGN_BLOB_H_
#define ONEFLOW_CORE_REGISTER_FOREIGN_BLOB_H_
#include "oneflow/core/register/blob.h"
namespace oneflow {
class ForeignBlob {
public:
OF_DISALLOW_COPY_AND_MOVE(ForeignBlob);
virtual ~ForeignBlob() = default;
virtual void CopyFrom(const Blob* blob) = 0;
virtual void CopyTo(Blob* blob) const = 0;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_FOREIGN_BLOB_H_
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment