From 8c2451c11b93e24b0f76ad22424fb374633f8a80 Mon Sep 17 00:00:00 2001
From: "kyle.cao" <kyle.cao@vesoft.com>
Date: Tue, 6 Apr 2021 14:59:56 +0800
Subject: [PATCH] Add optimizer rule for having clause (#842)

* push filter down aggNode(having without agg)

* add tck case

* fix tck

Co-authored-by: Yee <2520865+yixinglu@users.noreply.github.com>
---
 src/optimizer/CMakeLists.txt                  |   1 +
 .../rule/PushFilterDownAggregateRule.cpp      | 119 ++++++++++++++++++
 .../rule/PushFilterDownAggregateRule.h        |  34 +++++
 src/planner/Query.cpp                         |  30 +++++
 src/planner/Query.h                           |  12 ++
 .../PushFilterDownAggregateRule.feature       |  47 +++++++
 6 files changed, 243 insertions(+)
 create mode 100644 src/optimizer/rule/PushFilterDownAggregateRule.cpp
 create mode 100644 src/optimizer/rule/PushFilterDownAggregateRule.h
 create mode 100644 tests/tck/features/optimizer/PushFilterDownAggregateRule.feature

diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt
index 65615a7f..d186bd16 100644
--- a/src/optimizer/CMakeLists.txt
+++ b/src/optimizer/CMakeLists.txt
@@ -19,6 +19,7 @@ nebula_add_library(
     rule/IndexScanRule.cpp
     rule/LimitPushDownRule.cpp
     rule/TopNRule.cpp
+    rule/PushFilterDownAggregateRule.cpp
 )
 
 nebula_add_subdirectory(test)
diff --git a/src/optimizer/rule/PushFilterDownAggregateRule.cpp b/src/optimizer/rule/PushFilterDownAggregateRule.cpp
new file mode 100644
index 00000000..5a21925e
--- /dev/null
+++ b/src/optimizer/rule/PushFilterDownAggregateRule.cpp
@@ -0,0 +1,119 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License,
+ * attached with Common Clause Condition 1.0, found in the LICENSES directory.
+ */
+
+#include "optimizer/rule/PushFilterDownAggregateRule.h"
+
+#include "optimizer/OptContext.h"
+#include "optimizer/OptGroup.h"
+#include "planner/PlanNode.h"
+#include "planner/Query.h"
+#include "util/ExpressionUtils.h"
+#include "visitor/ExtractFilterExprVisitor.h"
+#include "visitor/RewriteVisitor.h"
+
+using nebula::graph::PlanNode;
+using nebula::graph::QueryContext;
+
+namespace nebula {
+namespace opt {
+
+std::unique_ptr<OptRule> PushFilterDownAggregateRule::kInstance =
+    std::unique_ptr<PushFilterDownAggregateRule>(new PushFilterDownAggregateRule());
+
+PushFilterDownAggregateRule::PushFilterDownAggregateRule() {
+    RuleSet::QueryRules().addRule(this);
+}
+
+const Pattern& PushFilterDownAggregateRule::pattern() const {
+    static Pattern pattern = Pattern::create(graph::PlanNode::Kind::kFilter,
+                                             {Pattern::create(graph::PlanNode::Kind::kAggregate)});
+    return pattern;
+}
+
+StatusOr<OptRule::TransformResult> PushFilterDownAggregateRule::transform(
+    OptContext* octx,
+    const MatchedResult& matched) const {
+    auto qctx = octx->qctx();
+    auto* filterGroupNode = matched.node;
+    auto* oldFilterNode = filterGroupNode->node();
+    auto deps = matched.dependencies;
+    DCHECK_EQ(deps.size(), 1);
+    auto aggGroupNode = deps.front().node;
+    auto* oldAggNode = aggGroupNode->node();
+    DCHECK(oldFilterNode->kind() == PlanNode::Kind::kFilter);
+    DCHECK(oldAggNode->kind() == PlanNode::Kind::kAggregate);
+    auto* newFilterNode = static_cast<const graph::Filter*>(oldFilterNode)->clone(qctx);
+    auto* newAggNode = static_cast<const graph::Aggregate*>(oldAggNode)->clone(qctx);
+    const auto* condition = newFilterNode->condition();
+    auto& groupItems = newAggNode->groupItems();
+
+    // Check expression recursively to ensure no aggregate items in the filter
+    auto varProps = graph::ExpressionUtils::collectAll(condition, {Expression::Kind::kVarProperty});
+    std::vector<std::string> propNames;
+    for (auto* expr : varProps) {
+        DCHECK(expr->kind() == Expression::Kind::kVarProperty);
+        propNames.emplace_back(*static_cast<const VariablePropertyExpression*>(expr)->prop());
+    }
+    std::unordered_map<std::string, Expression*> rewriteMap;
+    auto colNames = newAggNode->colNames();
+    for (size_t i = 0; i < colNames.size(); ++i) {
+        auto& colName = colNames[i];
+        auto iter = std::find_if(propNames.begin(), propNames.end(), [&colName](const auto& name) {
+            return !colName.compare(name);
+        });
+        if (iter == propNames.end()) continue;
+        if (graph::ExpressionUtils::findAny(groupItems[i], {Expression::Kind::kAggregate})) {
+            return TransformResult::noTransform();
+        }
+        rewriteMap[colName] = groupItems[i];
+    }
+
+    // Rewrite VariablePropertyExpr in filter's condition
+    auto matcher = [&rewriteMap](const Expression* e) -> bool {
+        if (e->kind() != Expression::Kind::kVarProperty) {
+            return false;
+        }
+        auto* propName = static_cast<const VariablePropertyExpression*>(e)->prop();
+        return rewriteMap[*propName];
+    };
+    auto rewriter = [&rewriteMap](const Expression* e) -> Expression* {
+        DCHECK_EQ(e->kind(), Expression::Kind::kVarProperty);
+        auto* propName = static_cast<const VariablePropertyExpression*>(e)->prop();
+        return rewriteMap[*propName]->clone().release();
+    };
+    auto* newCondition =
+        graph::RewriteVisitor::transform(condition, std::move(matcher), std::move(rewriter));
+    qctx->objPool()->add(newCondition);
+    newFilterNode->setCondition(newCondition);
+
+    // Exchange planNode
+    newAggNode->setOutputVar(oldFilterNode->outputVar());
+    newFilterNode->setInputVar(oldAggNode->inputVar());
+    DCHECK(oldAggNode->outputVar() == oldFilterNode->inputVar());
+    newAggNode->setInputVar(oldAggNode->outputVar());
+    newFilterNode->setOutputVar(oldAggNode->outputVar());
+
+    // Push down filter's optGroup and embed newAggGroupNode into old filter's Group
+    auto newAggGroupNode = OptGroupNode::create(octx, newAggNode, filterGroupNode->group());
+    auto newFilterGroup = OptGroup::create(octx);
+    auto newFilterGroupNode = newFilterGroup->makeGroupNode(newFilterNode);
+    newAggGroupNode->dependsOn(newFilterGroup);
+    for (auto dep : aggGroupNode->dependencies()) {
+        newFilterGroupNode->dependsOn(dep);
+    }
+
+    TransformResult result;
+    result.eraseAll = true;
+    result.newGroupNodes.emplace_back(newAggGroupNode);
+    return result;
+}
+
+std::string PushFilterDownAggregateRule::toString() const {
+    return "PushFilterDownAggregateRule";
+}
+
+}   // namespace opt
+}   // namespace nebula
diff --git a/src/optimizer/rule/PushFilterDownAggregateRule.h b/src/optimizer/rule/PushFilterDownAggregateRule.h
new file mode 100644
index 00000000..8d24b1ac
--- /dev/null
+++ b/src/optimizer/rule/PushFilterDownAggregateRule.h
@@ -0,0 +1,34 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License,
+ * attached with Common Clause Condition 1.0, found in the LICENSES directory.
+ */
+
+#ifndef OPTIMIZER_RULE_PUSHFILTERDOWNAGGREGATERULE_H_
+#define OPTIMIZER_RULE_PUSHFILTERDOWNAGGREGATERULE_H_
+
+#include <memory>
+#include "optimizer/OptRule.h"
+
+namespace nebula {
+namespace opt {
+
+class PushFilterDownAggregateRule final : public OptRule {
+public:
+    const Pattern &pattern() const override;
+
+    StatusOr<OptRule::TransformResult> transform(OptContext *qctx,
+                                                 const MatchedResult &matched) const override;
+
+    std::string toString() const override;
+
+private:
+    PushFilterDownAggregateRule();
+
+    static std::unique_ptr<OptRule> kInstance;
+};
+
+}   // namespace opt
+}   // namespace nebula
+
+#endif   // OPTIMIZER_RULE_PUSHFILTERDOWNAGGREGATERULE_H_
diff --git a/src/planner/Query.cpp b/src/planner/Query.cpp
index 94f203d6..47ad8392 100644
--- a/src/planner/Query.cpp
+++ b/src/planner/Query.cpp
@@ -150,6 +150,17 @@ std::unique_ptr<PlanNodeDescription> IndexScan::explain() const {
     return desc;
 }
 
+Filter* Filter::clone(QueryContext* qctx) const {
+    auto newFilter = Filter::make(qctx, nullptr, nullptr, needStableFilter_);
+    newFilter->clone(*this);
+    return newFilter;
+}
+
+void Filter::clone(const Filter& f) {
+    SingleInputNode::clone(f);
+    condition_ = qctx_->objPool()->add(f.condition()->clone().release());
+}
+
 std::unique_ptr<PlanNodeDescription> Filter::explain() const {
     auto desc = SingleInputNode::explain();
     addDescription("condition", condition_ ? condition_->toString() : "", desc.get());
@@ -222,6 +233,25 @@ std::unique_ptr<PlanNodeDescription> TopN::explain() const {
     return desc;
 }
 
+Aggregate* Aggregate::clone(QueryContext* qctx) const {
+    std::vector<Expression*> newGroupKeys;
+    std::vector<Expression*> newGroupItems;
+    auto newAggregate =
+        Aggregate::make(qctx, nullptr, std::move(newGroupKeys), std::move(newGroupItems));
+    newAggregate->clone(*this);
+    return newAggregate;
+}
+
+void Aggregate::clone(const Aggregate& agg) {
+    SingleInputNode::clone(agg);
+    for (auto* expr : agg.groupKeys()) {
+        groupKeys_.emplace_back(qctx_->objPool()->add(expr->clone().release()));
+    }
+    for (auto* expr : agg.groupItems()) {
+        groupItems_.emplace_back(qctx_->objPool()->add(expr->clone().release()));
+    }
+}
+
 std::unique_ptr<PlanNodeDescription> Aggregate::explain() const {
     auto desc = SingleInputNode::explain();
     addDescription("groupKeys", folly::toJson(util::toJson(groupKeys_)), desc.get());
diff --git a/src/planner/Query.h b/src/planner/Query.h
index 9a1fef30..7413f04a 100644
--- a/src/planner/Query.h
+++ b/src/planner/Query.h
@@ -549,10 +549,16 @@ public:
         return condition_;
     }
 
+    void setCondition(Expression* condition) {
+        condition_ = condition;
+    }
+
     bool needStableFilter() const {
         return needStableFilter_;
     }
 
+    Filter* clone(QueryContext* qctx) const;
+
     std::unique_ptr<PlanNodeDescription> explain() const override;
 
 private:
@@ -562,6 +568,8 @@ private:
         needStableFilter_ = needStableFilter;
     }
 
+    void clone(const Filter& f);
+
 private:
     // Remain result when true
     Expression*                 condition_{nullptr};
@@ -837,6 +845,8 @@ public:
         return groupItems_;
     }
 
+    Aggregate* clone(QueryContext* qctx) const;
+
     std::unique_ptr<PlanNodeDescription> explain() const override;
 
 private:
@@ -849,6 +859,8 @@ private:
         groupItems_ = std::move(groupItems);
     }
 
+    void clone(const Aggregate&);
+
 private:
     std::vector<Expression*>    groupKeys_;
     std::vector<Expression*>    groupItems_;
diff --git a/tests/tck/features/optimizer/PushFilterDownAggregateRule.feature b/tests/tck/features/optimizer/PushFilterDownAggregateRule.feature
new file mode 100644
index 00000000..9ca947eb
--- /dev/null
+++ b/tests/tck/features/optimizer/PushFilterDownAggregateRule.feature
@@ -0,0 +1,47 @@
+# Copyright (c) 2021 vesoft inc. All rights reserved.
+#
+# This source code is licensed under Apache 2.0 License,
+# attached with Common Clause Condition 1.0, found in the LICENSES directory.
+Feature: Push Filter down Aggregate rule
+
+  Background:
+    Given a graph with space named "nba"
+
+  Scenario: push filter down Aggregate
+    When profiling query:
+      """
+      MATCH (v:player)
+      WITH v.age+1 AS age, COUNT(v.age) as count
+        WHERE age<30
+      RETURN age,count
+        ORDER BY age
+      """
+    Then the result should be, in any order:
+      | age | count |
+      | -3  | 1     |
+      | -2  | 1     |
+      | -1  | 1     |
+      | 0   | 1     |
+      | 1   | 1     |
+      | 21  | 1     |
+      | 23  | 1     |
+      | 24  | 1     |
+      | 25  | 1     |
+      | 26  | 2     |
+      | 27  | 1     |
+      | 28  | 1     |
+      | 29  | 3     |
+    And the execution plan should be:
+      | id | name        | dependencies | operator info |
+      | 13 | DataCollect | 12           |               |
+      | 12 | Sort        | 11           |               |
+      | 11 | Project     | 18           |               |
+      | 18 | Aggregate   | 17           |               |
+      | 17 | Filter      | 8            |               |
+      | 8  | Filter      | 7            |               |
+      | 7  | Project     | 6            |               |
+      | 6  | Project     | 5            |               |
+      | 5  | Filter      | 16           |               |
+      | 16 | GetVertices | 14           |               |
+      | 14 | IndexScan   | 0            |               |
+      | 0  | Start       |              |               |
-- 
GitLab