diff --git a/oneflow/core/operator/loss_op.cpp b/oneflow/core/operator/loss_op.cpp index f618eebaec1fbc98d7d265bf5a94a7b2ea3bff8a..69c9528bffa5db14c26a4e18485894b280754470 100644 --- a/oneflow/core/operator/loss_op.cpp +++ b/oneflow/core/operator/loss_op.cpp @@ -25,7 +25,8 @@ void LossOp::VirtualGenKernelConf( } void LossOp::InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, size_t* buf_size) const { + const ParallelContext* parallel_ctx, size_t* buf_size, + std::function<void(OpContext*)>) const { const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction"); const BlobDesc* label_blob_desc = GetBlobDesc4BnInOp("label"); CHECK_EQ(pred_blob_desc->has_data_id_field(), label_blob_desc->has_data_id_field()); diff --git a/oneflow/core/operator/loss_op.h b/oneflow/core/operator/loss_op.h index 30ef6ef6df6d0340746dd77efbd8b08d408c8d1f..4264ca4881d9b258ef7415f3bff83a29f81956d2 100644 --- a/oneflow/core/operator/loss_op.h +++ b/oneflow/core/operator/loss_op.h @@ -16,15 +16,11 @@ class LossOp : public Operator { LogicalNode* NewProperLogicalNode() override { return new LossLogicalNode; } void InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, size_t* buf_size) const override; + const ParallelContext* parallel_ctx, size_t* buf_size, + std::function<void(OpContext*)> EnrollOpCtx) const override; bool IsLossOp() const override { return true; } protected: - void InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, size_t* buf_size, - std::function<void(OpContext*)> EnrollOpCtx) const override { - InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, buf_size); - } virtual void VirtualInitFromOpConf() {} virtual void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, size_t* buf_size) const {} diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index a5be3e4bc73fc6cdbb20d3aea8985356b1cbc48d..40a80393e50a3f6ace94a944f672a7e8146ad2dd 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -91,10 +91,6 @@ void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlo std::function<void(OpContext*)> EnrollOpCtx) const { InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, EnrollOpCtx); } -void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, size_t* buf_size) const { - InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx); -} void Operator::InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, std::function<void(OpContext*)> EnrollOpCtx) const { diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index d7b7ffd82d0c95d0ea27fc7b92d36a1669f041ef..b719bb2084d8ecbb85b768a10de9a4e4193422a9 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -110,8 +110,6 @@ class Operator { virtual void InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, const ParallelContext*, size_t* buf_size, std::function<void(OpContext*)> EnrollOpCtx) const; - virtual void InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, size_t* buf_size) const; virtual void InferBlobDescs(std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, const ParallelContext*, std::function<void(OpContext*)> EnrollOpCtx) const;