diff --git a/src/optimizer/OptGroup.cpp b/src/optimizer/OptGroup.cpp index cb97fd1f9aded0cae1aaefadaeccebd8a34253b8..bb5cf54ada16388395e9627e8f92750ef0bb8f63 100644 --- a/src/optimizer/OptGroup.cpp +++ b/src/optimizer/OptGroup.cpp @@ -32,6 +32,7 @@ OptGroup::OptGroup(QueryContext *qctx) noexcept : qctx_(qctx) { } void OptGroup::addGroupExpr(OptGroupExpr *groupExpr) { + DCHECK(groupExpr != nullptr); DCHECK(groupExpr->group() == this); groupExprs_.emplace_back(groupExpr); } @@ -49,20 +50,24 @@ Status OptGroup::explore(const OptRule *rule) { for (auto iter = groupExprs_.begin(); iter != groupExprs_.end();) { auto groupExpr = *iter; + DCHECK(groupExpr != nullptr); if (groupExpr->isExplored(rule)) { + ++iter; continue; } // Bottom to up exploration NG_RETURN_IF_ERROR(groupExpr->explore(rule)); // Find more equivalents - if (!rule->match(groupExpr)) { + auto status = rule->match(groupExpr); + if (!status.ok()) { ++iter; continue; } - - OptRule::TransformResult result; - NG_RETURN_IF_ERROR(rule->transform(qctx_, groupExpr, &result)); + auto matched = std::move(status).value(); + auto resStatus = rule->transform(qctx_, matched); + NG_RETURN_IF_ERROR(resStatus); + auto result = std::move(resStatus).value(); if (result.eraseAll) { groupExprs_.clear(); for (auto nge : result.newGroupExprs) { @@ -89,6 +94,19 @@ Status OptGroup::explore(const OptRule *rule) { return Status::OK(); } +Status OptGroup::exploreUtilMaxRound(const OptRule *rule) { + auto maxRound = kMaxExplorationRound; + while (!isExplored(rule)) { + if (0 < maxRound--) { + NG_RETURN_IF_ERROR(explore(rule)); + } else { + setExplored(rule); + break; + } + } + return Status::OK(); +} + std::pair<double, const OptGroupExpr *> OptGroup::findMinCostGroupExpr() const { double minCost = std::numeric_limits<double>::max(); const OptGroupExpr *minGroupExpr = nullptr; @@ -123,19 +141,19 @@ OptGroupExpr::OptGroupExpr(PlanNode *node, const OptGroup *group) noexcept } Status OptGroupExpr::explore(const OptRule *rule) { - if (isExplored(rule)) return Status::OK(); + if (isExplored(rule)) { + return Status::OK(); + } setExplored(rule); for (auto dep : dependencies_) { - if (!dep->isExplored(rule)) { - NG_RETURN_IF_ERROR(dep->explore(rule)); - } + DCHECK(dep != nullptr); + NG_RETURN_IF_ERROR(dep->exploreUtilMaxRound(rule)); } for (auto body : bodies_) { - if (!body->isExplored(rule)) { - NG_RETURN_IF_ERROR(body->explore(rule)); - } + DCHECK(body != nullptr); + NG_RETURN_IF_ERROR(body->exploreUtilMaxRound(rule)); } return Status::OK(); } diff --git a/src/optimizer/OptGroup.h b/src/optimizer/OptGroup.h index 0abab8397aad7c723f638492acfc9196f5c01326..03aa0c589263091b3127811543830fa2c472d77f 100644 --- a/src/optimizer/OptGroup.h +++ b/src/optimizer/OptGroup.h @@ -8,6 +8,7 @@ #define OPTIMIZER_OPTGROUP_H_ #include <algorithm> +#include <list> #include <vector> #include "common/base/Status.h" @@ -44,21 +45,24 @@ public: void addGroupExpr(OptGroupExpr *groupExpr); OptGroupExpr *makeGroupExpr(graph::QueryContext *qctx, graph::PlanNode *node); - const std::vector<OptGroupExpr *> &groupExprs() const { + const std::list<OptGroupExpr *> &groupExprs() const { return groupExprs_; } Status explore(const OptRule *rule); + Status exploreUtilMaxRound(const OptRule *rule); double getCost() const; const graph::PlanNode *getPlan() const; private: explicit OptGroup(graph::QueryContext *qctx) noexcept; + static constexpr int16_t kMaxExplorationRound = 128; + std::pair<double, const OptGroupExpr *> findMinCostGroupExpr() const; graph::QueryContext *qctx_{nullptr}; - std::vector<OptGroupExpr *> groupExprs_; + std::list<OptGroupExpr *> groupExprs_; std::vector<const OptRule *> exploredRules_; }; diff --git a/src/optimizer/OptRule.cpp b/src/optimizer/OptRule.cpp index e15b10f850cc6daf8726d866100eb83f3fea6acc..e6253fcfc2b8e538907328bc7760736e1c50b822 100644 --- a/src/optimizer/OptRule.cpp +++ b/src/optimizer/OptRule.cpp @@ -7,16 +7,80 @@ #include "optimizer/OptRule.h" #include "common/base/Logging.h" +#include "optimizer/OptGroup.h" namespace nebula { namespace opt { -RuleSet &RuleSet::defaultRules() { +Pattern Pattern::create(graph::PlanNode::Kind kind, std::initializer_list<Pattern> patterns) { + Pattern pattern; + pattern.kind_ = kind; + for (auto &p : patterns) { + pattern.dependencies_.emplace_back(p); + } + return pattern; +} + +StatusOr<MatchedResult> Pattern::match(const OptGroupExpr *groupExpr) const { + if (groupExpr->node()->kind() != kind_) { + return Status::Error(); + } + + if (dependencies_.empty()) { + return MatchedResult{groupExpr, {}}; + } + + if (groupExpr->dependencies().size() != dependencies_.size()) { + return Status::Error(); + } + + MatchedResult result; + result.node = groupExpr; + result.dependencies.reserve(dependencies_.size()); + for (size_t i = 0; i < dependencies_.size(); ++i) { + auto group = groupExpr->dependencies()[i]; + const auto &pattern = dependencies_[i]; + auto status = pattern.match(group); + NG_RETURN_IF_ERROR(status); + result.dependencies.emplace_back(std::move(status).value()); + } + return result; +} + +StatusOr<MatchedResult> Pattern::match(const OptGroup *group) const { + for (auto node : group->groupExprs()) { + auto status = match(node); + if (status.ok()) { + return status; + } + } + return Status::Error(); +} + +StatusOr<MatchedResult> OptRule::match(const OptGroupExpr *groupExpr) const { + const auto &pattern = this->pattern(); + auto status = pattern.match(groupExpr); + NG_RETURN_IF_ERROR(status); + auto matched = std::move(status).value(); + if (!this->match(matched)) { + return Status::Error(); + } + return matched; +} + +bool OptRule::match(const MatchedResult &matched) const { + UNUSED(matched); + // Return true if subclass doesn't override this interface, + // so optimizer will only check whether pattern is matched + return true; +} + +RuleSet &RuleSet::DefaultRules() { static RuleSet kDefaultRules("DefaultRuleSet"); return kDefaultRules; } -RuleSet &RuleSet::queryRules() { +RuleSet &RuleSet::QueryRules() { static RuleSet kQueryRules("QueryRules"); return kQueryRules; } diff --git a/src/optimizer/OptRule.h b/src/optimizer/OptRule.h index 7db633c7885e49a7f9ba292b79e4d2b2a5133010..af3a764dd636d1d51078df99d52f6deff3c2d36d 100644 --- a/src/optimizer/OptRule.h +++ b/src/optimizer/OptRule.h @@ -7,11 +7,12 @@ #ifndef OPTIMIZER_OPTRULE_H_ #define OPTIMIZER_OPTRULE_H_ +#include <initializer_list> #include <memory> #include <string> #include <vector> -#include "common/base/Status.h" +#include "common/base/StatusOr.h" #include "planner/PlanNode.h" namespace nebula { @@ -22,21 +23,43 @@ class QueryContext; namespace opt { class OptGroupExpr; +class OptGroup; + +struct MatchedResult { + const OptGroupExpr *node{nullptr}; + std::vector<MatchedResult> dependencies; +}; + +class Pattern final { +public: + static Pattern create(graph::PlanNode::Kind kind, std::initializer_list<Pattern> patterns = {}); + + StatusOr<MatchedResult> match(const OptGroupExpr *groupNode) const; + +private: + Pattern() = default; + StatusOr<MatchedResult> match(const OptGroup *group) const; + + graph::PlanNode::Kind kind_; + std::vector<Pattern> dependencies_; +}; class OptRule { public: struct TransformResult { - bool eraseCurr; - bool eraseAll; + bool eraseCurr{false}; + bool eraseAll{false}; std::vector<OptGroupExpr *> newGroupExprs; }; + StatusOr<MatchedResult> match(const OptGroupExpr *groupExpr) const; + virtual ~OptRule() = default; - virtual bool match(const OptGroupExpr *groupExpr) const = 0; - virtual Status transform(graph::QueryContext *qctx, - const OptGroupExpr *groupExpr, - TransformResult *result) const = 0; + virtual const Pattern &pattern() const = 0; + virtual bool match(const MatchedResult &matched) const; + virtual StatusOr<TransformResult> transform(graph::QueryContext *qctx, + const MatchedResult &matched) const = 0; virtual std::string toString() const = 0; protected: @@ -45,8 +68,8 @@ protected: class RuleSet final { public: - static RuleSet &defaultRules(); - static RuleSet &queryRules(); + static RuleSet &DefaultRules(); + static RuleSet &QueryRules(); RuleSet *addRule(const OptRule *rule); diff --git a/src/optimizer/Optimizer.cpp b/src/optimizer/Optimizer.cpp index 3e94edc34039d0aaceb6e5c49d1ab9837a807237..0baa50ba3674809cb700ac1589d9ee75cc2bc7c0 100644 --- a/src/optimizer/Optimizer.cpp +++ b/src/optimizer/Optimizer.cpp @@ -43,12 +43,9 @@ Status Optimizer::prepare() { } Status Optimizer::doExploration() { - // TODO(yee): Apply all rules recursively, not only once round for (auto ruleSet : ruleSets_) { for (auto rule : ruleSet->rules()) { - if (!rootGroup_->isExplored(rule)) { - NG_RETURN_IF_ERROR(rootGroup_->explore(rule)); - } + NG_RETURN_IF_ERROR(rootGroup_->exploreUtilMaxRound(rule)); } } return Status::OK(); diff --git a/src/optimizer/rule/IndexScanRule.cpp b/src/optimizer/rule/IndexScanRule.cpp index f551af6956da89bdb4400f094c1cf55dd6838b13..136a883e21f70cb5c8db0c94a0e7d5fa7467c6d2 100644 --- a/src/optimizer/rule/IndexScanRule.cpp +++ b/src/optimizer/rule/IndexScanRule.cpp @@ -18,22 +18,23 @@ std::unique_ptr<OptRule> IndexScanRule::kInstance = std::unique_ptr<IndexScanRule>(new IndexScanRule()); IndexScanRule::IndexScanRule() { - RuleSet::defaultRules().addRule(this); + RuleSet::DefaultRules().addRule(this); } -bool IndexScanRule::match(const OptGroupExpr *groupExpr) const { - return groupExpr->node()->kind() == PlanNode::Kind::kIndexScan; +const Pattern& IndexScanRule::pattern() const { + static Pattern pattern = Pattern::create(graph::PlanNode::Kind::kIndexScan); + return pattern; } -Status IndexScanRule::transform(graph::QueryContext *qctx, - const OptGroupExpr *groupExpr, - TransformResult *result) const { - FilterItems items; - ScanKind kind; +StatusOr<OptRule::TransformResult> IndexScanRule::transform(graph::QueryContext* qctx, + const MatchedResult& matched) const { + auto groupExpr = matched.node; auto filter = filterExpr(groupExpr); if (filter == nullptr) { return Status::SemanticError("WHERE clause error"); } + FilterItems items; + ScanKind kind; auto ret = analyzeExpression(filter.get(), &items, &kind, isEdge(groupExpr)); NG_RETURN_IF_ERROR(ret); @@ -48,9 +49,10 @@ Status IndexScanRule::transform(graph::QueryContext *qctx, return Status::Error("Plan node dependencies error"); } newGroupExpr->dependsOn(groupExpr->dependencies()[0]); - result->newGroupExprs.emplace_back(newGroupExpr); - result->eraseAll = true; - return Status::OK(); + TransformResult result; + result.newGroupExprs.emplace_back(newGroupExpr); + result.eraseAll = true; + return result; } std::string IndexScanRule::toString() const { @@ -256,6 +258,7 @@ IndexScanRule::filterExpr(const OptGroupExpr *groupExpr) const { auto in = static_cast<const IndexScan *>(groupExpr->node()); auto qct = in->queryContext(); // The initial IndexScan plan node has only one queryContext. + // TODO(yee): Move this condition to match interface if (qct->size() != 1) { LOG(ERROR) << "Index Scan plan node error"; return nullptr; diff --git a/src/optimizer/rule/IndexScanRule.h b/src/optimizer/rule/IndexScanRule.h index 3e7fed349fc90be3fa09817b023ee2531cff4425..a492342561acb07cacb1ea1cad41143037151a38 100644 --- a/src/optimizer/rule/IndexScanRule.h +++ b/src/optimizer/rule/IndexScanRule.h @@ -31,13 +31,10 @@ class IndexScanRule final : public OptRule { FRIEND_TEST(IndexScanRuleTest, IQCtxTest); public: - static std::unique_ptr<OptRule> kInstance; - - bool match(const OptGroupExpr *groupExpr) const override; + const Pattern& pattern() const override; - Status transform(graph::QueryContext *qctx, - const OptGroupExpr *groupExpr, - TransformResult *result) const override; + StatusOr<TransformResult> transform(graph::QueryContext* qctx, + const MatchedResult& matched) const override; std::string toString() const override; @@ -100,6 +97,8 @@ private: } }; + static std::unique_ptr<OptRule> kInstance; + IndexScanRule(); Status createIndexQueryCtx(IndexQueryCtx &iqctx, diff --git a/src/optimizer/rule/PushFilterDownGetNbrsRule.cpp b/src/optimizer/rule/PushFilterDownGetNbrsRule.cpp index e2d22c1b218077dda50b357a92976357e6f97381..c992e4a489c4e72182f18f43a9193317c40c6651 100644 --- a/src/optimizer/rule/PushFilterDownGetNbrsRule.cpp +++ b/src/optimizer/rule/PushFilterDownGetNbrsRule.cpp @@ -29,32 +29,28 @@ std::unique_ptr<OptRule> PushFilterDownGetNbrsRule::kInstance = std::unique_ptr<PushFilterDownGetNbrsRule>(new PushFilterDownGetNbrsRule()); PushFilterDownGetNbrsRule::PushFilterDownGetNbrsRule() { - RuleSet::queryRules().addRule(this); + RuleSet::QueryRules().addRule(this); } -bool PushFilterDownGetNbrsRule::match(const OptGroupExpr *groupExpr) const { - auto pair = findMatchedGroupExpr(groupExpr); - if (!pair.first) { - return false; - } - - return true; +const Pattern &PushFilterDownGetNbrsRule::pattern() const { + static Pattern pattern = Pattern::create( + graph::PlanNode::Kind::kFilter, {Pattern::create(graph::PlanNode::Kind::kGetNeighbors)}); + return pattern; } -Status PushFilterDownGetNbrsRule::transform(QueryContext *qctx, - const OptGroupExpr *groupExpr, - TransformResult *result) const { - auto pair = findMatchedGroupExpr(groupExpr); - auto filter = static_cast<const Filter *>(groupExpr->node()); - auto gn = static_cast<const GetNeighbors *>(pair.second->node()); +StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform( + QueryContext *qctx, + const MatchedResult &matched) const { + auto filterGroupExpr = matched.node; + auto gnGroupExpr = matched.dependencies.front().node; + auto filter = static_cast<const Filter *>(filterGroupExpr->node()); + auto gn = static_cast<const GetNeighbors *>(gnGroupExpr->node()); auto condition = filter->condition()->clone(); graph::ExtractFilterExprVisitor visitor; condition->accept(&visitor); if (!visitor.ok()) { - result->eraseCurr = false; - result->eraseAll = false; - return Status::OK(); + return TransformResult{false, false, {}}; } auto pool = qctx->objPool(); @@ -64,7 +60,7 @@ Status PushFilterDownGetNbrsRule::transform(QueryContext *qctx, auto newFilter = Filter::make(qctx, nullptr, pool->add(remainedExpr.release())); newFilter->setOutputVar(filter->outputVar()); newFilter->setInputVar(filter->inputVar()); - newFilterGroupExpr = OptGroupExpr::create(qctx, newFilter, groupExpr->group()); + newFilterGroupExpr = OptGroupExpr::create(qctx, newFilter, filterGroupExpr->group()); } auto newGNFilter = condition->encode(); @@ -78,48 +74,31 @@ Status PushFilterDownGetNbrsRule::transform(QueryContext *qctx, auto newGN = gn->clone(qctx); newGN->setFilter(newGNFilter); - OptGroupExpr *newGroupExpr = nullptr; + OptGroupExpr *newGnGroupExpr = nullptr; if (newFilterGroupExpr != nullptr) { // Filter(A&&B)->GetNeighbors(C) => Filter(A)->GetNeighbors(B&&C) auto newGroup = OptGroup::create(qctx); - newGroupExpr = OptGroupExpr::create(qctx, newGN, newGroup); + newGnGroupExpr = OptGroupExpr::create(qctx, newGN, newGroup); newFilterGroupExpr->dependsOn(newGroup); } else { // Filter(A)->GetNeighbors(C) => GetNeighbors(A&&C) - newGroupExpr = OptGroupExpr::create(qctx, newGN, groupExpr->group()); + newGnGroupExpr = OptGroupExpr::create(qctx, newGN, filterGroupExpr->group()); newGN->setOutputVar(filter->outputVar()); } - for (auto dep : pair.second->dependencies()) { - newGroupExpr->dependsOn(dep); + for (auto dep : gnGroupExpr->dependencies()) { + newGnGroupExpr->dependsOn(dep); } - result->newGroupExprs.emplace_back(newFilterGroupExpr ? newFilterGroupExpr : newGroupExpr); - result->eraseAll = true; - result->eraseCurr = true; - return Status::OK(); + TransformResult result; + result.newGroupExprs.emplace_back(newFilterGroupExpr ? newFilterGroupExpr : newGnGroupExpr); + result.eraseCurr = true; + return result; } std::string PushFilterDownGetNbrsRule::toString() const { return "PushFilterDownGetNbrsRule"; } -std::pair<bool, const OptGroupExpr *> PushFilterDownGetNbrsRule::findMatchedGroupExpr( - const OptGroupExpr *groupExpr) const { - auto node = groupExpr->node(); - if (node->kind() != PlanNode::Kind::kFilter) { - return std::make_pair(false, nullptr); - } - - for (auto dep : groupExpr->dependencies()) { - for (auto expr : dep->groupExprs()) { - if (expr->node()->kind() == PlanNode::Kind::kGetNeighbors) { - return std::make_pair(true, expr); - } - } - } - return std::make_pair(false, nullptr); -} - } // namespace opt } // namespace nebula diff --git a/src/optimizer/rule/PushFilterDownGetNbrsRule.h b/src/optimizer/rule/PushFilterDownGetNbrsRule.h index 286fa07fbf02c5052b36e54105ab6fa0f1cf6449..11ebbe6d715adf59a950df7c87ca2d57b85a8824 100644 --- a/src/optimizer/rule/PushFilterDownGetNbrsRule.h +++ b/src/optimizer/rule/PushFilterDownGetNbrsRule.h @@ -20,18 +20,15 @@ namespace opt { class PushFilterDownGetNbrsRule final : public OptRule { public: - static std::unique_ptr<OptRule> kInstance; - - bool match(const OptGroupExpr *groupExpr) const override; - Status transform(graph::QueryContext *qctx, - const OptGroupExpr *groupExpr, - TransformResult *result) const override; + const Pattern &pattern() const override; + StatusOr<TransformResult> transform(graph::QueryContext *qctx, + const MatchedResult &matched) const override; std::string toString() const override; private: PushFilterDownGetNbrsRule(); - std::pair<bool, const OptGroupExpr *> findMatchedGroupExpr(const OptGroupExpr *groupExpr) const; + static std::unique_ptr<OptRule> kInstance; }; } // namespace opt diff --git a/src/service/QueryInstance.cpp b/src/service/QueryInstance.cpp index 898c6d98bfdd0dd629a2d96739092ceb357c5795..b0140aef134535bbfab04c78b0a8d4faff2cd575 100644 --- a/src/service/QueryInstance.cpp +++ b/src/service/QueryInstance.cpp @@ -27,9 +27,9 @@ namespace graph { QueryInstance::QueryInstance(std::unique_ptr<QueryContext> qctx) { qctx_ = std::move(qctx); scheduler_ = std::make_unique<Scheduler>(qctx_.get()); - std::vector<const RuleSet *> rulesets{&RuleSet::defaultRules()}; + std::vector<const RuleSet *> rulesets{&RuleSet::DefaultRules()}; if (FLAGS_enable_optimizer) { - rulesets.emplace_back(&RuleSet::queryRules()); + rulesets.emplace_back(&RuleSet::QueryRules()); } optimizer_ = std::make_unique<Optimizer>(qctx_.get(), std::move(rulesets)); } diff --git a/tests/common/nebula_service.py b/tests/common/nebula_service.py index 9243cb2e42955d101cd0b9d26807ff91f3e71d62..73f1066765fe652e8bf0e3b632e4a0589f077c6a 100644 --- a/tests/common/nebula_service.py +++ b/tests/common/nebula_service.py @@ -111,7 +111,6 @@ class NebulaService(object): time.sleep(1) return False - def start(self, debug_log=True): os.chdir(self.work_dir)