diff --git a/src/optimizer/test/CMakeLists.txt b/src/optimizer/test/CMakeLists.txt index 65b65760f2ce99484a96d1bc350b51600a45917e..deb21a2bcad09da487340f52a5384f386ff3aa72 100644 --- a/src/optimizer/test/CMakeLists.txt +++ b/src/optimizer/test/CMakeLists.txt @@ -41,6 +41,7 @@ set(OPTIMIZER_TEST_LIB $<TARGET_OBJECTS:validator_obj> $<TARGET_OBJECTS:optimizer_obj> ) + nebula_add_test( NAME index_bound_value_test diff --git a/src/parser/Clauses.h b/src/parser/Clauses.h index cd9842a60befca9b1fe9d62799cfe2ca97b2e00a..984d3f75313afffdcaa94f7f38fb259faa4af258 100644 --- a/src/parser/Clauses.h +++ b/src/parser/Clauses.h @@ -303,6 +303,14 @@ public: return result; } + size_t size() const { + return columns_.size(); + } + + bool empty() const { + return columns_.empty(); + } + std::string toString() const; private: diff --git a/src/parser/MatchSentence.cpp b/src/parser/MatchSentence.cpp index 83e8fc48505b0f81c150c173b4eb8ba9e90948b6..1f00774e11b343c183868ab9386ea2caefefb0f9 100644 --- a/src/parser/MatchSentence.cpp +++ b/src/parser/MatchSentence.cpp @@ -24,14 +24,14 @@ std::string MatchEdge::toString() const { end = "-"; } - if (alias_ != nullptr || edge_ != nullptr || props_ != nullptr) { + if (alias_ != nullptr || type_ != nullptr || props_ != nullptr) { buf += '['; if (alias_ != nullptr) { buf += *alias_; } - if (edge_ != nullptr) { + if (type_ != nullptr) { buf += ':'; - buf += *edge_; + buf += *type_; } if (props_ != nullptr) { buf += props_->toString(); @@ -69,10 +69,10 @@ std::string MatchPath::toString() const { std::string buf; buf.reserve(256); - buf += head_->toString(); - for (auto i = 0u; i < steps_.size(); i++) { - buf += steps_[i].first->toString(); - buf += steps_[i].second->toString(); + buf += node(0)->toString(); + for (auto i = 0u; i < edges_.size(); i++) { + buf += edge(i)->toString(); + buf += node(i + 1)->toString(); } return buf; @@ -98,9 +98,9 @@ std::string MatchSentence::toString() const { buf += "MATCH "; buf += path_->toString(); - if (filter_ != nullptr) { + if (where_ != nullptr) { buf += ' '; - buf += filter_->toString(); + buf += where_->toString(); } if (return_ != nullptr) { buf += ' '; diff --git a/src/parser/MatchSentence.h b/src/parser/MatchSentence.h index a87f9fa1f20eb46df4e4d596d702e07f823f72db..2d906d3a2ea9c5811eb9bd69da559e60a92f9dce 100644 --- a/src/parser/MatchSentence.h +++ b/src/parser/MatchSentence.h @@ -17,19 +17,19 @@ namespace nebula { class MatchEdgeProp final { public: - MatchEdgeProp(std::string *alias, std::string *edge, Expression *props = nullptr) { + MatchEdgeProp(std::string *alias, std::string *type, Expression *props = nullptr) { alias_.reset(alias); - edge_.reset(edge); + type_.reset(type); props_.reset(static_cast<MapExpression*>(props)); } auto get() && { - return std::make_tuple(std::move(alias_), std::move(edge_), std::move(props_)); + return std::make_tuple(std::move(alias_), std::move(type_), std::move(props_)); } private: std::unique_ptr<std::string> alias_; - std::unique_ptr<std::string> edge_; + std::unique_ptr<std::string> type_; std::unique_ptr<MapExpression> props_; }; @@ -41,7 +41,7 @@ public: if (prop != nullptr) { auto tuple = std::move(*prop).get(); alias_ = std::move(std::get<0>(tuple)); - edge_ = std::move(std::get<1>(tuple)); + type_ = std::move(std::get<1>(tuple)); props_ = std::move(std::get<2>(tuple)); delete prop; } @@ -56,8 +56,8 @@ public: return alias_.get(); } - const std::string* edge() const { - return edge_.get(); + const std::string* type() const { + return type_.get(); } const MapExpression* props() const { @@ -69,7 +69,7 @@ public: private: Direction direction_; std::unique_ptr<std::string> alias_; - std::unique_ptr<std::string> edge_; + std::unique_ptr<std::string> type_; std::unique_ptr<MapExpression> props_; }; @@ -107,34 +107,40 @@ private: class MatchPath final { public: - explicit MatchPath(MatchNode *head) { - head_.reset(head); + explicit MatchPath(MatchNode *node) { + nodes_.emplace_back(node); } void add(MatchEdge *edge, MatchNode *node) { - steps_.emplace_back(edge, node); + edges_.emplace_back(edge); + nodes_.emplace_back(node); } - const MatchNode* head() const { - return head_.get(); + const auto& nodes() const { + return nodes_; } - using RawStep = std::pair<const MatchEdge*, const MatchNode*>; - std::vector<RawStep> steps() const { - std::vector<RawStep> result; - result.reserve(steps_.size()); - for (auto &step : steps_) { - result.emplace_back(step.first.get(), step.second.get()); - } - return result; + const auto& edges() const { + return edges_; + } + + size_t steps() const { + return edges_.size(); + } + + const MatchNode* node(size_t i) const { + return nodes_[i].get(); + } + + const MatchEdge* edge(size_t i) const { + return edges_[i].get(); } std::string toString() const; private: - using Step = std::pair<std::unique_ptr<MatchEdge>, std::unique_ptr<MatchNode>>; - std::unique_ptr<MatchNode> head_; - std::vector<Step> steps_; + std::vector<std::unique_ptr<MatchNode>> nodes_; + std::vector<std::unique_ptr<MatchEdge>> edges_; }; @@ -152,6 +158,10 @@ public: return columns_.get(); } + void setColumns(YieldColumns *columns) { + columns_.reset(columns); + } + bool isAll() const { return isAll_; } @@ -166,10 +176,10 @@ private: class MatchSentence final : public Sentence { public: - MatchSentence(MatchPath *path, WhereClause *filter, MatchReturn *ret) + MatchSentence(MatchPath *path, WhereClause *where, MatchReturn *ret) : Sentence(Kind::kMatch) { path_.reset(path); - filter_.reset(filter); + where_.reset(where); return_.reset(ret); } @@ -177,19 +187,31 @@ public: return path_.get(); } - const WhereClause* filter() const { - return filter_.get(); + MatchPath* path() { + return path_.get(); + } + + const WhereClause* where() const { + return where_.get(); + } + + WhereClause* where() { + return where_.get(); } const MatchReturn* ret() const { return return_.get(); } + MatchReturn* ret() { + return return_.get(); + } + std::string toString() const override; private: std::unique_ptr<MatchPath> path_; - std::unique_ptr<WhereClause> filter_; + std::unique_ptr<WhereClause> where_; std::unique_ptr<MatchReturn> return_; }; diff --git a/src/parser/parser.yy b/src/parser/parser.yy index 6c6fccc810f9fe44ea7e4ade314de048ff28631e..1190328da2a950a8fe97cd32bfb5718ec5ed7ad4 100644 --- a/src/parser/parser.yy +++ b/src/parser/parser.yy @@ -989,6 +989,9 @@ match_node | L_PAREN match_alias COLON name_label map_expression R_PAREN { $$ = new MatchNode($2, $4, $5); } + | L_PAREN match_alias map_expression R_PAREN { + $$ = new MatchNode($2, nullptr, $3); + } ; match_alias diff --git a/src/planner/Query.cpp b/src/planner/Query.cpp index f1693ea2761deab4e38a99d9aa3ad70423b2b516..c5e4d28ea7478a68bdb95dc4e750ac1564b96cd5 100644 --- a/src/planner/Query.cpp +++ b/src/planner/Query.cpp @@ -105,8 +105,10 @@ std::unique_ptr<cpp2::PlanNodeDescription> GetEdges::explain() const { IndexScan* IndexScan::clone(QueryContext* qctx) const { auto ctx = std::make_unique<std::vector<storage::cpp2::IndexQueryContext>>(); auto returnCols = std::make_unique<std::vector<std::string>>(*returnColumns()); - return IndexScan::make( + auto *scan = IndexScan::make( qctx, nullptr, space(), std::move(ctx), std::move(returnCols), isEdge(), schemaId()); + scan->setOutputVar(this->outputVar()); + return scan; } std::unique_ptr<cpp2::PlanNodeDescription> IndexScan::explain() const { diff --git a/src/util/ExpressionUtils.cpp b/src/util/ExpressionUtils.cpp index eb20ee2d98bb639cc03a6583e9e19deacda8f105..10a4cc31091fb65c9ce7ff00cc547ec48c0099d6 100644 --- a/src/util/ExpressionUtils.cpp +++ b/src/util/ExpressionUtils.cpp @@ -21,5 +21,49 @@ std::unique_ptr<Expression> ExpressionUtils::foldConstantExpr(const Expression * return newExpr; } +std::vector<const Expression*> ExpressionUtils::pullAnds(const Expression *expr) { + DCHECK(expr->kind() == Expression::Kind::kLogicalAnd); + auto *root = static_cast<const LogicalExpression*>(expr); + std::vector<const Expression*> operands; + + if (root->left()->kind() != Expression::Kind::kLogicalAnd) { + operands.emplace_back(root->left()); + } else { + auto ands = pullAnds(root->left()); + operands.insert(operands.end(), ands.begin(), ands.end()); + } + + if (root->right()->kind() != Expression::Kind::kLogicalAnd) { + operands.emplace_back(root->right()); + } else { + auto ands = pullAnds(root->right()); + operands.insert(operands.end(), ands.begin(), ands.end()); + } + + return operands; +} + +std::vector<const Expression*> ExpressionUtils::pullOrs(const Expression *expr) { + DCHECK(expr->kind() == Expression::Kind::kLogicalOr); + auto *root = static_cast<const LogicalExpression*>(expr); + std::vector<const Expression*> operands; + + if (root->left()->kind() != Expression::Kind::kLogicalOr) { + operands.emplace_back(root->left()); + } else { + auto ands = pullOrs(root->left()); + operands.insert(operands.end(), ands.begin(), ands.end()); + } + + if (root->right()->kind() != Expression::Kind::kLogicalOr) { + operands.emplace_back(root->right()); + } else { + auto ands = pullOrs(root->right()); + operands.insert(operands.end(), ands.begin(), ands.end()); + } + + return operands; +} + } // namespace graph } // namespace nebula diff --git a/src/util/ExpressionUtils.h b/src/util/ExpressionUtils.h index b404719c1a4438f56b8edc7429d47a4981b996f7..2ab87269851d09a6089366e07055b60e90692c4b 100644 --- a/src/util/ExpressionUtils.h +++ b/src/util/ExpressionUtils.h @@ -109,6 +109,10 @@ public: // Clone and fold constant expression static std::unique_ptr<Expression> foldConstantExpr(const Expression* expr); + + static std::vector<const Expression*> pullAnds(const Expression *expr); + + static std::vector<const Expression*> pullOrs(const Expression *expr); }; } // namespace graph diff --git a/src/util/test/ExpressionUtilsTest.cpp b/src/util/test/ExpressionUtilsTest.cpp index d42862554082e5ab109dbc2068402ec4fb170fd5..ba082c6312324f87720e8bff9745b8bb44974807 100644 --- a/src/util/test/ExpressionUtilsTest.cpp +++ b/src/util/test/ExpressionUtilsTest.cpp @@ -203,5 +203,149 @@ TEST_F(ExpressionUtilsTest, CheckComponent) { } } +TEST_F(ExpressionUtilsTest, PullAnds) { + using Kind = Expression::Kind; + // true AND false + { + auto *first = new ConstantExpression(true); + auto *second = new ConstantExpression(false); + LogicalExpression expr(Kind::kLogicalAnd, first, second); + auto ands = ExpressionUtils::pullAnds(&expr); + ASSERT_EQ(2UL, ands.size()); + ASSERT_EQ(first, ands[0]); + ASSERT_EQ(second, ands[1]); + } + // true AND false AND true + { + auto *first = new ConstantExpression(true); + auto *second = new ConstantExpression(false); + auto *third = new ConstantExpression(true); + LogicalExpression expr(Kind::kLogicalAnd, + new LogicalExpression(Kind::kLogicalAnd, first, second), third); + auto ands = ExpressionUtils::pullAnds(&expr); + ASSERT_EQ(3UL, ands.size()); + ASSERT_EQ(first, ands[0]); + ASSERT_EQ(second, ands[1]); + ASSERT_EQ(third, ands[2]); + } + // true AND (false AND true) + { + auto *first = new ConstantExpression(true); + auto *second = new ConstantExpression(false); + auto *third = new ConstantExpression(true); + LogicalExpression expr(Kind::kLogicalAnd, + first, + new LogicalExpression(Kind::kLogicalAnd, second, third)); + auto ands = ExpressionUtils::pullAnds(&expr); + ASSERT_EQ(3UL, ands.size()); + ASSERT_EQ(first, ands[0]); + ASSERT_EQ(second, ands[1]); + ASSERT_EQ(third, ands[2]); + } + // (true OR false) AND (true OR false) + { + auto *first = new LogicalExpression(Kind::kLogicalOr, + new ConstantExpression(true), + new ConstantExpression(false)); + auto *second = new LogicalExpression(Kind::kLogicalOr, + new ConstantExpression(true), + new ConstantExpression(false)); + LogicalExpression expr(Kind::kLogicalAnd, first, second); + auto ands = ExpressionUtils::pullAnds(&expr); + ASSERT_EQ(2UL, ands.size()); + ASSERT_EQ(first, ands[0]); + ASSERT_EQ(second, ands[1]); + } + // true AND ((false AND true) OR false) AND true + { + auto *first = new ConstantExpression(true); + auto *second = new LogicalExpression(Kind::kLogicalOr, + new LogicalExpression(Kind::kLogicalAnd, + new ConstantExpression(false), + new ConstantExpression(true)), + new ConstantExpression(false)); + auto *third = new ConstantExpression(true); + LogicalExpression expr(Kind::kLogicalAnd, + new LogicalExpression(Kind::kLogicalAnd, first, second), third); + auto ands = ExpressionUtils::pullAnds(&expr); + ASSERT_EQ(3UL, ands.size()); + ASSERT_EQ(first, ands[0]); + ASSERT_EQ(second, ands[1]); + ASSERT_EQ(third, ands[2]); + } +} + +TEST_F(ExpressionUtilsTest, PullOrs) { + using Kind = Expression::Kind; + // true OR false + { + auto *first = new ConstantExpression(true); + auto *second = new ConstantExpression(false); + LogicalExpression expr(Kind::kLogicalOr, first, second); + auto ors = ExpressionUtils::pullOrs(&expr); + ASSERT_EQ(2UL, ors.size()); + ASSERT_EQ(first, ors[0]); + ASSERT_EQ(second, ors[1]); + } + // true OR false OR true + { + auto *first = new ConstantExpression(true); + auto *second = new ConstantExpression(false); + auto *third = new ConstantExpression(true); + LogicalExpression expr(Kind::kLogicalOr, + new LogicalExpression(Kind::kLogicalOr, first, second), third); + auto ors = ExpressionUtils::pullOrs(&expr); + ASSERT_EQ(3UL, ors.size()); + ASSERT_EQ(first, ors[0]); + ASSERT_EQ(second, ors[1]); + ASSERT_EQ(third, ors[2]); + } + // true OR (false OR true) + { + auto *first = new ConstantExpression(true); + auto *second = new ConstantExpression(false); + auto *third = new ConstantExpression(true); + LogicalExpression expr(Kind::kLogicalOr, + first, + new LogicalExpression(Kind::kLogicalOr, second, third)); + auto ors = ExpressionUtils::pullOrs(&expr); + ASSERT_EQ(3UL, ors.size()); + ASSERT_EQ(first, ors[0]); + ASSERT_EQ(second, ors[1]); + ASSERT_EQ(third, ors[2]); + } + // (true AND false) OR (true AND false) + { + auto *first = new LogicalExpression(Kind::kLogicalAnd, + new ConstantExpression(true), + new ConstantExpression(false)); + auto *second = new LogicalExpression(Kind::kLogicalAnd, + new ConstantExpression(true), + new ConstantExpression(false)); + LogicalExpression expr(Kind::kLogicalOr, first, second); + auto ors = ExpressionUtils::pullOrs(&expr); + ASSERT_EQ(2UL, ors.size()); + ASSERT_EQ(first, ors[0]); + ASSERT_EQ(second, ors[1]); + } + // true OR ((false OR true) AND false) OR true + { + auto *first = new ConstantExpression(true); + auto *second = new LogicalExpression(Kind::kLogicalAnd, + new LogicalExpression(Kind::kLogicalOr, + new ConstantExpression(false), + new ConstantExpression(true)), + new ConstantExpression(false)); + auto *third = new ConstantExpression(true); + LogicalExpression expr(Kind::kLogicalOr, + new LogicalExpression(Kind::kLogicalOr, first, second), third); + auto ors = ExpressionUtils::pullOrs(&expr); + ASSERT_EQ(3UL, ors.size()); + ASSERT_EQ(first, ors[0]); + ASSERT_EQ(second, ors[1]); + ASSERT_EQ(third, ors[2]); + } +} + } // namespace graph } // namespace nebula diff --git a/src/validator/CMakeLists.txt b/src/validator/CMakeLists.txt index ffe1da36231dd95d7291e65f6eef70dc4afe0128..3baa16bd11814eb6507698ed7435653c63d0b427 100644 --- a/src/validator/CMakeLists.txt +++ b/src/validator/CMakeLists.txt @@ -29,6 +29,7 @@ nebula_add_library( GroupByValidator.cpp FindPathValidator.cpp IndexScanValidator.cpp + MatchValidator.cpp ) nebula_add_subdirectory(test) diff --git a/src/validator/MatchValidator.cpp b/src/validator/MatchValidator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3266ad862ddbcb770969b30b70ddc96a2896ad3 --- /dev/null +++ b/src/validator/MatchValidator.cpp @@ -0,0 +1,670 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "validator/MatchValidator.h" +#include "visitor/RewriteMatchLabelVisitor.h" +#include "util/ExpressionUtils.h" + +namespace nebula { +namespace graph { + +Status MatchValidator::toPlan() { + NG_RETURN_IF_ERROR(buildScanNode()); + if (!edgeInfos_.empty()) { + NG_RETURN_IF_ERROR(buildSteps()); + } + NG_RETURN_IF_ERROR(buildGetTailVertices()); + if (!edgeInfos_.empty()) { + NG_RETURN_IF_ERROR(buildTailJoin()); + } + NG_RETURN_IF_ERROR(buildFilter()); + NG_RETURN_IF_ERROR(buildReturn()); + + return Status::OK(); +} + + +Status MatchValidator::validateImpl() { + auto *sentence = static_cast<MatchSentence*>(sentence_); + NG_RETURN_IF_ERROR(validatePath(sentence->path())); + if (sentence->where() != nullptr) { + NG_RETURN_IF_ERROR(validateFilter(sentence->where()->filter())); + } + NG_RETURN_IF_ERROR(validateReturn(sentence->ret())); + return analyzeStartPoint(); +} + + +Status MatchValidator::validatePath(const MatchPath *path) { + auto *sm = qctx_->schemaMng(); + auto steps = path->steps(); + + nodeInfos_.resize(steps + 1); + edgeInfos_.resize(steps); + for (auto i = 0u; i <= steps; i++) { + auto *node = path->node(i); + auto *label = node->label(); + auto *alias = node->alias(); + auto *props = node->props(); + auto anonymous = false; + if (label != nullptr) { + auto tid = sm->toTagID(space_.id, *label); + if (!tid.ok()) { + return Status::Error("`%s': Unknown tag", label->c_str()); + } + nodeInfos_[i].tid = tid.value(); + } + if (alias == nullptr) { + anonymous = true; + alias = saveObject(new std::string(anon_->getVar())); + } + if (!aliases_.emplace(*alias, kNode).second) { + return Status::Error("`%s': Redefined alias", alias->c_str()); + } + Expression *filter = nullptr; + if (props != nullptr) { + filter = makeSubFilter(*alias, props); + } + nodeInfos_[i].anonymous = anonymous; + nodeInfos_[i].label = label; + nodeInfos_[i].alias = alias; + nodeInfos_[i].props = props; + nodeInfos_[i].filter = filter; + } + + for (auto i = 0u; i < steps; i++) { + auto *edge = path->edge(i); + auto *type = edge->type(); + auto *alias = edge->alias(); + auto *props = edge->props(); + auto direction = edge->direction(); + auto anonymous = false; + if (direction != Direction::OUT_EDGE) { + return Status::SemanticError("Only outbound traversal supported"); + } + if (type != nullptr) { + auto etype = sm->toEdgeType(space_.id, *type); + if (!etype.ok()) { + return Status::SemanticError("`%s': Unknown edge type", type->c_str()); + } + edgeInfos_[i].edgeType = etype.value(); + } + if (alias == nullptr) { + anonymous = true; + alias = saveObject(new std::string(anon_->getVar())); + } + if (!aliases_.emplace(*alias, kEdge).second) { + return Status::SemanticError("`%s': Redefined alias", alias->c_str()); + } + Expression *filter = nullptr; + if (props != nullptr) { + filter = makeSubFilter(*alias, props); + } + edgeInfos_[i].anonymous = anonymous; + edgeInfos_[i].direction = direction; + edgeInfos_[i].type = type; + edgeInfos_[i].alias = alias; + edgeInfos_[i].props = props; + edgeInfos_[i].filter = filter; + } + + return Status::OK(); +} + + +Status MatchValidator::validateFilter(const Expression *filter) { + filter_ = ExpressionUtils::foldConstantExpr(filter); + NG_RETURN_IF_ERROR(validateAliases({filter_.get()})); + + return Status::OK(); +} + + +Status MatchValidator::validateReturn(MatchReturn *ret) { + // `RETURN *': return all named nodes or edges + if (ret->isAll()) { + auto makeColumn = [] (const std::string &name) { + auto *expr = new LabelExpression(name); + auto *alias = new std::string(name); + return new YieldColumn(expr, alias); + }; + + auto columns = new YieldColumns(); + auto steps = edgeInfos_.size(); + + if (!nodeInfos_[0].anonymous) { + columns->addColumn(makeColumn(*nodeInfos_[0].alias)); + } + + for (auto i = 0u; i < steps; i++) { + if (!edgeInfos_[i].anonymous) { + columns->addColumn(makeColumn(*edgeInfos_[i].alias)); + } + if (!nodeInfos_[i+1].anonymous) { + columns->addColumn(makeColumn(*nodeInfos_[i+1].alias)); + } + } + + if (columns->empty()) { + return Status::SemanticError("`RETURN *' not allowed if there is no alias"); + } + + ret->setColumns(columns); + } + + // Check all referencing expressions are valid + std::vector<const Expression*> exprs; + exprs.reserve(ret->columns()->size()); + for (auto *col : ret->columns()->columns()) { + exprs.push_back(col->expr()); + } + NG_RETURN_IF_ERROR(validateAliases(exprs)); + + return Status::OK(); +} + + +Status MatchValidator::validateAliases(const std::vector<const Expression*> &exprs) const { + static const std::unordered_set<Expression::Kind> kinds = { + Expression::Kind::kLabel, + Expression::Kind::kLabelAttribute + }; + + for (auto *expr : exprs) { + auto refExprs = ExpressionUtils::collectAll(expr, kinds); + if (refExprs.empty()) { + continue; + } + for (auto *refExpr : refExprs) { + auto kind = refExpr->kind(); + const std::string *name = nullptr; + if (kind == Expression::Kind::kLabel) { + name = static_cast<const LabelExpression*>(refExpr)->name(); + } else { + DCHECK(kind == Expression::Kind::kLabelAttribute); + name = static_cast<const LabelAttributeExpression*>(refExpr)->left()->name(); + } + DCHECK(name != nullptr); + if (aliases_.count(*name) != 1) { + return Status::SemanticError("Alias used but not defined: `%s'", name->c_str()); + } + } + } + return Status::OK(); +} + + +Status MatchValidator::analyzeStartPoint() { + // TODO(dutor) Originate from either node or edge at any position + startFromNode_ = true; + startIndex_ = 0; + startExpr_ = nullptr; + + auto &head = nodeInfos_[0]; + + if (head.label == nullptr) { + return Status::SemanticError("Head node must have a label"); + } + + Expression *filter = nullptr; + if (filter_ != nullptr) { + filter = makeIndexFilter(*head.label, *head.alias, filter_.get()); + } + if (filter == nullptr) { + if (head.props != nullptr && !head.props->items().empty()) { + filter = makeIndexFilter(*head.label, head.props); + } + } + if (filter == nullptr) { + return Status::SemanticError("Index cannot be deduced in props or filter"); + } + + + scanInfo_.filter = filter; + scanInfo_.schemaId = head.tid; + + return Status::OK(); +} + + +Expression* +MatchValidator::makeIndexFilter(const std::string &label, const MapExpression *map) const { + auto &items = map->items(); + Expression *root = new RelationalExpression(Expression::Kind::kRelEQ, + new TagPropertyExpression( + new std::string(label), + new std::string(*items[0].first)), + items[0].second->clone().release()); + for (auto i = 1u; i < items.size(); i++) { + auto *left = root; + auto *right = new RelationalExpression(Expression::Kind::kRelEQ, + new TagPropertyExpression( + new std::string(label), + new std::string(*items[i].first)), + items[i].second->clone().release()); + root = new LogicalExpression(Expression::Kind::kLogicalAnd, left, right); + } + saveObject(root); + + return root; +} + +Expression* +MatchValidator::makeIndexFilter(const std::string &label, + const std::string &alias, + const Expression *filter) const { + static const std::unordered_set<Expression::Kind> kinds = { + Expression::Kind::kRelEQ, + Expression::Kind::kRelLT, + Expression::Kind::kRelLE, + Expression::Kind::kRelGT, + Expression::Kind::kRelGE + }; + + std::vector<const Expression*> ands; + auto kind = filter->kind(); + if (kinds.count(kind) == 1) { + ands.emplace_back(filter); + } else if (kind == Expression::Kind::kLogicalAnd) { + ands = ExpressionUtils::pullAnds(filter); + } else { + return nullptr; + } + + std::vector<Expression*> relationals; + for (auto *item : ands) { + if (kinds.count(item->kind()) != 1) { + continue; + } + + auto *binary = static_cast<const BinaryExpression*>(item); + auto *left = binary->left(); + auto *right = binary->right(); + const LabelAttributeExpression *la = nullptr; + const ConstantExpression *constant = nullptr; + if (left->kind() == Expression::Kind::kLabelAttribute && + right->kind() == Expression::Kind::kConstant) { + la = static_cast<const LabelAttributeExpression*>(left); + constant = static_cast<const ConstantExpression*>(right); + } else if (right->kind() == Expression::Kind::kLabelAttribute && + left->kind() == Expression::Kind::kConstant) { + la = static_cast<const LabelAttributeExpression*>(right); + constant = static_cast<const ConstantExpression*>(left); + } else { + continue; + } + + if (*la->left()->name() != alias) { + continue; + } + + auto *tpExpr = new TagPropertyExpression( + new std::string(label), + new std::string(*la->right()->name())); + auto *newConstant = constant->clone().release(); + if (left->kind() == Expression::Kind::kLabelAttribute) { + auto *rel = new RelationalExpression(item->kind(), tpExpr, newConstant); + relationals.emplace_back(rel); + } else { + auto *rel = new RelationalExpression(item->kind(), newConstant, tpExpr); + relationals.emplace_back(rel); + } + } + + if (relationals.empty()) { + return nullptr; + } + + auto *root = relationals[0]; + for (auto i = 1u; i < relationals.size(); i++) { + auto *left = root; + root = new LogicalExpression(Expression::Kind::kLogicalAnd, left, relationals[i]); + } + + saveObject(root); + return root; +} + + +Expression* MatchValidator::makeSubFilter(const std::string &alias, + const MapExpression *map) const { + DCHECK(map != nullptr); + auto &items = map->items(); + DCHECK(!items.empty()); + Expression *root = nullptr; + root = new RelationalExpression(Expression::Kind::kRelEQ, + new LabelAttributeExpression( + new LabelExpression(alias), + new LabelExpression(*items[0].first)), + items[0].second->clone().release()); + for (auto i = 1u; i < items.size(); i++) { + auto *left = root; + auto *right = new RelationalExpression(Expression::Kind::kRelEQ, + new LabelAttributeExpression( + new LabelExpression(alias), + new LabelExpression(*items[i].first)), + items[i].second->clone().release()); + root = new LogicalExpression(Expression::Kind::kLogicalAnd, left, right); + } + + saveObject(root); + + return root; +} + + +Status MatchValidator::buildScanNode() { + if (!startFromNode_) { + return Status::SemanticError("Scan from edge not supported now"); + } + if (startIndex_ != 0) { + return Status::SemanticError("Only support scan from the head node"); + } + + using IQC = nebula::storage::cpp2::IndexQueryContext; + auto contexts = std::make_unique<std::vector<IQC>>(); + contexts->emplace_back(); + contexts->back().set_filter(Expression::encode(*scanInfo_.filter)); + auto columns = std::make_unique<std::vector<std::string>>(); + auto scan = IndexScan::make(qctx_, + nullptr, + space_.id, + std::move(contexts), + std::move(columns), + false, + scanInfo_.schemaId); + tail_ = scan; + root_ = scan; + + return Status::OK(); +} + + +Status MatchValidator::buildSteps() { + gnSrcExpr_ = new VariablePropertyExpression(new std::string(), + new std::string(kVid)); + saveObject(gnSrcExpr_); + NG_RETURN_IF_ERROR(buildStep()); + for (auto i = 1u; i < edgeInfos_.size(); i++) { + NG_RETURN_IF_ERROR(buildStep()); + NG_RETURN_IF_ERROR(buildStepJoin()); + } + + return Status::OK(); +} + + +Status MatchValidator::buildStep() { + curStep_++; + + auto &srcNodeInfo = nodeInfos_[curStep_]; + auto &edgeInfo = edgeInfos_[curStep_]; + auto *gn = GetNeighbors::make(qctx_, root_, space_.id); + gn->setSrc(gnSrcExpr_); + auto vertexProps = std::make_unique<std::vector<VertexProp>>(); + if (srcNodeInfo.label != nullptr) { + VertexProp vertexProp; + vertexProp.set_tag(srcNodeInfo.tid); + vertexProps->emplace_back(std::move(vertexProp)); + } + gn->setVertexProps(std::move(vertexProps)); + auto edgeProps = std::make_unique<std::vector<EdgeProp>>(); + if (edgeInfo.type != nullptr) { + EdgeProp edgeProp; + edgeProp.set_type(edgeInfo.edgeType); + edgeProps->emplace_back(std::move(edgeProp)); + } + gn->setEdgeProps(std::move(edgeProps)); + gn->setEdgeDirection(edgeInfo.direction); + + auto *yields = saveObject(new YieldColumns()); + yields->addColumn(new YieldColumn(new VertexExpression())); + yields->addColumn(new YieldColumn(new EdgeExpression())); + auto *project = Project::make(qctx_, gn, yields); + project->setInputVar(gn->outputVar()); + project->setColNames({*srcNodeInfo.alias, *edgeInfo.alias}); + + root_ = project; + + auto rewriter = [this] (const Expression *expr) { + DCHECK_EQ(expr->kind(), Expression::Kind::kLabelAttribute); + return rewrite(static_cast<const LabelAttributeExpression*>(expr)); + }; + + if (srcNodeInfo.filter != nullptr) { + RewriteMatchLabelVisitor visitor(rewriter); + srcNodeInfo.filter->accept(&visitor); + auto *node = Filter::make(qctx_, root_, srcNodeInfo.filter); + node->setInputVar(root_->outputVar()); + node->setColNames(root_->colNames()); + root_ = node; + } + + if (edgeInfo.filter != nullptr) { + RewriteMatchLabelVisitor visitor(rewriter); + edgeInfo.filter->accept(&visitor); + auto *node = Filter::make(qctx_, root_, edgeInfo.filter); + node->setInputVar(root_->outputVar()); + node->setColNames(root_->colNames()); + root_ = node; + } + + gnSrcExpr_ = new AttributeExpression( + new VariablePropertyExpression( + new std::string(project->outputVar()), + new std::string(*edgeInfo.alias)), + new LabelExpression("_dst")); + saveObject(gnSrcExpr_); + + prevStepRoot_ = thisStepRoot_; + thisStepRoot_ = root_; + + return Status::OK(); +} + + +Status MatchValidator::buildGetTailVertices() { + Expression *src = nullptr; + if (!edgeInfos_.empty()) { + src = new AttributeExpression( + new VariablePropertyExpression( + new std::string(), + new std::string(*edgeInfos_[curStep_].alias)), + new LabelExpression("_dst")); + } else { + src = new VariablePropertyExpression(new std::string(), + new std::string(kVid)); + } + saveObject(src); + + auto &nodeInfo = nodeInfos_[curStep_ + 1]; + std::vector<VertexProp> props; + if (nodeInfo.label != nullptr) { + VertexProp prop; + prop.set_tag(nodeInfo.tid); + props.emplace_back(prop); + } + + auto *gv = GetVertices::make(qctx_, root_, space_.id, src, std::move(props), {}, true); + if (thisStepRoot_ != nullptr) { + gv->setInputVar(thisStepRoot_->outputVar()); + } + + auto *yields = saveObject(new YieldColumns()); + yields->addColumn(new YieldColumn(new VertexExpression())); + auto *project = Project::make(qctx_, gv, yields); + project->setInputVar(gv->outputVar()); + project->setColNames({*nodeInfo.alias}); + + auto *dedup = Dedup::make(qctx_, project); + dedup->setInputVar(project->outputVar()); + dedup->setColNames(project->colNames()); + root_ = dedup; + + return Status::OK(); +} + + +Status MatchValidator::buildStepJoin() { + auto prevStep = curStep_ - 1; + auto key = new AttributeExpression( + new VariablePropertyExpression( + new std::string(prevStepRoot_->outputVar()), + new std::string(*edgeInfos_[prevStep].alias)), + new LabelExpression("_dst")); + auto probe = new AttributeExpression( + new VariablePropertyExpression( + new std::string(thisStepRoot_->outputVar()), + new std::string(*nodeInfos_[curStep_].alias)), + new LabelExpression("_vid")); + auto *join = DataJoin::make(qctx_, + root_, + {prevStepRoot_->outputVar(), 0}, + {thisStepRoot_->outputVar(), 0}, + {key}, + {probe}); + auto leftColNames = prevStepRoot_->colNames(); + auto rightColNames = thisStepRoot_->colNames(); + std::vector<std::string> colNames; + colNames.reserve(leftColNames.size() + rightColNames.size()); + for (auto &name : leftColNames) { + colNames.emplace_back(std::move(name)); + } + for (auto &name : rightColNames) { + colNames.emplace_back(std::move(name)); + } + join->setColNames(std::move(colNames)); + root_ = join; + thisStepRoot_ = root_; + + return Status::OK(); +} + + +Status MatchValidator::buildTailJoin() { + auto key = new AttributeExpression( + new VariablePropertyExpression( + new std::string(thisStepRoot_->outputVar()), + new std::string(*edgeInfos_[curStep_].alias)), + new LabelExpression("_dst")); + auto probe = new AttributeExpression( + new VariablePropertyExpression( + new std::string(root_->outputVar()), + new std::string(*nodeInfos_[curStep_ + 1].alias)), + new LabelExpression("_vid")); + auto *join = DataJoin::make(qctx_, + root_, + {thisStepRoot_->outputVar(), 0}, + {root_->outputVar(), 0}, + {key}, + {probe}); + auto colNames = thisStepRoot_->colNames(); + colNames.emplace_back(*nodeInfos_[curStep_ + 1].alias); + join->setColNames(std::move(colNames)); + root_ = join; + + return Status::OK(); +} + + +Status MatchValidator::buildFilter() { + auto *sentence = static_cast<MatchSentence*>(sentence_); + auto *clause = sentence->where(); + if (clause == nullptr) { + return Status::OK(); + } + auto *filter = clause->filter(); + auto kind = filter->kind(); + // TODO(dutor) Find a better way to identify where an expr is a boolean one + if (kind == Expression::Kind::kLabel || + kind == Expression::Kind::kLabelAttribute) { + return Status::SemanticError("Filter should be a boolean expression"); + } + + auto newFilter = filter->clone(); + auto rewriter = [this] (const Expression *expr) { + if (expr->kind() == Expression::Kind::kLabel) { + return rewrite(static_cast<const LabelExpression*>(expr)); + } else { + return rewrite(static_cast<const LabelAttributeExpression*>(expr)); + } + }; + RewriteMatchLabelVisitor visitor(std::move(rewriter)); + newFilter->accept(&visitor); + + auto *node = Filter::make(qctx_, root_, saveObject(newFilter.release())); + node->setInputVar(root_->outputVar()); + node->setColNames(root_->colNames()); + + root_ = node; + + return Status::OK(); +} + + +Status MatchValidator::buildReturn() { + auto *sentence = static_cast<MatchSentence*>(sentence_); + auto *yields = new YieldColumns(); + std::vector<std::string> colNames; + + for (auto *col : sentence->ret()->columns()->columns()) { + auto kind = col->expr()->kind(); + YieldColumn *newColumn = nullptr; + if (kind == Expression::Kind::kLabel) { + auto *label = static_cast<const LabelExpression*>(col->expr()); + newColumn = new YieldColumn(rewrite(label)); + } else if (kind == Expression::Kind::kLabelAttribute) { + auto *la = static_cast<const LabelAttributeExpression*>(col->expr()); + newColumn = new YieldColumn(rewrite(la)); + } else { + auto newExpr = col->expr()->clone(); + auto rewriter = [this] (const Expression *expr) { + if (expr->kind() == Expression::Kind::kLabel) { + return rewrite(static_cast<const LabelExpression*>(expr)); + } else { + return rewrite(static_cast<const LabelAttributeExpression*>(expr)); + } + }; + RewriteMatchLabelVisitor visitor(std::move(rewriter)); + newExpr->accept(&visitor); + newColumn = new YieldColumn(newExpr.release()); + } + yields->addColumn(newColumn); + if (col->alias() != nullptr) { + colNames.emplace_back(*col->alias()); + } else { + colNames.emplace_back(col->expr()->toString()); + } + } + + auto *project = Project::make(qctx_, root_, yields); + project->setInputVar(root_->outputVar()); + project->setColNames(std::move(colNames)); + root_ = project; + + return Status::OK(); +} + + +Expression* MatchValidator::rewrite(const LabelExpression *label) const { + auto *expr = new VariablePropertyExpression( + new std::string(), + new std::string(*label->name())); + return expr; +} + + +Expression *MatchValidator::rewrite(const LabelAttributeExpression *la) const { + auto *expr = new AttributeExpression( + new VariablePropertyExpression( + new std::string(), + new std::string(*la->left()->name())), + new LabelExpression(*la->right()->name())); + return expr; +} + +} // namespace graph +} // namespace nebula diff --git a/src/validator/MatchValidator.h b/src/validator/MatchValidator.h new file mode 100644 index 0000000000000000000000000000000000000000..32d9e9c4b13aada606ab07124c3791bb67ef7c66 --- /dev/null +++ b/src/validator/MatchValidator.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef VALIDATOR_MATCHVALIDATOR_H_ +#define VALIDATOR_MATCHVALIDATOR_H_ + +#include "common/base/Base.h" +#include "validator/TraversalValidator.h" +#include "util/AnonVarGenerator.h" +#include "planner/Query.h" + +namespace nebula { +namespace graph { + +class MatchValidator final : public TraversalValidator { +public: + MatchValidator(Sentence *sentence, QueryContext *context) + : TraversalValidator(sentence, context) { + anon_ = vctx_->anonVarGen(); + } + +private: + Status validateImpl() override; + + Status toPlan() override; + + Status validatePath(const MatchPath *path); + + Status validateFilter(const Expression *filter); + + Status validateReturn(MatchReturn *ret); + + Status validateAliases(const std::vector<const Expression*> &exprs) const; + + Status analyzeStartPoint(); + + Status ananyzeFilterForIndexing(); + + Expression* makeSubFilter(const std::string &alias, const MapExpression *map) const; + + Expression* makeIndexFilter(const std::string &label, + const MapExpression *map) const; + Expression* makeIndexFilter(const std::string &label, + const std::string &alias, + const Expression *filter) const; + + Status buildScanNode(); + + Status buildSteps(); + + Status buildStep(); + + Status buildGetTailVertices(); + + Status buildStepJoin(); + + Status buildTailJoin(); + + Status buildFilter(); + + Status buildReturn(); + + Expression* rewrite(const LabelExpression*) const; + + Expression* rewrite(const LabelAttributeExpression*) const; + + template <typename T> + T* saveObject(T *obj) const { + return qctx_->objPool()->add(obj); + } + +private: + using VertexProp = nebula::storage::cpp2::VertexProp; + using EdgeProp = nebula::storage::cpp2::EdgeProp; + using Direction = MatchEdge::Direction; + struct NodeInfo { + TagID tid{0}; + bool anonymous{false}; + const std::string *label{nullptr}; + const std::string *alias{nullptr}; + const MapExpression *props{nullptr}; + Expression *filter{nullptr}; + }; + + struct EdgeInfo { + bool anonymous{false}; + EdgeType edgeType{0}; + MatchEdge::Direction direction{MatchEdge::Direction::OUT_EDGE}; + const std::string *type{nullptr}; + const std::string *alias{nullptr}; + const MapExpression *props{nullptr}; + Expression *filter{nullptr}; + }; + + enum AliasType { + kNode, kEdge, kPath + }; + + struct ScanInfo { + Expression *filter{nullptr}; + int32_t schemaId{0}; + }; + +private: + bool startFromNode_{true}; + int32_t startIndex_{0}; + int32_t curStep_{-1}; + PlanNode *thisStepRoot_{nullptr}; + PlanNode *prevStepRoot_{nullptr}; + Expression *startExpr_{nullptr}; + Expression *gnSrcExpr_{nullptr}; + std::vector<NodeInfo> nodeInfos_; + std::vector<EdgeInfo> edgeInfos_; + ScanInfo scanInfo_; + std::unordered_map<std::string, AliasType> aliases_; + AnonVarGenerator *anon_{nullptr}; + std::unique_ptr<Expression> filter_; +}; + +} // namespace graph +} // namespace nebula + +#endif // VALIDATOR_MATCHVALIDATOR_H_ diff --git a/src/validator/Validator.cpp b/src/validator/Validator.cpp index 3ad7fabcb2ff30d788ca3f9e900c7522d0948f22..baa642de2137db4117937c41ba9e65ca2c497a40 100644 --- a/src/validator/Validator.cpp +++ b/src/validator/Validator.cpp @@ -35,6 +35,8 @@ #include "validator/YieldValidator.h" #include "visitor/DeducePropsVisitor.h" #include "visitor/DeduceTypeVisitor.h" +#include "validator/GroupByValidator.h" +#include "validator/MatchValidator.h" #include "visitor/EvaluableExprVisitor.h" #include "validator/IndexScanValidator.h" @@ -165,6 +167,8 @@ std::unique_ptr<Validator> Validator::makeValidator(Sentence* sentence, QueryCon return std::make_unique<ShowConfigsValidator>(sentence, context); case Sentence::Kind::kFindPath: return std::make_unique<FindPathValidator>(sentence, context); + case Sentence::Kind::kMatch: + return std::make_unique<MatchValidator>(sentence, context); case Sentence::Kind::kCreateTagIndex: return std::make_unique<CreateTagIndexValidator>(sentence, context); case Sentence::Kind::kShowCreateTagIndex: @@ -191,7 +195,6 @@ std::unique_ptr<Validator> Validator::makeValidator(Sentence* sentence, QueryCon return std::make_unique<DropEdgeIndexValidator>(sentence, context); case Sentence::Kind::kLookup: return std::make_unique<IndexScanValidator>(sentence, context); - case Sentence::Kind::kMatch: case Sentence::Kind::kUnknown: case Sentence::Kind::kDownload: case Sentence::Kind::kIngest: diff --git a/src/visitor/CMakeLists.txt b/src/visitor/CMakeLists.txt index 2f051dbf710c2103eec78fbbf799f5edecf50dda..a7e61a16b0ea7e16a236aeb708459fabc7caefc7 100644 --- a/src/visitor/CMakeLists.txt +++ b/src/visitor/CMakeLists.txt @@ -16,6 +16,7 @@ nebula_add_library( RewriteLabelAttrVisitor.cpp RewriteInputPropVisitor.cpp RewriteSymExprVisitor.cpp + RewriteMatchLabelVisitor.cpp ) nebula_add_subdirectory(test) diff --git a/src/visitor/CollectAllExprsVisitor.cpp b/src/visitor/CollectAllExprsVisitor.cpp index 9683fedd62170c889e67d391da6b4645278d8575..73232ef6b0c98137987aec8034adfba477a041d0 100644 --- a/src/visitor/CollectAllExprsVisitor.cpp +++ b/src/visitor/CollectAllExprsVisitor.cpp @@ -111,6 +111,14 @@ void CollectAllExprsVisitor::visit(LabelExpression *expr) { collectExpr(expr); } +void CollectAllExprsVisitor::visit(LabelAttributeExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(AttributeExpression *expr) { + collectExpr(expr); +} + void CollectAllExprsVisitor::visit(VertexExpression *expr) { collectExpr(expr); } diff --git a/src/visitor/CollectAllExprsVisitor.h b/src/visitor/CollectAllExprsVisitor.h index ab2fb962dc0bf7f7ea11495e170ab75dd29ed458..5ea8b05f6a086162f68ee5b2fa64bd56b19e0b45 100644 --- a/src/visitor/CollectAllExprsVisitor.h +++ b/src/visitor/CollectAllExprsVisitor.h @@ -51,6 +51,8 @@ private: void visit(VariableExpression* expr) override; void visit(VersionedVariableExpression* expr) override; void visit(LabelExpression* expr) override; + void visit(LabelAttributeExpression* expr) override; + void visit(AttributeExpression* expr) override; void visit(VertexExpression* expr) override; void visit(EdgeExpression* expr) override; diff --git a/src/visitor/RewriteMatchLabelVisitor.cpp b/src/visitor/RewriteMatchLabelVisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..febe8439c1d7db956e9169a78e2d1fd64d944251 --- /dev/null +++ b/src/visitor/RewriteMatchLabelVisitor.cpp @@ -0,0 +1,135 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "visitor/RewriteMatchLabelVisitor.h" + +namespace nebula { +namespace graph { + +void RewriteMatchLabelVisitor::visit(TypeCastingExpression *expr) { + if (isLabel(expr->operand())) { + expr->setOperand(rewriter_(expr->operand())); + } else { + expr->operand()->accept(this); + } +} + + +void RewriteMatchLabelVisitor::visit(UnaryExpression *expr) { + if (isLabel(expr->operand())) { + expr->setOperand(rewriter_(expr->operand())); + } else { + expr->operand()->accept(this); + } +} + + +void RewriteMatchLabelVisitor::visit(FunctionCallExpression *expr) { + for (auto &arg : expr->args()->args()) { + if (isLabel(arg.get())) { + arg.reset(rewriter_(arg.get())); + } else { + arg->accept(this); + } + } +} + + +void RewriteMatchLabelVisitor::visit(AttributeExpression *expr) { + if (isLabel(expr->left())) { + expr->setLeft(rewriter_(expr->left())); + } else { + expr->left()->accept(this); + } +} + + +void RewriteMatchLabelVisitor::visit(ListExpression *expr) { + auto newItems = rewriteExprList(expr->items()); + if (!newItems.empty()) { + expr->setItems(std::move(newItems)); + } +} + + +void RewriteMatchLabelVisitor::visit(SetExpression *expr) { + auto newItems = rewriteExprList(expr->items()); + if (!newItems.empty()) { + expr->setItems(std::move(newItems)); + } +} + + +void RewriteMatchLabelVisitor::visit(MapExpression *expr) { + auto &items = expr->items(); + auto iter = std::find_if(items.cbegin(), items.cend(), [] (auto &pair) { + return isLabel(pair.second.get()); + }); + if (iter == items.cend()) { + return; + } + + std::vector<MapExpression::Item> newItems; + newItems.reserve(items.size()); + for (auto &pair : items) { + MapExpression::Item newItem; + newItem.first.reset(new std::string(*pair.first)); + if (isLabel(pair.second.get())) { + newItem.second.reset(rewriter_(pair.second.get())); + } else { + newItem.second = pair.second->clone(); + newItem.second->accept(this); + } + newItems.emplace_back(std::move(newItem)); + } + + expr->setItems(std::move(newItems)); +} + + +void RewriteMatchLabelVisitor::visitBinaryExpr(BinaryExpression *expr) { + if (isLabel(expr->left())) { + expr->setLeft(rewriter_(expr->left())); + } else { + expr->left()->accept(this); + } + if (isLabel(expr->right())) { + expr->setRight(rewriter_(expr->right())); + } else { + expr->right()->accept(this); + } +} + + +std::vector<std::unique_ptr<Expression>> +RewriteMatchLabelVisitor::rewriteExprList(const std::vector<std::unique_ptr<Expression>> &list) { + std::vector<std::unique_ptr<Expression>> newList; + auto iter = std::find_if(list.cbegin(), list.cend(), [] (auto &expr) { + return isLabel(expr.get()); + }); + if (iter != list.cend()) { + std::for_each(list.cbegin(), list.cend(), [this] (auto &expr) { + const_cast<Expression*>(expr.get())->accept(this); + }); + return newList; + } + + newList.reserve(list.size()); + for (auto &expr : list) { + if (isLabel(expr.get())) { + newList.emplace_back(rewriter_(expr.get())); + } else { + auto newExpr = expr->clone(); + newExpr->accept(this); + newList.emplace_back(std::move(newExpr)); + } + } + + return newList; +} + +} // namespace graph +} // namespace nebula diff --git a/src/visitor/RewriteMatchLabelVisitor.h b/src/visitor/RewriteMatchLabelVisitor.h new file mode 100644 index 0000000000000000000000000000000000000000..fc83b490811a0db9710f0df24dbe39a11e495426 --- /dev/null +++ b/src/visitor/RewriteMatchLabelVisitor.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef VISITOR_REWRITEMATCHLABELVISITOR_H_ +#define VISITOR_REWRITEMATCHLABELVISITOR_H_ + +#include <vector> +#include <functional> +#include "visitor/ExprVisitorImpl.h" + +namespace nebula { +namespace graph { + +class RewriteMatchLabelVisitor final : public ExprVisitorImpl { +public: + using Rewriter = std::function<Expression*(const Expression*)>; + explicit RewriteMatchLabelVisitor(Rewriter rewriter) + : rewriter_(std::move(rewriter)) { + } + +private: + bool ok() const override { + return true; + } + + static bool isLabel(const Expression *expr) { + return expr->kind() == Expression::Kind::kLabel + || expr->kind() == Expression::Kind::kLabelAttribute; + } + +private: + using ExprVisitorImpl::visit; + void visit(TypeCastingExpression*) override; + void visit(UnaryExpression*) override; + void visit(FunctionCallExpression*) override; + void visit(ListExpression*) override; + void visit(SetExpression*) override; + void visit(MapExpression*) override; + void visit(ConstantExpression*) override {} + void visit(LabelExpression*) override {} + void visit(AttributeExpression*) override; + void visit(UUIDExpression*) override {} + void visit(LabelAttributeExpression*) override {} + void visit(VariableExpression*) override {} + void visit(VersionedVariableExpression*) override {} + void visit(TagPropertyExpression*) override {} + void visit(EdgePropertyExpression*) override {} + void visit(InputPropertyExpression*) override {} + void visit(VariablePropertyExpression*) override {} + void visit(DestPropertyExpression*) override {} + void visit(SourcePropertyExpression*) override {} + void visit(EdgeSrcIdExpression*) override {} + void visit(EdgeTypeExpression*) override {} + void visit(EdgeRankExpression*) override {} + void visit(EdgeDstIdExpression*) override {} + void visit(VertexExpression*) override {} + void visit(EdgeExpression*) override {} + + void visitBinaryExpr(BinaryExpression *) override; + + std::vector<std::unique_ptr<Expression>> + rewriteExprList(const std::vector<std::unique_ptr<Expression>> &list); + +private: + Rewriter rewriter_; +}; + +} // namespace graph +} // namespace nebula + + +#endif // VISITOR_REWRITEMATCHLABELVISITOR_H_