Skip to content
Snippets Groups Projects
Commit 069b3fcf authored by Houjiang Chen's avatar Houjiang Chen Committed by Li Xinqi
Browse files

Dev job set fix infer apis (#2072)

* Refine Infer APIs by return Maybe<void> type

* Fix return type

* Fix code style

* Replace CHECK macros in the implementation of infer APIs

* Revert IsOk

* update
parent bf17ab2e
No related branches found
No related tags found
No related merge requests found
......@@ -122,12 +122,12 @@ std::string Sprintf(const Args&... args) {
} // namespace
#define ASSERT_EQ(lhs, rhs) ((lhs) == (rhs))
#define ASSERT_GE(lhs, rhs) ((lhs) >= (rhs))
#define ASSERT_GT(lhs, rhs) ((lhs) > (rhs))
#define ASSERT_LE(lhs, rhs) ((lhs) <= (rhs))
#define ASSERT_LT(lhs, rhs) ((lhs) < (rhs))
#define ASSERT_NE(lhs, rhs) ((lhs) != (rhs))
#define OF_TEST_EQ(lhs, rhs) ((lhs) == (rhs))
#define OF_TEST_GE(lhs, rhs) ((lhs) >= (rhs))
#define OF_TEST_GT(lhs, rhs) ((lhs) > (rhs))
#define OF_TEST_LE(lhs, rhs) ((lhs) <= (rhs))
#define OF_TEST_LT(lhs, rhs) ((lhs) < (rhs))
#define OF_TEST_NE(lhs, rhs) ((lhs) != (rhs))
#define GEN_ERROR_MSG(type, expr, ...) \
[&]() -> std::string { \
......@@ -147,17 +147,17 @@ std::string Sprintf(const Args&... args) {
} \
}
#define CHECK_EQ_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(ASSERT_EQ(lhs, rhs), __VA_ARGS__)
#define CHECK_EQ_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_EQ(lhs, rhs), __VA_ARGS__)
#define CHECK_GE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(ASSERT_GE(lhs, rhs), __VA_ARGS__)
#define CHECK_GE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_GE(lhs, rhs), __VA_ARGS__)
#define CHECK_GT_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(ASSERT_GT(lhs, rhs), __VA_ARGS__)
#define CHECK_GT_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_GT(lhs, rhs), __VA_ARGS__)
#define CHECK_LE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(ASSERT_LE(lhs, rhs), __VA_ARGS__)
#define CHECK_LE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_LE(lhs, rhs), __VA_ARGS__)
#define CHECK_LT_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(ASSERT_LT(lhs, rhs), __VA_ARGS__)
#define CHECK_LT_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_LT(lhs, rhs), __VA_ARGS__)
#define CHECK_NE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(ASSERT_NE(lhs, rhs), __VA_ARGS__)
#define CHECK_NE_OR_RETURN(lhs, rhs, ...) CHECK_OR_RETURN(OF_TEST_NE(lhs, rhs), __VA_ARGS__)
#define CHECK_STREQ_OR_RETURN(lhs, rhs, ...) \
CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs), __VA_ARGS__)
......@@ -169,10 +169,10 @@ std::string Sprintf(const Args&... args) {
return Maybe<void>(error); \
}
#define UNSUPPORTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(ASSERT_UNSUPPORTED, __VA_ARGS__)
#define UNSUPPORTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_UNSUPPORTED, __VA_ARGS__)
#define TODO_THEN_RETURN(...) ENFORCE_THEN_RETURN(ASSERT_TODO, __VA_ARGS__)
#define TODO_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_TODO, __VA_ARGS__)
#define UNIMPLEMENTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(ASSERT_UNIMPLEMENTED, __VA_ARGS__)
#define UNIMPLEMENTED_THEN_RETURN(...) ENFORCE_THEN_RETURN(OF_TEST_UNIMPLEMENTED, __VA_ARGS__)
#endif // ONEFLOW_CORE_COMMON_MAYBE_H_
......@@ -11,13 +11,13 @@ class AssignOp final : public Operator {
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
void InferHasBatchDim(
Maybe<void> InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override {
NaiveInferHasBatchDim(HasBatchDim4BnInOp);
return NaiveInferHasBatchDim(HasBatchDim4BnInOp);
}
void GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
......@@ -32,9 +32,11 @@ void AssignOp::InitFromOpConf() {
const PbMessage& AssignOp::GetCustomizedConf() const { return op_conf().assign_conf(); }
void AssignOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
CHECK(*GetBlobDesc4BnInOp("ref") == *GetBlobDesc4BnInOp("value"));
Maybe<void> AssignOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
CHECK_OR_RETURN(*GetBlobDesc4BnInOp("ref") == *GetBlobDesc4BnInOp("value"));
return Maybe<void>::Ok();
}
void AssignOp::GetSbpSignatures(
......
......@@ -9,14 +9,17 @@ void PartialTickOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
void PartialTickOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
Maybe<void> PartialTickOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
void PartialTickOp::InferHasBatchDim(
Maybe<void> PartialTickOp::InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const {
*HasBatchDim4BnInOp("out") = false;
return Maybe<void>::Ok();
}
void PartialTickOp::GetSbpSignatures(
......
......@@ -13,13 +13,14 @@ class PartialTickOp final : public Operator {
~PartialTickOp() = default;
void InitFromOpConf() override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
const PbMessage& GetCustomizedConf() const override { return op_conf().partial_tick_conf(); }
LogicalNode* NewProperLogicalNode() const override { return new TickLogicalNode; }
private:
void InferHasBatchDim(std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
Maybe<void> InferHasBatchDim(
std::function<bool*(const std::string&)> HasBatchDim4BnInOp) const override;
void GetSbpSignatures(
const std::function<const BlobDesc&(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment