Skip to content
Snippets Groups Projects
Unverified Commit aba33fab authored by Yee's avatar Yee Committed by GitHub
Browse files

Check dataflow dependency between plan nodes in optimizer (#679)


* Check the dataflow dependency in optimizer

* Fix end condition

* Cleanup

* Fix failed test cases

* Refactor input variables setter of plan node

* Check whether the data flow is same as the control flow

* Add symbol printer

* Add utils

* Add OptContext

Co-authored-by: default avatarjie.wang <38901892+jievince@users.noreply.github.com>
parent caeaea51
No related branches found
No related tags found
No related merge requests found
Showing
with 248 additions and 78 deletions
......@@ -4,7 +4,7 @@
# The file to host the process id
--pid_file=pids/nebula-graphd.pid
# Whether to enable optimizer
--enable_optimizer=false
--enable_optimizer=true
########## logging ##########
# The directory to host logging files, which must already exists
......@@ -24,7 +24,7 @@
--stderrthreshold=2
########## query ##########
# Whether to treat partial success as an error.
# Whether to treat partial success as an error.
# This flag is only used for Read-only access, and Modify access always treats partial success as an error.
--accept_partial_success=false
......
......@@ -4,7 +4,7 @@
# The file to host the process id
--pid_file=pids/nebula-graphd.pid
# Whether to enable optimizer
--enable_optimizer=false
--enable_optimizer=true
########## logging ##########
# The directory to host logging files, which must already exists
......@@ -24,7 +24,7 @@
--stderrthreshold=2
########## query ##########
# Whether to treat partial success as an error.
# Whether to treat partial success as an error.
# This flag is only used for Read-only access, and Modify access always treats partial success as an error.
--accept_partial_success=false
......
......@@ -10,6 +10,7 @@ nebula_add_library(
ExecutionContext.cpp
Iterator.cpp
Result.cpp
Symbols.cpp
)
......
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "context/Symbols.h"
#include <sstream>
#include "planner/PlanNode.h"
#include "util/Utils.h"
namespace nebula {
namespace graph {
std::string Variable::toString() const {
std::stringstream ss;
ss << "name: " << name << ", type: " << type << ", colNames: <" << folly::join(",", colNames)
<< ">, readBy: <" << util::join(readBy, [](auto pn) { return pn->toString(); })
<< ">, writtenBy: <" << util::join(writtenBy, [](auto pn) { return pn->toString(); }) << ">";
return ss.str();
}
std::string SymbolTable::toString() const {
std::stringstream ss;
ss << "SymTable: [";
for (const auto& p : vars_) {
ss << "\n" << p.first << ": ";
if (p.second) {
ss << p.second->toString();
}
}
ss << "\n]";
return ss.str();
}
} // namespace graph
} // namespace nebula
......@@ -37,6 +37,7 @@ using ColsDef = std::vector<ColDef>;
struct Variable {
explicit Variable(std::string n) : name(std::move(n)) {}
std::string toString() const;
std::string name;
Value::Type type{Value::Type::DATASET};
......@@ -118,6 +119,8 @@ public:
}
}
std::string toString() const;
private:
ObjectPool* objPool_{nullptr};
// var name -> variable
......
......@@ -10,6 +10,7 @@ nebula_add_library(
Optimizer.cpp
OptGroup.cpp
OptRule.cpp
OptContext.cpp
rule/PushFilterDownGetNbrsRule.cpp
rule/IndexScanRule.cpp
rule/LimitPushDownRule.cpp
......
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "optimizer/OptContext.h"
#include "common/base/Logging.h"
#include "common/base/ObjectPool.h"
namespace nebula {
namespace opt {
OptContext::OptContext(graph::QueryContext *qctx)
: qctx_(DCHECK_NOTNULL(qctx)), objPool_(std::make_unique<ObjectPool>()) {}
void OptContext::addPlanNodeAndOptGroupNode(int64_t planNodeId, const OptGroupNode *optGroupNode) {
planNodeToOptGroupNodeMap_.emplace(planNodeId, optGroupNode);
}
const OptGroupNode *OptContext::findOptGroupNodeByPlanNodeId(int64_t planNodeId) const {
auto found = planNodeToOptGroupNodeMap_.find(planNodeId);
return found == planNodeToOptGroupNodeMap_.end() ? nullptr : found->second;
}
} // namespace opt
} // namespace nebula
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#ifndef OPTIMIZER_OPTCONTEXT_H_
#define OPTIMIZER_OPTCONTEXT_H_
#include <memory>
#include <unordered_map>
#include "common/cpp/helpers.h"
namespace nebula {
class ObjectPool;
namespace graph {
class QueryContext;
} // namespace graph
namespace opt {
class OptGroupNode;
class OptContext final : private cpp::NonCopyable, private cpp::NonMovable {
public:
explicit OptContext(graph::QueryContext *qctx);
graph::QueryContext *qctx() const {
return qctx_;
}
ObjectPool *objPool() const {
return objPool_.get();
}
void addPlanNodeAndOptGroupNode(int64_t planNodeId, const OptGroupNode *optGroupNode);
const OptGroupNode *findOptGroupNodeByPlanNodeId(int64_t planNodeId) const;
private:
graph::QueryContext *qctx_{nullptr};
std::unique_ptr<ObjectPool> objPool_;
std::unordered_map<int64_t, const OptGroupNode *> planNodeToOptGroupNodeMap_;
};
} // namespace opt
} // namespace nebula
#endif // OPTIMIZER_OPTCONTEXT_H_
......@@ -9,6 +9,7 @@
#include <limits>
#include "context/QueryContext.h"
#include "optimizer/OptContext.h"
#include "optimizer/OptRule.h"
#include "planner/Logic.h"
#include "planner/PlanNode.h"
......@@ -23,12 +24,12 @@ using nebula::graph::SingleDependencyNode;
namespace nebula {
namespace opt {
OptGroup *OptGroup::create(QueryContext *qctx) {
return qctx->objPool()->add(new OptGroup(qctx));
OptGroup *OptGroup::create(OptContext *ctx) {
return ctx->objPool()->add(new OptGroup(ctx));
}
OptGroup::OptGroup(QueryContext *qctx) noexcept : qctx_(qctx) {
DCHECK(qctx != nullptr);
OptGroup::OptGroup(OptContext *ctx) noexcept : ctx_(ctx) {
DCHECK(ctx != nullptr);
}
void OptGroup::addGroupNode(OptGroupNode *groupNode) {
......@@ -37,8 +38,8 @@ void OptGroup::addGroupNode(OptGroupNode *groupNode) {
groupNodes_.emplace_back(groupNode);
}
OptGroupNode *OptGroup::makeGroupNode(QueryContext *qctx, PlanNode *node) {
groupNodes_.emplace_back(OptGroupNode::create(qctx, node, this));
OptGroupNode *OptGroup::makeGroupNode(PlanNode *node) {
groupNodes_.emplace_back(OptGroupNode::create(ctx_, node, this));
return groupNodes_.back();
}
......@@ -59,13 +60,13 @@ Status OptGroup::explore(const OptRule *rule) {
NG_RETURN_IF_ERROR(groupNode->explore(rule));
// Find more equivalents
auto status = rule->match(groupNode);
auto status = rule->match(ctx_, groupNode);
if (!status.ok()) {
++iter;
continue;
}
auto matched = std::move(status).value();
auto resStatus = rule->transform(qctx_, matched);
auto resStatus = rule->transform(ctx_, matched);
NG_RETURN_IF_ERROR(resStatus);
auto result = std::move(resStatus).value();
if (result.eraseAll) {
......@@ -130,8 +131,10 @@ const PlanNode *OptGroup::getPlan() const {
return minGroupNode->getPlan();
}
OptGroupNode *OptGroupNode::create(QueryContext *qctx, PlanNode *node, const OptGroup *group) {
return qctx->objPool()->add(new OptGroupNode(node, group));
OptGroupNode *OptGroupNode::create(OptContext *ctx, PlanNode *node, const OptGroup *group) {
auto optGNode = ctx->objPool()->add(new OptGroupNode(node, group));
ctx->addPlanNodeAndOptGroupNode(node->id(), optGNode);
return optGNode;
}
OptGroupNode::OptGroupNode(PlanNode *node, const OptGroup *group) noexcept
......
......@@ -10,22 +10,23 @@
#include <algorithm>
#include <list>
#include <vector>
#include "common/base/Status.h"
namespace nebula {
namespace graph {
class PlanNode;
class QueryContext;
} // namespace graph
namespace opt {
class OptContext;
class OptGroupNode;
class OptRule;
class OptGroup final {
public:
static OptGroup *create(graph::QueryContext *qctx);
static OptGroup *create(OptContext *ctx);
bool isExplored(const OptRule *rule) const {
return std::find(exploredRules_.cbegin(), exploredRules_.cend(), rule) !=
......@@ -44,7 +45,7 @@ public:
}
void addGroupNode(OptGroupNode *groupNode);
OptGroupNode *makeGroupNode(graph::QueryContext *qctx, graph::PlanNode *node);
OptGroupNode *makeGroupNode(graph::PlanNode *node);
const std::list<OptGroupNode *> &groupNodes() const {
return groupNodes_;
}
......@@ -55,22 +56,20 @@ public:
const graph::PlanNode *getPlan() const;
private:
explicit OptGroup(graph::QueryContext *qctx) noexcept;
explicit OptGroup(OptContext *ctx) noexcept;
static constexpr int16_t kMaxExplorationRound = 128;
std::pair<double, const OptGroupNode *> findMinCostGroupNode() const;
graph::QueryContext *qctx_{nullptr};
OptContext *ctx_{nullptr};
std::list<OptGroupNode *> groupNodes_;
std::vector<const OptRule *> exploredRules_;
};
class OptGroupNode final {
public:
static OptGroupNode *create(graph::QueryContext *qctx,
graph::PlanNode *node,
const OptGroup *group);
static OptGroupNode *create(OptContext *ctx, graph::PlanNode *node, const OptGroup *group);
void dependsOn(OptGroup *dep) {
dependencies_.emplace_back(dep);
......
......@@ -7,7 +7,10 @@
#include "optimizer/OptRule.h"
#include "common/base/Logging.h"
#include "context/Symbols.h"
#include "optimizer/OptContext.h"
#include "optimizer/OptGroup.h"
#include "planner/PlanNode.h"
namespace nebula {
namespace opt {
......@@ -57,21 +60,58 @@ StatusOr<MatchedResult> Pattern::match(const OptGroup *group) const {
return Status::Error();
}
StatusOr<MatchedResult> OptRule::match(const OptGroupNode *groupNode) const {
StatusOr<MatchedResult> OptRule::match(OptContext *ctx, const OptGroupNode *groupNode) const {
const auto &pattern = this->pattern();
auto status = pattern.match(groupNode);
NG_RETURN_IF_ERROR(status);
auto matched = std::move(status).value();
if (!this->match(matched)) {
if (!this->match(ctx, matched)) {
return Status::Error();
}
return matched;
}
bool OptRule::match(const MatchedResult &matched) const {
UNUSED(matched);
// Return true if subclass doesn't override this interface,
// so optimizer will only check whether pattern is matched
bool OptRule::match(OptContext *ctx, const MatchedResult &matched) const {
return checkDataflowDeps(ctx, matched, matched.node->node()->outputVar(), true);
}
bool OptRule::checkDataflowDeps(OptContext *ctx,
const MatchedResult &matched,
const std::string &var,
bool isRoot) const {
auto node = matched.node;
auto planNode = node->node();
const auto &outVarName = planNode->outputVar();
if (outVarName != var) {
return false;
}
auto symTbl = ctx->qctx()->symTable();
auto outVar = symTbl->getVar(outVarName);
// Check whether the data flow is same as the control flow in execution plan.
if (!isRoot) {
for (auto pnode : outVar->readBy) {
auto optGNode = ctx->findOptGroupNodeByPlanNodeId(pnode->id());
if (!optGNode) continue;
const auto &deps = optGNode->dependencies();
if (deps.empty()) continue;
auto found = std::find(deps.begin(), deps.end(), node->group());
if (found == deps.end()) {
VLOG(2) << ctx->qctx()->symTable()->toString();
return false;
}
}
}
const auto &deps = matched.dependencies;
if (deps.empty()) {
return true;
}
DCHECK_EQ(deps.size(), node->dependencies().size());
for (size_t i = 0; i < deps.size(); ++i) {
if (!checkDataflowDeps(ctx, deps[i], planNode->inputVar(i), false)) {
return false;
}
}
return true;
}
......@@ -81,7 +121,7 @@ RuleSet &RuleSet::DefaultRules() {
}
RuleSet &RuleSet::QueryRules() {
static RuleSet kQueryRules("QueryRules");
static RuleSet kQueryRules("QueryRuleSet");
return kQueryRules;
}
......
......@@ -16,12 +16,14 @@
#include "planner/PlanNode.h"
namespace nebula {
namespace graph {
class QueryContext;
} // namespace graph
namespace opt {
class OptContext;
class OptGroupNode;
class OptGroup;
......@@ -57,18 +59,25 @@ public:
std::vector<OptGroupNode *> newGroupNodes;
};
StatusOr<MatchedResult> match(const OptGroupNode *groupNode) const;
StatusOr<MatchedResult> match(OptContext *ctx, const OptGroupNode *groupNode) const;
virtual ~OptRule() = default;
virtual const Pattern &pattern() const = 0;
virtual bool match(const MatchedResult &matched) const;
virtual StatusOr<TransformResult> transform(graph::QueryContext *qctx,
virtual bool match(OptContext *ctx, const MatchedResult &matched) const;
virtual StatusOr<TransformResult> transform(OptContext *ctx,
const MatchedResult &matched) const = 0;
virtual std::string toString() const = 0;
protected:
OptRule() = default;
// Return false if the output variable of this matched plan node is not the
// input of other plan node
bool checkDataflowDeps(OptContext *ctx,
const MatchedResult &matched,
const std::string &var,
bool isRoot) const;
};
class RuleSet final {
......
......@@ -7,6 +7,7 @@
#include "optimizer/Optimizer.h"
#include "context/QueryContext.h"
#include "optimizer/OptContext.h"
#include "optimizer/OptGroup.h"
#include "optimizer/OptRule.h"
#include "planner/ExecutionPlan.h"
......@@ -27,9 +28,10 @@ Optimizer::Optimizer(std::vector<const RuleSet *> ruleSets) : ruleSets_(std::mov
StatusOr<const PlanNode *> Optimizer::findBestPlan(QueryContext *qctx) {
DCHECK(qctx != nullptr);
auto optCtx = std::make_unique<OptContext>(qctx);
auto root = qctx->plan()->root();
auto status = prepare(qctx, root);
auto status = prepare(optCtx.get(), root);
NG_RETURN_IF_ERROR(status);
auto rootGroup = std::move(status).value();
......@@ -37,9 +39,9 @@ StatusOr<const PlanNode *> Optimizer::findBestPlan(QueryContext *qctx) {
return rootGroup->getPlan();
}
StatusOr<OptGroup *> Optimizer::prepare(QueryContext *qctx, PlanNode *root) {
StatusOr<OptGroup *> Optimizer::prepare(OptContext *ctx, PlanNode *root) {
std::unordered_map<int64_t, OptGroup *> visited;
return convertToGroup(qctx, root, &visited);
return convertToGroup(ctx, root, &visited);
}
Status Optimizer::doExploration(OptGroup *rootGroup) {
......@@ -51,7 +53,7 @@ Status Optimizer::doExploration(OptGroup *rootGroup) {
return Status::OK();
}
OptGroup *Optimizer::convertToGroup(QueryContext *qctx,
OptGroup *Optimizer::convertToGroup(OptContext *ctx,
PlanNode *node,
std::unordered_map<int64_t, OptGroup *> *visited) {
auto iter = visited->find(node->id());
......@@ -59,8 +61,8 @@ OptGroup *Optimizer::convertToGroup(QueryContext *qctx,
return iter->second;
}
auto group = OptGroup::create(qctx);
auto groupNode = group->makeGroupNode(qctx, node);
auto group = OptGroup::create(ctx);
auto groupNode = group->makeGroupNode(node);
switch (node->dependencies().size()) {
case 0: {
......@@ -70,29 +72,29 @@ OptGroup *Optimizer::convertToGroup(QueryContext *qctx,
case 1: {
if (node->kind() == PlanNode::Kind::kSelect) {
auto select = static_cast<Select *>(node);
auto then = convertToGroup(qctx, const_cast<PlanNode *>(select->then()), visited);
auto then = convertToGroup(ctx, const_cast<PlanNode *>(select->then()), visited);
groupNode->addBody(then);
auto otherNode = const_cast<PlanNode *>(select->otherwise());
auto otherwise = convertToGroup(qctx, otherNode, visited);
auto otherwise = convertToGroup(ctx, otherNode, visited);
groupNode->addBody(otherwise);
} else if (node->kind() == PlanNode::Kind::kLoop) {
auto loop = static_cast<Loop *>(node);
auto body = convertToGroup(qctx, const_cast<PlanNode *>(loop->body()), visited);
auto body = convertToGroup(ctx, const_cast<PlanNode *>(loop->body()), visited);
groupNode->addBody(body);
}
auto dep = static_cast<SingleDependencyNode *>(node)->dep();
DCHECK(dep != nullptr);
auto depGroup = convertToGroup(qctx, const_cast<graph::PlanNode *>(dep), visited);
auto depGroup = convertToGroup(ctx, const_cast<graph::PlanNode *>(dep), visited);
groupNode->dependsOn(depGroup);
break;
}
case 2: {
auto bNode = static_cast<BiInputNode *>(node);
auto leftNode = const_cast<graph::PlanNode *>(bNode->left());
auto leftGroup = convertToGroup(qctx, leftNode, visited);
auto leftGroup = convertToGroup(ctx, leftNode, visited);
groupNode->dependsOn(leftGroup);
auto rightNode = const_cast<graph::PlanNode *>(bNode->right());
auto rightGroup = convertToGroup(qctx, rightNode, visited);
auto rightGroup = convertToGroup(ctx, rightNode, visited);
groupNode->dependsOn(rightGroup);
break;
}
......
......@@ -19,6 +19,7 @@ class QueryContext;
namespace opt {
class OptContext;
class OptGroup;
class OptGroupNode;
class RuleSet;
......@@ -31,10 +32,10 @@ public:
StatusOr<const graph::PlanNode *> findBestPlan(graph::QueryContext *qctx);
private:
StatusOr<OptGroup *> prepare(graph::QueryContext *qctx, graph::PlanNode *root);
StatusOr<OptGroup *> prepare(OptContext *ctx, graph::PlanNode *root);
Status doExploration(OptGroup *rootGroup);
OptGroup *convertToGroup(graph::QueryContext *qctx,
OptGroup *convertToGroup(OptContext *ctx,
graph::PlanNode *node,
std::unordered_map<int64_t, OptGroup *> *visited);
......
......@@ -6,11 +6,13 @@
#include "optimizer/rule/IndexScanRule.h"
#include "common/expression/LabelAttributeExpression.h"
#include "optimizer/OptContext.h"
#include "optimizer/OptGroup.h"
#include "planner/PlanNode.h"
#include "planner/Query.h"
using nebula::graph::IndexScan;
using nebula::graph::OptimizerUtils;
namespace nebula {
namespace opt {
......@@ -27,7 +29,7 @@ const Pattern& IndexScanRule::pattern() const {
return pattern;
}
StatusOr<OptRule::TransformResult> IndexScanRule::transform(graph::QueryContext* qctx,
StatusOr<OptRule::TransformResult> IndexScanRule::transform(OptContext* ctx,
const MatchedResult& matched) const {
auto groupNode = matched.node;
if (isEmptyResultSet(groupNode)) {
......@@ -35,6 +37,7 @@ StatusOr<OptRule::TransformResult> IndexScanRule::transform(graph::QueryContext*
}
auto filter = filterExpr(groupNode);
auto qctx = ctx->qctx();
IndexQueryCtx iqctx = std::make_unique<std::vector<IndexQueryContext>>();
if (filter == nullptr) {
// Only filter is nullptr when lookup on tagname
......@@ -48,7 +51,7 @@ StatusOr<OptRule::TransformResult> IndexScanRule::transform(graph::QueryContext*
auto newIN = static_cast<const IndexScan*>(groupNode->node())->clone(qctx);
newIN->setIndexQueryContext(std::move(iqctx));
auto newGroupNode = OptGroupNode::create(qctx, newIN, groupNode->group());
auto newGroupNode = OptGroupNode::create(ctx, newIN, groupNode->group());
if (groupNode->dependencies().size() != 1) {
return Status::Error("Plan node dependencies error");
}
......
......@@ -11,14 +11,8 @@
#include "optimizer/OptimizerUtils.h"
namespace nebula {
namespace graph {
class IndexScan;
} // namespace graph
namespace opt {
using graph::IndexScan;
using graph::OptimizerUtils;
using graph::PlanNode;
using graph::QueryContext;
using storage::cpp2::IndexQueryContext;
using storage::cpp2::IndexColumnHint;
......@@ -26,6 +20,8 @@ using BVO = graph::OptimizerUtils::BoundValueOperator;
using IndexItem = std::shared_ptr<meta::cpp2::IndexItem>;
using IndexQueryCtx = std::unique_ptr<std::vector<IndexQueryContext>>;
class OptContext;
class IndexScanRule final : public OptRule {
FRIEND_TEST(IndexScanRuleTest, BoundValueTest);
FRIEND_TEST(IndexScanRuleTest, IQCtxTest);
......@@ -34,7 +30,7 @@ class IndexScanRule final : public OptRule {
public:
const Pattern& pattern() const override;
StatusOr<TransformResult> transform(graph::QueryContext* qctx,
StatusOr<TransformResult> transform(OptContext* ctx,
const MatchedResult& matched) const override;
std::string toString() const override;
......
......@@ -12,6 +12,7 @@
#include "common/expression/FunctionCallExpression.h"
#include "common/expression/LogicalExpression.h"
#include "common/expression/UnaryExpression.h"
#include "optimizer/OptContext.h"
#include "optimizer/OptGroup.h"
#include "planner/PlanNode.h"
#include "planner/Query.h"
......@@ -42,7 +43,7 @@ const Pattern &LimitPushDownRule::pattern() const {
}
StatusOr<OptRule::TransformResult> LimitPushDownRule::transform(
QueryContext *qctx,
OptContext *ctx,
const MatchedResult &matched) const {
auto limitGroupNode = matched.node;
auto projGroupNode = matched.dependencies.front().node;
......@@ -57,17 +58,18 @@ StatusOr<OptRule::TransformResult> LimitPushDownRule::transform(
return TransformResult::noTransform();
}
auto qctx = ctx->qctx();
auto newLimit = limit->clone(qctx);
auto newLimitGroupNode = OptGroupNode::create(qctx, newLimit, limitGroupNode->group());
auto newLimitGroupNode = OptGroupNode::create(ctx, newLimit, limitGroupNode->group());
auto newProj = proj->clone(qctx);
auto newProjGroup = OptGroup::create(qctx);
auto newProjGroupNode = newProjGroup->makeGroupNode(qctx, newProj);
auto newProjGroup = OptGroup::create(ctx);
auto newProjGroupNode = newProjGroup->makeGroupNode(newProj);
auto newGn = gn->clone(qctx);
newGn->setLimit(limitRows);
auto newGnGroup = OptGroup::create(qctx);
auto newGnGroupNode = newGnGroup->makeGroupNode(qctx, newGn);
auto newGnGroup = OptGroup::create(ctx);
auto newGnGroupNode = newGnGroup->makeGroupNode(newGn);
newLimitGroupNode->dependsOn(newProjGroup);
newProjGroupNode->dependsOn(newGnGroup);
......
......@@ -12,19 +12,13 @@
#include "optimizer/OptRule.h"
namespace nebula {
namespace graph {
class Limit;
class Project;
class GetNeighbors;
} // namespace graph
namespace opt {
class LimitPushDownRule final : public OptRule {
public:
const Pattern &pattern() const override;
StatusOr<OptRule::TransformResult> transform(graph::QueryContext *qctx,
StatusOr<OptRule::TransformResult> transform(OptContext *ctx,
const MatchedResult &matched) const override;
std::string toString() const override;
......
......@@ -12,6 +12,7 @@
#include "common/expression/FunctionCallExpression.h"
#include "common/expression/LogicalExpression.h"
#include "common/expression/UnaryExpression.h"
#include "optimizer/OptContext.h"
#include "optimizer/OptGroup.h"
#include "planner/PlanNode.h"
#include "planner/Query.h"
......@@ -39,7 +40,7 @@ const Pattern &PushFilterDownGetNbrsRule::pattern() const {
}
StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform(
QueryContext *qctx,
OptContext *ctx,
const MatchedResult &matched) const {
auto filterGroupNode = matched.node;
auto gnGroupNode = matched.dependencies.front().node;
......@@ -53,6 +54,7 @@ StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform(
return TransformResult::noTransform();
}
auto qctx = ctx->qctx();
auto pool = qctx->objPool();
auto remainedExpr = std::move(visitor).remainedExpr();
OptGroupNode *newFilterGroupNode = nullptr;
......@@ -60,7 +62,7 @@ StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform(
auto newFilter = Filter::make(qctx, nullptr, pool->add(remainedExpr.release()));
newFilter->setOutputVar(filter->outputVar());
newFilter->setInputVar(filter->inputVar());
newFilterGroupNode = OptGroupNode::create(qctx, newFilter, filterGroupNode->group());
newFilterGroupNode = OptGroupNode::create(ctx, newFilter, filterGroupNode->group());
}
auto newGNFilter = condition->encode();
......@@ -77,12 +79,12 @@ StatusOr<OptRule::TransformResult> PushFilterDownGetNbrsRule::transform(
OptGroupNode *newGnGroupNode = nullptr;
if (newFilterGroupNode != nullptr) {
// Filter(A&&B)<-GetNeighbors(C) => Filter(A)<-GetNeighbors(B&&C)
auto newGroup = OptGroup::create(qctx);
newGnGroupNode = newGroup->makeGroupNode(qctx, newGN);
auto newGroup = OptGroup::create(ctx);
newGnGroupNode = newGroup->makeGroupNode(newGN);
newFilterGroupNode->dependsOn(newGroup);
} else {
// Filter(A)<-GetNeighbors(C) => GetNeighbors(A&&C)
newGnGroupNode = OptGroupNode::create(qctx, newGN, filterGroupNode->group());
newGnGroupNode = OptGroupNode::create(ctx, newGN, filterGroupNode->group());
newGN->setOutputVar(filter->outputVar());
}
......
......@@ -12,17 +12,13 @@
#include "optimizer/OptRule.h"
namespace nebula {
namespace graph {
class GetNeighbors;
} // namespace graph
namespace opt {
class PushFilterDownGetNbrsRule final : public OptRule {
public:
const Pattern &pattern() const override;
StatusOr<TransformResult> transform(graph::QueryContext *qctx,
StatusOr<TransformResult> transform(OptContext *ctx,
const MatchedResult &matched) const override;
std::string toString() const override;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment