From 8c2451c11b93e24b0f76ad22424fb374633f8a80 Mon Sep 17 00:00:00 2001 From: "kyle.cao" <kyle.cao@vesoft.com> Date: Tue, 6 Apr 2021 14:59:56 +0800 Subject: [PATCH] Add optimizer rule for having clause (#842) * push filter down aggNode(having without agg) * add tck case * fix tck Co-authored-by: Yee <2520865+yixinglu@users.noreply.github.com> --- src/optimizer/CMakeLists.txt | 1 + .../rule/PushFilterDownAggregateRule.cpp | 119 ++++++++++++++++++ .../rule/PushFilterDownAggregateRule.h | 34 +++++ src/planner/Query.cpp | 30 +++++ src/planner/Query.h | 12 ++ .../PushFilterDownAggregateRule.feature | 47 +++++++ 6 files changed, 243 insertions(+) create mode 100644 src/optimizer/rule/PushFilterDownAggregateRule.cpp create mode 100644 src/optimizer/rule/PushFilterDownAggregateRule.h create mode 100644 tests/tck/features/optimizer/PushFilterDownAggregateRule.feature diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt index 65615a7f..d186bd16 100644 --- a/src/optimizer/CMakeLists.txt +++ b/src/optimizer/CMakeLists.txt @@ -19,6 +19,7 @@ nebula_add_library( rule/IndexScanRule.cpp rule/LimitPushDownRule.cpp rule/TopNRule.cpp + rule/PushFilterDownAggregateRule.cpp ) nebula_add_subdirectory(test) diff --git a/src/optimizer/rule/PushFilterDownAggregateRule.cpp b/src/optimizer/rule/PushFilterDownAggregateRule.cpp new file mode 100644 index 00000000..5a21925e --- /dev/null +++ b/src/optimizer/rule/PushFilterDownAggregateRule.cpp @@ -0,0 +1,119 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "optimizer/rule/PushFilterDownAggregateRule.h" + +#include "optimizer/OptContext.h" +#include "optimizer/OptGroup.h" +#include "planner/PlanNode.h" +#include "planner/Query.h" +#include "util/ExpressionUtils.h" +#include "visitor/ExtractFilterExprVisitor.h" +#include "visitor/RewriteVisitor.h" + +using nebula::graph::PlanNode; +using nebula::graph::QueryContext; + +namespace nebula { +namespace opt { + +std::unique_ptr<OptRule> PushFilterDownAggregateRule::kInstance = + std::unique_ptr<PushFilterDownAggregateRule>(new PushFilterDownAggregateRule()); + +PushFilterDownAggregateRule::PushFilterDownAggregateRule() { + RuleSet::QueryRules().addRule(this); +} + +const Pattern& PushFilterDownAggregateRule::pattern() const { + static Pattern pattern = Pattern::create(graph::PlanNode::Kind::kFilter, + {Pattern::create(graph::PlanNode::Kind::kAggregate)}); + return pattern; +} + +StatusOr<OptRule::TransformResult> PushFilterDownAggregateRule::transform( + OptContext* octx, + const MatchedResult& matched) const { + auto qctx = octx->qctx(); + auto* filterGroupNode = matched.node; + auto* oldFilterNode = filterGroupNode->node(); + auto deps = matched.dependencies; + DCHECK_EQ(deps.size(), 1); + auto aggGroupNode = deps.front().node; + auto* oldAggNode = aggGroupNode->node(); + DCHECK(oldFilterNode->kind() == PlanNode::Kind::kFilter); + DCHECK(oldAggNode->kind() == PlanNode::Kind::kAggregate); + auto* newFilterNode = static_cast<const graph::Filter*>(oldFilterNode)->clone(qctx); + auto* newAggNode = static_cast<const graph::Aggregate*>(oldAggNode)->clone(qctx); + const auto* condition = newFilterNode->condition(); + auto& groupItems = newAggNode->groupItems(); + + // Check expression recursively to ensure no aggregate items in the filter + auto varProps = graph::ExpressionUtils::collectAll(condition, {Expression::Kind::kVarProperty}); + std::vector<std::string> propNames; + for (auto* expr : varProps) { + DCHECK(expr->kind() == Expression::Kind::kVarProperty); + propNames.emplace_back(*static_cast<const VariablePropertyExpression*>(expr)->prop()); + } + std::unordered_map<std::string, Expression*> rewriteMap; + auto colNames = newAggNode->colNames(); + for (size_t i = 0; i < colNames.size(); ++i) { + auto& colName = colNames[i]; + auto iter = std::find_if(propNames.begin(), propNames.end(), [&colName](const auto& name) { + return !colName.compare(name); + }); + if (iter == propNames.end()) continue; + if (graph::ExpressionUtils::findAny(groupItems[i], {Expression::Kind::kAggregate})) { + return TransformResult::noTransform(); + } + rewriteMap[colName] = groupItems[i]; + } + + // Rewrite VariablePropertyExpr in filter's condition + auto matcher = [&rewriteMap](const Expression* e) -> bool { + if (e->kind() != Expression::Kind::kVarProperty) { + return false; + } + auto* propName = static_cast<const VariablePropertyExpression*>(e)->prop(); + return rewriteMap[*propName]; + }; + auto rewriter = [&rewriteMap](const Expression* e) -> Expression* { + DCHECK_EQ(e->kind(), Expression::Kind::kVarProperty); + auto* propName = static_cast<const VariablePropertyExpression*>(e)->prop(); + return rewriteMap[*propName]->clone().release(); + }; + auto* newCondition = + graph::RewriteVisitor::transform(condition, std::move(matcher), std::move(rewriter)); + qctx->objPool()->add(newCondition); + newFilterNode->setCondition(newCondition); + + // Exchange planNode + newAggNode->setOutputVar(oldFilterNode->outputVar()); + newFilterNode->setInputVar(oldAggNode->inputVar()); + DCHECK(oldAggNode->outputVar() == oldFilterNode->inputVar()); + newAggNode->setInputVar(oldAggNode->outputVar()); + newFilterNode->setOutputVar(oldAggNode->outputVar()); + + // Push down filter's optGroup and embed newAggGroupNode into old filter's Group + auto newAggGroupNode = OptGroupNode::create(octx, newAggNode, filterGroupNode->group()); + auto newFilterGroup = OptGroup::create(octx); + auto newFilterGroupNode = newFilterGroup->makeGroupNode(newFilterNode); + newAggGroupNode->dependsOn(newFilterGroup); + for (auto dep : aggGroupNode->dependencies()) { + newFilterGroupNode->dependsOn(dep); + } + + TransformResult result; + result.eraseAll = true; + result.newGroupNodes.emplace_back(newAggGroupNode); + return result; +} + +std::string PushFilterDownAggregateRule::toString() const { + return "PushFilterDownAggregateRule"; +} + +} // namespace opt +} // namespace nebula diff --git a/src/optimizer/rule/PushFilterDownAggregateRule.h b/src/optimizer/rule/PushFilterDownAggregateRule.h new file mode 100644 index 00000000..8d24b1ac --- /dev/null +++ b/src/optimizer/rule/PushFilterDownAggregateRule.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef OPTIMIZER_RULE_PUSHFILTERDOWNAGGREGATERULE_H_ +#define OPTIMIZER_RULE_PUSHFILTERDOWNAGGREGATERULE_H_ + +#include <memory> +#include "optimizer/OptRule.h" + +namespace nebula { +namespace opt { + +class PushFilterDownAggregateRule final : public OptRule { +public: + const Pattern &pattern() const override; + + StatusOr<OptRule::TransformResult> transform(OptContext *qctx, + const MatchedResult &matched) const override; + + std::string toString() const override; + +private: + PushFilterDownAggregateRule(); + + static std::unique_ptr<OptRule> kInstance; +}; + +} // namespace opt +} // namespace nebula + +#endif // OPTIMIZER_RULE_PUSHFILTERDOWNAGGREGATERULE_H_ diff --git a/src/planner/Query.cpp b/src/planner/Query.cpp index 94f203d6..47ad8392 100644 --- a/src/planner/Query.cpp +++ b/src/planner/Query.cpp @@ -150,6 +150,17 @@ std::unique_ptr<PlanNodeDescription> IndexScan::explain() const { return desc; } +Filter* Filter::clone(QueryContext* qctx) const { + auto newFilter = Filter::make(qctx, nullptr, nullptr, needStableFilter_); + newFilter->clone(*this); + return newFilter; +} + +void Filter::clone(const Filter& f) { + SingleInputNode::clone(f); + condition_ = qctx_->objPool()->add(f.condition()->clone().release()); +} + std::unique_ptr<PlanNodeDescription> Filter::explain() const { auto desc = SingleInputNode::explain(); addDescription("condition", condition_ ? condition_->toString() : "", desc.get()); @@ -222,6 +233,25 @@ std::unique_ptr<PlanNodeDescription> TopN::explain() const { return desc; } +Aggregate* Aggregate::clone(QueryContext* qctx) const { + std::vector<Expression*> newGroupKeys; + std::vector<Expression*> newGroupItems; + auto newAggregate = + Aggregate::make(qctx, nullptr, std::move(newGroupKeys), std::move(newGroupItems)); + newAggregate->clone(*this); + return newAggregate; +} + +void Aggregate::clone(const Aggregate& agg) { + SingleInputNode::clone(agg); + for (auto* expr : agg.groupKeys()) { + groupKeys_.emplace_back(qctx_->objPool()->add(expr->clone().release())); + } + for (auto* expr : agg.groupItems()) { + groupItems_.emplace_back(qctx_->objPool()->add(expr->clone().release())); + } +} + std::unique_ptr<PlanNodeDescription> Aggregate::explain() const { auto desc = SingleInputNode::explain(); addDescription("groupKeys", folly::toJson(util::toJson(groupKeys_)), desc.get()); diff --git a/src/planner/Query.h b/src/planner/Query.h index 9a1fef30..7413f04a 100644 --- a/src/planner/Query.h +++ b/src/planner/Query.h @@ -549,10 +549,16 @@ public: return condition_; } + void setCondition(Expression* condition) { + condition_ = condition; + } + bool needStableFilter() const { return needStableFilter_; } + Filter* clone(QueryContext* qctx) const; + std::unique_ptr<PlanNodeDescription> explain() const override; private: @@ -562,6 +568,8 @@ private: needStableFilter_ = needStableFilter; } + void clone(const Filter& f); + private: // Remain result when true Expression* condition_{nullptr}; @@ -837,6 +845,8 @@ public: return groupItems_; } + Aggregate* clone(QueryContext* qctx) const; + std::unique_ptr<PlanNodeDescription> explain() const override; private: @@ -849,6 +859,8 @@ private: groupItems_ = std::move(groupItems); } + void clone(const Aggregate&); + private: std::vector<Expression*> groupKeys_; std::vector<Expression*> groupItems_; diff --git a/tests/tck/features/optimizer/PushFilterDownAggregateRule.feature b/tests/tck/features/optimizer/PushFilterDownAggregateRule.feature new file mode 100644 index 00000000..9ca947eb --- /dev/null +++ b/tests/tck/features/optimizer/PushFilterDownAggregateRule.feature @@ -0,0 +1,47 @@ +# Copyright (c) 2021 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License, +# attached with Common Clause Condition 1.0, found in the LICENSES directory. +Feature: Push Filter down Aggregate rule + + Background: + Given a graph with space named "nba" + + Scenario: push filter down Aggregate + When profiling query: + """ + MATCH (v:player) + WITH v.age+1 AS age, COUNT(v.age) as count + WHERE age<30 + RETURN age,count + ORDER BY age + """ + Then the result should be, in any order: + | age | count | + | -3 | 1 | + | -2 | 1 | + | -1 | 1 | + | 0 | 1 | + | 1 | 1 | + | 21 | 1 | + | 23 | 1 | + | 24 | 1 | + | 25 | 1 | + | 26 | 2 | + | 27 | 1 | + | 28 | 1 | + | 29 | 3 | + And the execution plan should be: + | id | name | dependencies | operator info | + | 13 | DataCollect | 12 | | + | 12 | Sort | 11 | | + | 11 | Project | 18 | | + | 18 | Aggregate | 17 | | + | 17 | Filter | 8 | | + | 8 | Filter | 7 | | + | 7 | Project | 6 | | + | 6 | Project | 5 | | + | 5 | Filter | 16 | | + | 16 | GetVertices | 14 | | + | 14 | IndexScan | 0 | | + | 0 | Start | | | -- GitLab