Skip to content
Snippets Groups Projects
Unverified Commit 2111c2d2 authored by Juncheng's avatar Juncheng Committed by GitHub
Browse files

Add Operator::InferInternalBlobDescs (#4205)

* Add Operator::InferInternalBlobDescs

* XrtLaunchOp::InferOutBlobDescs

* refine
parent 65754966
No related branches found
No related tags found
No related merge requests found
Showing
with 69 additions and 64 deletions
......@@ -24,9 +24,9 @@ void AccTickOp::InitFromOpConf() {
EnrollOutputBn("acc", false);
}
Maybe<void> AccTickOp::InferBlobDescs(
Maybe<void> AccTickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
*GetBlobDesc4BnInOp("acc") = *GetBlobDesc4BnInOp("one");
GetBlobDesc4BnInOp("acc")->mut_shape() = Shape({1LL});
return Maybe<void>::Ok();
......
......@@ -30,8 +30,9 @@ class AccTickOp final : public Operator {
void InitFromOpConf() override;
LogicalNode* NewProperLogicalNode() const override { return new AccTickLogicalNode; }
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
Maybe<void> InferOutputBlobTimeShape(
std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
const ParallelContext* parallel_ctx, Shape* time_shape) const override;
......
......@@ -28,8 +28,9 @@ class AccumulateOp final : public Operator {
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
return Maybe<void>::Ok();
}
Maybe<void> InferOutputBlobTimeShape(
......
......@@ -24,8 +24,9 @@ class AssignOp final : public Operator {
~AssignOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
private:
Maybe<void> InferBatchAxis(
......@@ -49,9 +50,9 @@ std::string DebugString(const BlobDesc& blob_desc) {
return blob_desc_proto.DebugString();
}
Maybe<void> AssignOp::InferBlobDescs(
Maybe<void> AssignOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_OR_RETURN(*GetBlobDesc4BnInOp("ref") == *GetBlobDesc4BnInOp("value"))
<< "\nref_blob_desc: " << DebugString(*GetBlobDesc4BnInOp("ref"))
<< "\nvalue_blob_desc: " << DebugString(*GetBlobDesc4BnInOp("value"));
......
......@@ -26,14 +26,9 @@ class BoxingIdentityOp : public Operator {
~BoxingIdentityOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
protected:
virtual void VirtualInferBlobDescs(
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {}
virtual void VirtualInitFromOpConf(){};
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
private:
LogicalBlobId lbi4ibn(const std::string& input_bn) const override;
......@@ -53,9 +48,9 @@ LogicalBlobId BoxingIdentityOp::lbi4obn(const std::string& output_bn) const {
return this->op_conf().boxing_identity_conf().lbi();
}
Maybe<void> BoxingIdentityOp::InferBlobDescs(
Maybe<void> BoxingIdentityOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
*GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
return Maybe<void>::Ok();
}
......
......@@ -60,9 +60,9 @@ Symbol<OperatorConf> BoxingOp::GetOpConfWithoutOpNameAndLbn() const {
return SymbolOf(op_conf);
}
Maybe<void> BoxingOp::InferBlobDescs(
Maybe<void> BoxingOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BoxingOpConf& conf = op_conf().boxing_conf();
BlobDesc* first_in_blob = GetBlobDesc4BnInOp(input_bns().Get(0));
if (conf.in_box_case() == BoxingOpConf::kAddBox) {
......
......@@ -27,8 +27,9 @@ class BoxingOp final : public Operator {
~BoxingOp() = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
protected:
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
......
......@@ -25,8 +25,9 @@ class BoxingZerosOp : public Operator {
~BoxingZerosOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
private:
LogicalBlobId lbi4ibn(const std::string& input_bn) const override;
......@@ -43,9 +44,9 @@ LogicalBlobId BoxingZerosOp::lbi4obn(const std::string& output_bn) const {
return this->op_conf().boxing_zeros_conf().lbi();
}
Maybe<void> BoxingZerosOp::InferBlobDescs(
Maybe<void> BoxingZerosOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BoxingZerosOpConf& conf = this->op_conf().boxing_zeros_conf();
BlobDesc* out = GetBlobDesc4BnInOp("out");
out->set_data_type(conf.data_type());
......
......@@ -54,8 +54,9 @@ class BroadcastToCompatibleWithOp final : public Operator {
EnrollOutputBn("y");
}
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
int64_t num_compatibles = op_conf().broadcast_to_compatible_with_conf().compatible_size();
const BlobDesc* x_desc = GetBlobDesc4BnInOp("x");
Shape broadcasted_shape(x_desc->shape());
......
......@@ -28,9 +28,9 @@ LogicalNode* CallbackNotifyOp::NewProperLogicalNode() const {
return new CallbackNotifyLogicalNode();
}
Maybe<void> CallbackNotifyOp::InferBlobDescs(
Maybe<void> CallbackNotifyOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);
CHECK_OR_RETURN(GetBlobDesc4BnInOp("in")->shape() == Shape({1}));
CHECK_OR_RETURN(IsIntegralDataType(GetBlobDesc4BnInOp("in")->data_type()));
......
......@@ -27,8 +27,9 @@ class CallbackNotifyOp final : public Operator {
~CallbackNotifyOp() = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
LogicalNode* NewProperLogicalNode() const override;
private:
......
......@@ -24,8 +24,9 @@ void CaseOp::InitFromOpConf() {
EnrollRepeatedOutputBn("out", false);
}
Maybe<void> CaseOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
Maybe<void> CaseOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* in = GetBlobDesc4BnInOp("in");
CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), 1);
const DataType data_type = in->data_type();
......
......@@ -27,8 +27,9 @@ class CaseOp final : public Operator {
~CaseOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
private:
Maybe<void> InferBatchAxis(
......
......@@ -42,8 +42,9 @@ class CollectiveBoxingGenericOp : public Operator {
return this->op_conf().collective_boxing_generic_conf().lbi();
}
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*) const override {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*,
const SbpSignature* sbp_signature) const override {
const RankDesc& rank_desc = op_conf().collective_boxing_generic_conf().rank_desc();
const DataType data_type = rank_desc.op_desc().data_type();
if (GenericOpHasInput(rank_desc)) {
......
......@@ -27,13 +27,9 @@ class CollectiveBoxingPackOp : public Operator {
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
protected:
virtual void VirtualInferBlobDescs(
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
private:
LogicalBlobId lbi4ibn(const std::string& input_bn) const override;
......@@ -53,9 +49,9 @@ LogicalBlobId CollectiveBoxingPackOp::lbi4obn(const std::string& output_bn) cons
return this->op_conf().collective_boxing_pack_conf().lbi();
}
Maybe<void> CollectiveBoxingPackOp::InferBlobDescs(
Maybe<void> CollectiveBoxingPackOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *in_blob_desc;
......
......@@ -27,8 +27,9 @@ class CollectiveBoxingUnpackOp : public Operator {
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
protected:
virtual void VirtualInferBlobDescs(
......@@ -53,9 +54,9 @@ LogicalBlobId CollectiveBoxingUnpackOp::lbi4obn(const std::string& output_bn) co
return this->op_conf().collective_boxing_unpack_conf().lbi();
}
Maybe<void> CollectiveBoxingUnpackOp::InferBlobDescs(
Maybe<void> CollectiveBoxingUnpackOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf();
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
......
......@@ -29,9 +29,9 @@ class ConstantLikeOp final : public Operator {
EnrollOutputBn("out", false);
}
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
const ConstantLikeOpConf& conf = op_conf().constant_like_conf();
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *GetBlobDesc4BnInOp("like");
......
......@@ -24,8 +24,9 @@ class CopyHdOp final : public Operator {
~CopyHdOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
private:
Maybe<void> InferBatchAxis(
......@@ -52,9 +53,9 @@ void CopyHdOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
Maybe<void> CopyHdOp::InferBlobDescs(
Maybe<void> CopyHdOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
*GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
return Maybe<void>::Ok();
}
......
......@@ -23,8 +23,9 @@ void CWiseOp::InitFromOpConf() {
VirtualInitFromOpConf();
}
Maybe<void> CWiseOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
Maybe<void> CWiseOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(0));
for (size_t i = 1; i < input_bns().size(); ++i) {
const auto* blob_desc = GetBlobDesc4BnInOp(input_bns().Get(i));
......
......@@ -28,8 +28,9 @@ class CWiseOp : public Operator {
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
protected:
virtual void VirtualInitFromOpConf() { UNIMPLEMENTED(); }
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment