From fe5029c48ce56fe94ea3777443256808cfcfcb5f Mon Sep 17 00:00:00 2001 From: Yee <2520865+yixinglu@users.noreply.github.com> Date: Mon, 7 Sep 2020 10:21:57 +0800 Subject: [PATCH] Add rewrite label attr expr visitor (#259) * Add rewrite label attr expr visitor * Fix review comments * Rename --- src/util/ExpressionUtils.h | 329 +++++------------------ src/util/test/CMakeLists.txt | 26 ++ src/util/test/ExpressionUtilsTest.cpp | 105 ++++---- src/validator/FetchEdgesValidator.cpp | 10 +- src/validator/FetchVerticesValidator.cpp | 18 +- src/visitor/CMakeLists.txt | 3 + src/visitor/CollectAllExprsVisitor.cpp | 135 ++++++++++ src/visitor/CollectAllExprsVisitor.h | 67 +++++ src/visitor/FindAnyExprVisitor.cpp | 149 ++++++++++ src/visitor/FindAnyExprVisitor.h | 71 +++++ src/visitor/RewriteLabelAttrVisitor.cpp | 142 ++++++++++ src/visitor/RewriteLabelAttrVisitor.h | 67 +++++ 12 files changed, 789 insertions(+), 333 deletions(-) create mode 100644 src/visitor/CollectAllExprsVisitor.cpp create mode 100644 src/visitor/CollectAllExprsVisitor.h create mode 100644 src/visitor/FindAnyExprVisitor.cpp create mode 100644 src/visitor/FindAnyExprVisitor.h create mode 100644 src/visitor/RewriteLabelAttrVisitor.cpp create mode 100644 src/visitor/RewriteLabelAttrVisitor.h diff --git a/src/util/ExpressionUtils.h b/src/util/ExpressionUtils.h index c3317467..4fac1e48 100644 --- a/src/util/ExpressionUtils.h +++ b/src/util/ExpressionUtils.h @@ -10,11 +10,13 @@ #include "common/expression/BinaryExpression.h" #include "common/expression/Expression.h" #include "common/expression/FunctionCallExpression.h" +#include "common/expression/LabelExpression.h" #include "common/expression/PropertyExpression.h" #include "common/expression/TypeCastingExpression.h" #include "common/expression/UnaryExpression.h" -#include "common/expression/LabelExpression.h" -#include "common/expression/LabelAttributeExpression.h" +#include "visitor/CollectAllExprsVisitor.h" +#include "visitor/FindAnyExprVisitor.h" +#include "visitor/RewriteLabelAttrVisitor.h" namespace nebula { namespace graph { @@ -23,180 +25,74 @@ class ExpressionUtils { public: explicit ExpressionUtils(...) = delete; - // return true for continue, false return directly - using Visitor = std::function<bool(const Expression*)>; - using MutableVisitor = std::function<bool(Expression*)>; - - // preorder traverse in fact for tail call optimization - // if want to do some thing like eval, don't try it - template <typename T, - typename V, - typename = std::enable_if_t<std::is_same<std::remove_const_t<T>, Expression>::value>> - static bool traverse(T* expr, V visitor) { - if (!visitor(expr)) { - return false; - } - switch (expr->kind()) { - case Expression::Kind::kDstProperty: - case Expression::Kind::kSrcProperty: - case Expression::Kind::kTagProperty: - case Expression::Kind::kEdgeProperty: - case Expression::Kind::kEdgeSrc: - case Expression::Kind::kEdgeType: - case Expression::Kind::kEdgeRank: - case Expression::Kind::kEdgeDst: - case Expression::Kind::kInputProperty: - case Expression::Kind::kVarProperty: - case Expression::Kind::kVertex: - case Expression::Kind::kEdge: - case Expression::Kind::kUUID: - case Expression::Kind::kVar: - case Expression::Kind::kVersionedVar: - case Expression::Kind::kLabelAttribute: - case Expression::Kind::kConstant: { - return true; - } - case Expression::Kind::kAdd: - case Expression::Kind::kMinus: - case Expression::Kind::kMultiply: - case Expression::Kind::kDivision: - case Expression::Kind::kMod: - case Expression::Kind::kRelEQ: - case Expression::Kind::kRelNE: - case Expression::Kind::kRelLT: - case Expression::Kind::kRelLE: - case Expression::Kind::kRelGT: - case Expression::Kind::kRelGE: - case Expression::Kind::kRelIn: - case Expression::Kind::kRelNotIn: - case Expression::Kind::kContains: - case Expression::Kind::kLogicalAnd: - case Expression::Kind::kLogicalOr: - case Expression::Kind::kLogicalXor: { - using ToType = keep_const_t<T, BinaryExpression>; - auto biExpr = static_cast<ToType*>(expr); - if (!traverse(biExpr->left(), visitor)) { - return false; - } - return traverse(biExpr->right(), visitor); - } - case Expression::Kind::kUnaryIncr: - case Expression::Kind::kUnaryDecr: - case Expression::Kind::kUnaryPlus: - case Expression::Kind::kUnaryNegate: - case Expression::Kind::kUnaryNot: { - using ToType = keep_const_t<T, UnaryExpression>; - auto unaryExpr = static_cast<ToType*>(expr); - return traverse(unaryExpr->operand(), visitor); - } - case Expression::Kind::kTypeCasting: { - using ToType = keep_const_t<T, TypeCastingExpression>; - auto typeCastingExpr = static_cast<ToType*>(expr); - return traverse(typeCastingExpr->operand(), visitor); - } - case Expression::Kind::kFunctionCall: { - using ToType = keep_const_t<T, FunctionCallExpression>; - auto funcExpr = static_cast<ToType*>(expr); - for (auto& arg : funcExpr->args()->args()) { - if (!traverse(arg.get(), visitor)) { - return false; - } - } - return true; - } - case Expression::Kind::kList: // FIXME(dutor) - case Expression::Kind::kSet: - case Expression::Kind::kMap: - case Expression::Kind::kSubscript: - case Expression::Kind::kAttribute: - case Expression::Kind::kLabel: { - return false; - } - } - DLOG(FATAL) << "Impossible expression kind " << static_cast<int>(expr->kind()); - return false; - } - static inline bool isKindOf(const Expression* expr, const std::unordered_set<Expression::Kind>& expected) { return expected.find(expr->kind()) != expected.end(); } // null for not found - static const Expression* findAnyKind(const Expression* self, - const std::unordered_set<Expression::Kind>& expected) { - const Expression* found = nullptr; - traverse(self, [&expected, &found](const Expression* expr) -> bool { - if (isKindOf(expr, expected)) { - found = expr; - return false; // Already find so return now - } - return true; // Not find so continue traverse - }); - return found; + static const Expression* findAny(const Expression* self, + const std::unordered_set<Expression::Kind>& expected) { + FindAnyExprVisitor visitor(expected); + const_cast<Expression*>(self)->accept(&visitor); + return visitor.expr(); } // Find all expression fit any kind // Empty for not found any one - static std::vector<const Expression*> findAnyKindInAll( + static std::vector<const Expression*> collectAll( const Expression* self, const std::unordered_set<Expression::Kind>& expected) { - std::vector<const Expression*> exprs; - traverse(self, [&expected, &exprs](const Expression* expr) -> bool { - if (isKindOf(expr, expected)) { - exprs.emplace_back(expr); - } - return true; // Not return always to traverse entire expression tree - }); - return exprs; + CollectAllExprsVisitor visitor(expected); + const_cast<Expression*>(self)->accept(&visitor); + return std::move(visitor).exprs(); } - static bool hasAnyKind(const Expression* expr, - const std::unordered_set<Expression::Kind>& expected) { - return findAnyKind(expr, expected) != nullptr; + static bool hasAny(const Expression* expr, + const std::unordered_set<Expression::Kind>& expected) { + return findAny(expr, expected) != nullptr; } // Require data from input/variable static bool hasInput(const Expression* expr) { - return hasAnyKind(expr, - {Expression::Kind::kInputProperty, - Expression::Kind::kVarProperty, - Expression::Kind::kVar, - Expression::Kind::kVersionedVar}); + return hasAny(expr, + {Expression::Kind::kInputProperty, + Expression::Kind::kVarProperty, + Expression::Kind::kVar, + Expression::Kind::kVersionedVar}); } // require data from graph storage static const Expression* findStorage(const Expression* expr) { - return findAnyKind(expr, - {Expression::Kind::kTagProperty, - Expression::Kind::kEdgeProperty, - Expression::Kind::kDstProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kEdgeSrc, - Expression::Kind::kEdgeType, - Expression::Kind::kEdgeRank, - Expression::Kind::kEdgeDst, - Expression::Kind::kVertex, - Expression::Kind::kEdge}); + return findAny(expr, + {Expression::Kind::kTagProperty, + Expression::Kind::kEdgeProperty, + Expression::Kind::kDstProperty, + Expression::Kind::kSrcProperty, + Expression::Kind::kEdgeSrc, + Expression::Kind::kEdgeType, + Expression::Kind::kEdgeRank, + Expression::Kind::kEdgeDst, + Expression::Kind::kVertex, + Expression::Kind::kEdge}); } static std::vector<const Expression*> findAllStorage(const Expression* expr) { - return findAnyKindInAll(expr, - {Expression::Kind::kTagProperty, - Expression::Kind::kEdgeProperty, - Expression::Kind::kDstProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kEdgeSrc, - Expression::Kind::kEdgeType, - Expression::Kind::kEdgeRank, - Expression::Kind::kEdgeDst, - Expression::Kind::kVertex, - Expression::Kind::kEdge}); + return collectAll(expr, + {Expression::Kind::kTagProperty, + Expression::Kind::kEdgeProperty, + Expression::Kind::kDstProperty, + Expression::Kind::kSrcProperty, + Expression::Kind::kEdgeSrc, + Expression::Kind::kEdgeType, + Expression::Kind::kEdgeRank, + Expression::Kind::kEdgeDst, + Expression::Kind::kVertex, + Expression::Kind::kEdge}); } static std::vector<const Expression*> findAllInputVariableProp(const Expression* expr) { - return findAnyKindInAll(expr, - {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty}); + return collectAll(expr, {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty}); } static bool hasStorage(const Expression* expr) { @@ -218,23 +114,22 @@ public: } static bool isConstExpr(const Expression* expr) { - return !hasAnyKind(expr, - {Expression::Kind::kInputProperty, - Expression::Kind::kVarProperty, - Expression::Kind::kVar, - Expression::Kind::kVersionedVar, - - Expression::Kind::kLabelAttribute, - Expression::Kind::kTagProperty, - Expression::Kind::kEdgeProperty, - Expression::Kind::kDstProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kEdgeSrc, - Expression::Kind::kEdgeType, - Expression::Kind::kEdgeRank, - Expression::Kind::kEdgeDst, - Expression::Kind::kVertex, - Expression::Kind::kEdge}); + return !hasAny(expr, + {Expression::Kind::kInputProperty, + Expression::Kind::kVarProperty, + Expression::Kind::kVar, + Expression::Kind::kVersionedVar, + Expression::Kind::kLabelAttribute, + Expression::Kind::kTagProperty, + Expression::Kind::kEdgeProperty, + Expression::Kind::kDstProperty, + Expression::Kind::kSrcProperty, + Expression::Kind::kEdgeSrc, + Expression::Kind::kEdgeType, + Expression::Kind::kEdgeRank, + Expression::Kind::kEdgeDst, + Expression::Kind::kVertex, + Expression::Kind::kEdge}); } // clone expression @@ -251,99 +146,8 @@ public: typename = std::enable_if_t<std::is_same<To, EdgePropertyExpression>::value || std::is_same<To, TagPropertyExpression>::value>> static void rewriteLabelAttribute(Expression* expr) { - traverse(expr, [](Expression* current) -> bool { - switch (current->kind()) { - case Expression::Kind::kDstProperty: - case Expression::Kind::kSrcProperty: - case Expression::Kind::kLabelAttribute: - case Expression::Kind::kTagProperty: - case Expression::Kind::kEdgeProperty: - case Expression::Kind::kEdgeSrc: - case Expression::Kind::kEdgeType: - case Expression::Kind::kEdgeRank: - case Expression::Kind::kEdgeDst: - case Expression::Kind::kVertex: - case Expression::Kind::kEdge: - case Expression::Kind::kInputProperty: - case Expression::Kind::kVarProperty: - case Expression::Kind::kUUID: - case Expression::Kind::kVar: - case Expression::Kind::kVersionedVar: - case Expression::Kind::kConstant: { - return true; - } - case Expression::Kind::kAdd: - case Expression::Kind::kMinus: - case Expression::Kind::kMultiply: - case Expression::Kind::kDivision: - case Expression::Kind::kMod: - case Expression::Kind::kRelEQ: - case Expression::Kind::kRelNE: - case Expression::Kind::kRelLT: - case Expression::Kind::kRelLE: - case Expression::Kind::kRelGT: - case Expression::Kind::kRelGE: - case Expression::Kind::kRelIn: - case Expression::Kind::kRelNotIn: - case Expression::Kind::kContains: - case Expression::Kind::kLogicalAnd: - case Expression::Kind::kLogicalOr: - case Expression::Kind::kLogicalXor: { - auto* biExpr = static_cast<BinaryExpression*>(current); - if (biExpr->left()->kind() == Expression::Kind::kLabelAttribute) { - auto* laExpr = static_cast<LabelAttributeExpression*>(biExpr->left()); - biExpr->setLeft(rewriteLabelAttribute<To>(laExpr)); - } - if (biExpr->right()->kind() == Expression::Kind::kLabelAttribute) { - auto* laExpr = static_cast<LabelAttributeExpression*>(biExpr->right()); - biExpr->setRight(rewriteLabelAttribute<To>(laExpr)); - } - return true; - } - case Expression::Kind::kUnaryIncr: - case Expression::Kind::kUnaryDecr: - case Expression::Kind::kUnaryPlus: - case Expression::Kind::kUnaryNegate: - case Expression::Kind::kUnaryNot: { - auto* unaryExpr = static_cast<UnaryExpression*>(current); - if (unaryExpr->operand()->kind() == Expression::Kind::kLabelAttribute) { - auto* laExpr = - static_cast<LabelAttributeExpression*>(unaryExpr->operand()); - unaryExpr->setOperand(rewriteLabelAttribute<To>(laExpr)); - } - return true; - } - case Expression::Kind::kTypeCasting: { - auto* typeCastingExpr = static_cast<TypeCastingExpression*>(current); - if (typeCastingExpr->operand()->kind() == Expression::Kind::kLabelAttribute) { - auto* laExpr = - static_cast<LabelAttributeExpression*>(typeCastingExpr->operand()); - typeCastingExpr->setOperand(rewriteLabelAttribute<To>(laExpr)); - } - return true; - } - case Expression::Kind::kFunctionCall: { - auto* funcExpr = static_cast<FunctionCallExpression*>(current); - for (auto& arg : funcExpr->args()->args()) { - if (arg->kind() == Expression::Kind::kLabelAttribute) { - auto* laExpr = static_cast<LabelAttributeExpression*>(arg.get()); - arg.reset(rewriteLabelAttribute<To>(laExpr)); - } - } - return true; - } - case Expression::Kind::kList: // FIXME(dutor) - case Expression::Kind::kSet: - case Expression::Kind::kMap: - case Expression::Kind::kSubscript: - case Expression::Kind::kAttribute: - case Expression::Kind::kLabel: { - return false; - } - } // switch - DLOG(FATAL) << "Impossible expression kind " << static_cast<int>(current->kind()); - return false; - }); // traverse + RewriteLabelAttrVisitor visitor(std::is_same<To, TagPropertyExpression>::value); + expr->accept(&visitor); } template <typename To, @@ -353,16 +157,9 @@ public: return new To(new std::string(std::move(*expr->left()->name())), new std::string(std::move(*expr->right()->name()))); } - -private: - // keep const or non-const with T - template <typename T, typename To> - using keep_const_t = std::conditional_t<std::is_const<T>::value, - const std::remove_const_t<To>, - std::remove_const_t<To>>; }; } // namespace graph } // namespace nebula -#endif // _UTIL_EXPRESSION_UTILS_H_ +#endif // _UTIL_EXPRESSION_UTILS_H_ diff --git a/src/util/test/CMakeLists.txt b/src/util/test/CMakeLists.txt index a47381d5..c2cb340d 100644 --- a/src/util/test/CMakeLists.txt +++ b/src/util/test/CMakeLists.txt @@ -13,9 +13,35 @@ nebula_add_test( $<TARGET_OBJECTS:common_function_manager_obj> $<TARGET_OBJECTS:common_time_obj> $<TARGET_OBJECTS:common_time_function_obj> + $<TARGET_OBJECTS:common_meta_thrift_obj> + $<TARGET_OBJECTS:common_meta_client_obj> + $<TARGET_OBJECTS:common_meta_obj> + $<TARGET_OBJECTS:common_storage_thrift_obj> + $<TARGET_OBJECTS:common_graph_thrift_obj> + $<TARGET_OBJECTS:common_conf_obj> + $<TARGET_OBJECTS:common_fs_obj> + $<TARGET_OBJECTS:common_thrift_obj> + $<TARGET_OBJECTS:common_common_thrift_obj> + $<TARGET_OBJECTS:common_thread_obj> + $<TARGET_OBJECTS:common_file_based_cluster_id_man_obj> + $<TARGET_OBJECTS:common_charset_obj> + $<TARGET_OBJECTS:common_encryption_obj> + $<TARGET_OBJECTS:common_http_client_obj> + $<TARGET_OBJECTS:common_process_obj> + $<TARGET_OBJECTS:common_agg_function_obj> $<TARGET_OBJECTS:idgenerator_obj> + $<TARGET_OBJECTS:expr_visitor_obj> + $<TARGET_OBJECTS:session_obj> + $<TARGET_OBJECTS:graph_auth_obj> + $<TARGET_OBJECTS:graph_flags_obj> + $<TARGET_OBJECTS:util_obj> + $<TARGET_OBJECTS:planner_obj> + $<TARGET_OBJECTS:parser_obj> + $<TARGET_OBJECTS:context_obj> + $<TARGET_OBJECTS:validator_obj> LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} + proxygenlib ) diff --git a/src/util/test/ExpressionUtilsTest.cpp b/src/util/test/ExpressionUtilsTest.cpp index b2ff556c..d4286255 100644 --- a/src/util/test/ExpressionUtilsTest.cpp +++ b/src/util/test/ExpressionUtilsTest.cpp @@ -21,56 +21,56 @@ TEST_F(ExpressionUtilsTest, CheckComponent) { const auto root = std::make_unique<ConstantExpression>(); ASSERT_TRUE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kConstant})); - ASSERT_TRUE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kConstant})); + ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kConstant})); ASSERT_TRUE(ExpressionUtils::isKindOf( root.get(), {Expression::Kind::kConstant, Expression::Kind::kAdd})); - ASSERT_TRUE(ExpressionUtils::hasAnyKind( - root.get(), {Expression::Kind::kConstant, Expression::Kind::kAdd})); + ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), + {Expression::Kind::kConstant, Expression::Kind::kAdd})); ASSERT_FALSE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kAdd})); - ASSERT_FALSE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kAdd})); + ASSERT_FALSE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kAdd})); ASSERT_FALSE(ExpressionUtils::isKindOf( root.get(), {Expression::Kind::kDivision, Expression::Kind::kAdd})); - ASSERT_FALSE(ExpressionUtils::hasAnyKind( + ASSERT_FALSE(ExpressionUtils::hasAny( root.get(), {Expression::Kind::kDstProperty, Expression::Kind::kAdd})); // find const Expression *found = - ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kConstant}); + ExpressionUtils::findAny(root.get(), {Expression::Kind::kConstant}); ASSERT_EQ(found, root.get()); - found = ExpressionUtils::findAnyKind( + found = ExpressionUtils::findAny( root.get(), {Expression::Kind::kConstant, Expression::Kind::kAdd, Expression::Kind::kEdgeProperty}); ASSERT_EQ(found, root.get()); - found = ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kEdgeDst}); + found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kEdgeDst}); ASSERT_EQ(found, nullptr); - found = ExpressionUtils::findAnyKind( + found = ExpressionUtils::findAny( root.get(), {Expression::Kind::kEdgeRank, Expression::Kind::kInputProperty}); ASSERT_EQ(found, nullptr); // find all const auto willFoundAll = std::vector<const Expression *>{root.get()}; std::vector<const Expression *> founds = - ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kConstant}); + ExpressionUtils::collectAll(root.get(), {Expression::Kind::kConstant}); ASSERT_EQ(founds, willFoundAll); - founds = ExpressionUtils::findAnyKindInAll( + founds = ExpressionUtils::collectAll( root.get(), {Expression::Kind::kAdd, Expression::Kind::kConstant, Expression::Kind::kEdgeDst}); ASSERT_EQ(founds, willFoundAll); - founds = ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kSrcProperty}); + founds = ExpressionUtils::collectAll(root.get(), {Expression::Kind::kSrcProperty}); ASSERT_TRUE(founds.empty()); - founds = ExpressionUtils::findAnyKindInAll(root.get(), - {Expression::Kind::kUnaryNegate, - Expression::Kind::kEdgeDst, - Expression::Kind::kEdgeDst}); + founds = ExpressionUtils::collectAll(root.get(), + {Expression::Kind::kUnaryNegate, + Expression::Kind::kEdgeDst, + Expression::Kind::kEdgeDst}); ASSERT_TRUE(founds.empty()); } @@ -83,54 +83,54 @@ TEST_F(ExpressionUtilsTest, CheckComponent) { new TypeCastingExpression(Value::Type::BOOL, new ConstantExpression()))); ASSERT_TRUE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kTypeCasting})); - ASSERT_TRUE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kConstant})); + ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kConstant})); ASSERT_TRUE(ExpressionUtils::isKindOf( root.get(), {Expression::Kind::kTypeCasting, Expression::Kind::kAdd})); - ASSERT_TRUE(ExpressionUtils::hasAnyKind( + ASSERT_TRUE(ExpressionUtils::hasAny( root.get(), {Expression::Kind::kTypeCasting, Expression::Kind::kAdd})); ASSERT_FALSE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kAdd})); - ASSERT_FALSE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kAdd})); + ASSERT_FALSE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kAdd})); ASSERT_FALSE(ExpressionUtils::isKindOf( root.get(), {Expression::Kind::kDivision, Expression::Kind::kAdd})); - ASSERT_FALSE(ExpressionUtils::hasAnyKind( + ASSERT_FALSE(ExpressionUtils::hasAny( root.get(), {Expression::Kind::kDstProperty, Expression::Kind::kAdd})); // found const Expression *found = - ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kTypeCasting}); + ExpressionUtils::findAny(root.get(), {Expression::Kind::kTypeCasting}); ASSERT_EQ(found, root.get()); - found = ExpressionUtils::findAnyKind(root.get(), - {Expression::Kind::kFunctionCall, - Expression::Kind::kTypeCasting, - Expression::Kind::kLogicalAnd}); + found = ExpressionUtils::findAny(root.get(), + {Expression::Kind::kFunctionCall, + Expression::Kind::kTypeCasting, + Expression::Kind::kLogicalAnd}); ASSERT_EQ(found, root.get()); - found = ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kDivision}); + found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kDivision}); ASSERT_EQ(found, nullptr); - found = ExpressionUtils::findAnyKind(root.get(), - {Expression::Kind::kLogicalXor, - Expression::Kind::kRelGE, - Expression::Kind::kEdgeProperty}); + found = ExpressionUtils::findAny(root.get(), + {Expression::Kind::kLogicalXor, + Expression::Kind::kRelGE, + Expression::Kind::kEdgeProperty}); ASSERT_EQ(found, nullptr); // found all std::vector<const Expression *> founds = - ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kConstant}); + ExpressionUtils::collectAll(root.get(), {Expression::Kind::kConstant}); ASSERT_EQ(founds.size(), 1); - founds = ExpressionUtils::findAnyKindInAll( + founds = ExpressionUtils::collectAll( root.get(), {Expression::Kind::kFunctionCall, Expression::Kind::kTypeCasting}); ASSERT_EQ(founds.size(), 3); - founds = ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kAdd}); + founds = ExpressionUtils::collectAll(root.get(), {Expression::Kind::kAdd}); ASSERT_TRUE(founds.empty()); - founds = ExpressionUtils::findAnyKindInAll( + founds = ExpressionUtils::collectAll( root.get(), {Expression::Kind::kRelLE, Expression::Kind::kDstProperty}); ASSERT_TRUE(founds.empty()); } @@ -151,54 +151,53 @@ TEST_F(ExpressionUtilsTest, CheckComponent) { new ConstantExpression(2))); ASSERT_TRUE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kAdd})); - ASSERT_TRUE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kMinus})); + ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kMinus})); ASSERT_TRUE(ExpressionUtils::isKindOf( root.get(), {Expression::Kind::kTypeCasting, Expression::Kind::kAdd})); - ASSERT_TRUE(ExpressionUtils::hasAnyKind( + ASSERT_TRUE(ExpressionUtils::hasAny( root.get(), {Expression::Kind::kLabelAttribute, Expression::Kind::kDivision})); ASSERT_FALSE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kConstant})); - ASSERT_FALSE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kFunctionCall})); + ASSERT_FALSE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kFunctionCall})); ASSERT_FALSE(ExpressionUtils::isKindOf( root.get(), {Expression::Kind::kDivision, Expression::Kind::kEdgeProperty})); - ASSERT_FALSE(ExpressionUtils::hasAnyKind( + ASSERT_FALSE(ExpressionUtils::hasAny( root.get(), {Expression::Kind::kDstProperty, Expression::Kind::kLogicalAnd})); // found - const Expression *found = - ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kAdd}); + const Expression *found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kAdd}); ASSERT_EQ(found, root.get()); - found = ExpressionUtils::findAnyKind(root.get(), - {Expression::Kind::kFunctionCall, - Expression::Kind::kRelLE, - Expression::Kind::kMultiply}); + found = ExpressionUtils::findAny(root.get(), + {Expression::Kind::kFunctionCall, + Expression::Kind::kRelLE, + Expression::Kind::kMultiply}); ASSERT_NE(found, nullptr); - found = ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kInputProperty}); + found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kInputProperty}); ASSERT_EQ(found, nullptr); - found = ExpressionUtils::findAnyKind(root.get(), - {Expression::Kind::kLogicalXor, - Expression::Kind::kEdgeRank, - Expression::Kind::kUnaryNot}); + found = ExpressionUtils::findAny(root.get(), + {Expression::Kind::kLogicalXor, + Expression::Kind::kEdgeRank, + Expression::Kind::kUnaryNot}); ASSERT_EQ(found, nullptr); // found all std::vector<const Expression *> founds = - ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kConstant}); + ExpressionUtils::collectAll(root.get(), {Expression::Kind::kConstant}); ASSERT_EQ(founds.size(), 6); - founds = ExpressionUtils::findAnyKindInAll( + founds = ExpressionUtils::collectAll( root.get(), {Expression::Kind::kDivision, Expression::Kind::kMinus}); ASSERT_EQ(founds.size(), 2); - founds = ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kEdgeDst}); + founds = ExpressionUtils::collectAll(root.get(), {Expression::Kind::kEdgeDst}); ASSERT_TRUE(founds.empty()); - founds = ExpressionUtils::findAnyKindInAll( + founds = ExpressionUtils::collectAll( root.get(), {Expression::Kind::kLogicalAnd, Expression::Kind::kUnaryNegate}); ASSERT_TRUE(founds.empty()); } diff --git a/src/validator/FetchEdgesValidator.cpp b/src/validator/FetchEdgesValidator.cpp index 5420ab82..3fbda6ff 100644 --- a/src/validator/FetchEdgesValidator.cpp +++ b/src/validator/FetchEdgesValidator.cpp @@ -251,11 +251,11 @@ Status FetchEdgesValidator::preparePropertiesWithoutYield() { /*static*/ const Expression *FetchEdgesValidator::findInvalidYieldExpression(const Expression *root) { - return ExpressionUtils::findAnyKind(root, - {Expression::Kind::kInputProperty, - Expression::Kind::kVarProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kDstProperty}); + return ExpressionUtils::findAny(root, + {Expression::Kind::kInputProperty, + Expression::Kind::kVarProperty, + Expression::Kind::kSrcProperty, + Expression::Kind::kDstProperty}); } // TODO(shylock) optimize dedup input when distinct given diff --git a/src/validator/FetchVerticesValidator.cpp b/src/validator/FetchVerticesValidator.cpp index 0d2c7622..2d9fa505 100644 --- a/src/validator/FetchVerticesValidator.cpp +++ b/src/validator/FetchVerticesValidator.cpp @@ -236,15 +236,15 @@ Status FetchVerticesValidator::preparePropertiesWithoutYield() { /*static*/ const Expression *FetchVerticesValidator::findInvalidYieldExpression(const Expression *root) { - return ExpressionUtils::findAnyKind(root, - {Expression::Kind::kInputProperty, - Expression::Kind::kVarProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kDstProperty, - Expression::Kind::kEdgeSrc, - Expression::Kind::kEdgeType, - Expression::Kind::kEdgeRank, - Expression::Kind::kEdgeDst}); + return ExpressionUtils::findAny(root, + {Expression::Kind::kInputProperty, + Expression::Kind::kVarProperty, + Expression::Kind::kSrcProperty, + Expression::Kind::kDstProperty, + Expression::Kind::kEdgeSrc, + Expression::Kind::kEdgeType, + Expression::Kind::kEdgeRank, + Expression::Kind::kEdgeDst}); } // TODO(shylock) optimize dedup input when distinct given diff --git a/src/visitor/CMakeLists.txt b/src/visitor/CMakeLists.txt index af97ecf1..6216ba9d 100644 --- a/src/visitor/CMakeLists.txt +++ b/src/visitor/CMakeLists.txt @@ -6,6 +6,9 @@ nebula_add_library( expr_visitor_obj OBJECT ExprVisitorImpl.cpp + CollectAllExprsVisitor.cpp DeducePropsVisitor.cpp DeduceTypeVisitor.cpp + FindAnyExprVisitor.cpp + RewriteLabelAttrVisitor.cpp ) diff --git a/src/visitor/CollectAllExprsVisitor.cpp b/src/visitor/CollectAllExprsVisitor.cpp new file mode 100644 index 00000000..d1f6a7ee --- /dev/null +++ b/src/visitor/CollectAllExprsVisitor.cpp @@ -0,0 +1,135 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "visitor/CollectAllExprsVisitor.h" + +namespace nebula { +namespace graph { + +CollectAllExprsVisitor::CollectAllExprsVisitor( + const std::unordered_set<Expression::Kind> &exprKinds) + : exprKinds_(exprKinds) {} + +void CollectAllExprsVisitor::visit(TypeCastingExpression *expr) { + collectExpr(expr); + expr->operand()->accept(this); +} + +void CollectAllExprsVisitor::visit(UnaryExpression *expr) { + collectExpr(expr); + expr->operand()->accept(this); +} + +void CollectAllExprsVisitor::visit(FunctionCallExpression *expr) { + collectExpr(expr); + for (const auto &arg : expr->args()->args()) { + arg->accept(this); + } +} + +void CollectAllExprsVisitor::visit(ListExpression *expr) { + collectExpr(expr); + for (auto item : expr->items()) { + const_cast<Expression *>(item)->accept(this); + } +} + +void CollectAllExprsVisitor::visit(SetExpression *expr) { + collectExpr(expr); + for (auto item : expr->items()) { + const_cast<Expression *>(item)->accept(this); + } +} + +void CollectAllExprsVisitor::visit(MapExpression *expr) { + collectExpr(expr); + for (const auto &pair : expr->items()) { + const_cast<Expression *>(pair.second)->accept(this); + } +} + +void CollectAllExprsVisitor::visit(ConstantExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(EdgePropertyExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(TagPropertyExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(InputPropertyExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(VariablePropertyExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(SourcePropertyExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(DestPropertyExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(EdgeSrcIdExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(EdgeTypeExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(EdgeRankExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(EdgeDstIdExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(UUIDExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(VariableExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(VersionedVariableExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(LabelExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(VertexExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visit(EdgeExpression *expr) { + collectExpr(expr); +} + +void CollectAllExprsVisitor::visitBinaryExpr(BinaryExpression *expr) { + collectExpr(expr); + expr->left()->accept(this); + expr->right()->accept(this); +} + +void CollectAllExprsVisitor::collectExpr(const Expression *expr) { + if (exprKinds_.find(expr->kind()) != exprKinds_.cend()) { + exprs_.push_back(expr); + } +} + +} // namespace graph +} // namespace nebula diff --git a/src/visitor/CollectAllExprsVisitor.h b/src/visitor/CollectAllExprsVisitor.h new file mode 100644 index 00000000..ab2fb962 --- /dev/null +++ b/src/visitor/CollectAllExprsVisitor.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef VISITOR_COLLECTALLEXPRSVISITOR_H_ +#define VISITOR_COLLECTALLEXPRSVISITOR_H_ + +#include <unordered_set> + +#include "common/expression/Expression.h" +#include "visitor/ExprVisitorImpl.h" + +namespace nebula { +namespace graph { + +class CollectAllExprsVisitor final : public ExprVisitorImpl { +public: + explicit CollectAllExprsVisitor(const std::unordered_set<Expression::Kind>& exprKinds); + bool ok() const override { + return !exprKinds_.empty(); + } + + std::vector<const Expression*> exprs() && { + return std::move(exprs_); + } + +private: + using ExprVisitorImpl::visit; + + void visit(TypeCastingExpression* expr) override; + void visit(UnaryExpression* expr) override; + void visit(FunctionCallExpression* expr) override; + void visit(ListExpression* expr) override; + void visit(SetExpression* expr) override; + void visit(MapExpression* expr) override; + + void visit(ConstantExpression* expr) override; + void visit(EdgePropertyExpression* expr) override; + void visit(TagPropertyExpression* expr) override; + void visit(InputPropertyExpression* expr) override; + void visit(VariablePropertyExpression* expr) override; + void visit(SourcePropertyExpression* expr) override; + void visit(DestPropertyExpression* expr) override; + void visit(EdgeSrcIdExpression* expr) override; + void visit(EdgeTypeExpression* expr) override; + void visit(EdgeRankExpression* expr) override; + void visit(EdgeDstIdExpression* expr) override; + void visit(UUIDExpression* expr) override; + void visit(VariableExpression* expr) override; + void visit(VersionedVariableExpression* expr) override; + void visit(LabelExpression* expr) override; + void visit(VertexExpression* expr) override; + void visit(EdgeExpression* expr) override; + + void visitBinaryExpr(BinaryExpression* expr) override; + void collectExpr(const Expression* expr); + + const std::unordered_set<Expression::Kind>& exprKinds_; + std::vector<const Expression*> exprs_; +}; + +} // namespace graph +} // namespace nebula + +#endif // VISITOR_COLLECTALLEXPRSVISITOR_H_ diff --git a/src/visitor/FindAnyExprVisitor.cpp b/src/visitor/FindAnyExprVisitor.cpp new file mode 100644 index 00000000..448aa066 --- /dev/null +++ b/src/visitor/FindAnyExprVisitor.cpp @@ -0,0 +1,149 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "visitor/FindAnyExprVisitor.h" + +namespace nebula { +namespace graph { + +FindAnyExprVisitor::FindAnyExprVisitor(const std::unordered_set<Expression::Kind> &kinds) + : kinds_(kinds) { + DCHECK(!kinds.empty()); +} + +void FindAnyExprVisitor::visit(TypeCastingExpression *expr) { + findExpr(expr); + if (found_) return; + expr->operand()->accept(this); +} + +void FindAnyExprVisitor::visit(UnaryExpression *expr) { + findExpr(expr); + if (found_) return; + expr->operand()->accept(this); +} + +void FindAnyExprVisitor::visit(FunctionCallExpression *expr) { + findExpr(expr); + if (found_) return; + for (const auto &arg : expr->args()->args()) { + arg->accept(this); + if (found_) return; + } +} + +void FindAnyExprVisitor::visit(ListExpression *expr) { + findExpr(expr); + if (found_) return; + for (const auto &item : expr->items()) { + const_cast<Expression *>(item)->accept(this); + if (found_) return; + } +} + +void FindAnyExprVisitor::visit(SetExpression *expr) { + findExpr(expr); + if (found_) return; + for (const auto &item : expr->items()) { + const_cast<Expression *>(item)->accept(this); + if (found_) return; + } +} + +void FindAnyExprVisitor::visit(MapExpression *expr) { + findExpr(expr); + if (found_) return; + for (const auto &pair : expr->items()) { + const_cast<Expression *>(pair.second)->accept(this); + if (found_) return; + } +} + +void FindAnyExprVisitor::visit(ConstantExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(EdgePropertyExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(TagPropertyExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(InputPropertyExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(VariablePropertyExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(SourcePropertyExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(DestPropertyExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(EdgeSrcIdExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(EdgeTypeExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(EdgeRankExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(EdgeDstIdExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(UUIDExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(VariableExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(VersionedVariableExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(LabelExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(VertexExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visit(EdgeExpression *expr) { + findExpr(expr); +} + +void FindAnyExprVisitor::visitBinaryExpr(BinaryExpression *expr) { + findExpr(expr); + if (found_) return; + expr->left()->accept(this); + if (found_) return; + expr->right()->accept(this); +} + +void FindAnyExprVisitor::findExpr(const Expression *expr) { + found_ = kinds_.find(expr->kind()) != kinds_.cend(); + if (found_) { + expr_ = expr; + } +} + +} // namespace graph +} // namespace nebula diff --git a/src/visitor/FindAnyExprVisitor.h b/src/visitor/FindAnyExprVisitor.h new file mode 100644 index 00000000..580ec338 --- /dev/null +++ b/src/visitor/FindAnyExprVisitor.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef VISITOR_FINDANYEXPRVISITOR_H_ +#define VISITOR_FINDANYEXPRVISITOR_H_ + +#include <unordered_set> + +#include "common/expression/Expression.h" +#include "visitor/ExprVisitorImpl.h" + +namespace nebula { +namespace graph { + +class FindAnyExprVisitor final : public ExprVisitorImpl { +public: + explicit FindAnyExprVisitor(const std::unordered_set<Expression::Kind>& kinds); + + bool ok() const override { + // continue if not found + return !found_; + } + + const Expression* expr() const { + return expr_; + } + +private: + using ExprVisitorImpl::visit; + + void visit(TypeCastingExpression* expr) override; + void visit(UnaryExpression* expr) override; + void visit(FunctionCallExpression* expr) override; + void visit(ListExpression* expr) override; + void visit(SetExpression* expr) override; + void visit(MapExpression* expr) override; + + void visit(ConstantExpression* expr) override; + void visit(EdgePropertyExpression* expr) override; + void visit(TagPropertyExpression* expr) override; + void visit(InputPropertyExpression* expr) override; + void visit(VariablePropertyExpression* expr) override; + void visit(SourcePropertyExpression* expr) override; + void visit(DestPropertyExpression* expr) override; + void visit(EdgeSrcIdExpression* expr) override; + void visit(EdgeTypeExpression* expr) override; + void visit(EdgeRankExpression* expr) override; + void visit(EdgeDstIdExpression* expr) override; + void visit(UUIDExpression* expr) override; + void visit(VariableExpression* expr) override; + void visit(VersionedVariableExpression* expr) override; + void visit(LabelExpression* expr) override; + void visit(VertexExpression* expr) override; + void visit(EdgeExpression* expr) override; + + void visitBinaryExpr(BinaryExpression* expr) override; + + void findExpr(const Expression* expr); + + bool found_{false}; + const Expression* expr_{nullptr}; + const std::unordered_set<Expression::Kind>& kinds_; +}; + +} // namespace graph +} // namespace nebula + +#endif // VISITOR_FINDANYEXPRVISITOR_H_ diff --git a/src/visitor/RewriteLabelAttrVisitor.cpp b/src/visitor/RewriteLabelAttrVisitor.cpp new file mode 100644 index 00000000..de55dfe7 --- /dev/null +++ b/src/visitor/RewriteLabelAttrVisitor.cpp @@ -0,0 +1,142 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "visitor/RewriteLabelAttrVisitor.h" + +#include "common/base/Logging.h" + +namespace nebula { +namespace graph { + +RewriteLabelAttrVisitor::RewriteLabelAttrVisitor(bool isTag) : isTag_(isTag) {} + +void RewriteLabelAttrVisitor::visit(TypeCastingExpression* expr) { + if (isLabelAttrExpr(expr->operand())) { + auto operand = static_cast<LabelAttributeExpression*>(expr->operand()); + expr->setOperand(createExpr(operand)); + } else { + expr->operand()->accept(this); + } +} + +void RewriteLabelAttrVisitor::visit(UnaryExpression* expr) { + if (isLabelAttrExpr(expr->operand())) { + auto operand = static_cast<LabelAttributeExpression*>(expr->operand()); + expr->setOperand(createExpr(operand)); + } else { + expr->operand()->accept(this); + } +} + +void RewriteLabelAttrVisitor::visit(FunctionCallExpression* expr) { + for (auto& arg : expr->args()->args()) { + if (isLabelAttrExpr(arg.get())) { + auto newArg = static_cast<LabelAttributeExpression*>(arg.get()); + arg.reset(createExpr(newArg)); + } else { + arg->accept(this); + } + } +} + +void RewriteLabelAttrVisitor::visit(ListExpression* expr) { + auto newItems = rewriteExprList(expr->items()); + if (!newItems.empty()) { + expr->setItems(std::move(newItems)); + } +} + +void RewriteLabelAttrVisitor::visit(SetExpression* expr) { + auto newItems = rewriteExprList(expr->items()); + if (!newItems.empty()) { + expr->setItems(std::move(newItems)); + } +} + +void RewriteLabelAttrVisitor::visit(MapExpression* expr) { + auto items = expr->items(); + auto found = std::find_if( + items.cbegin(), items.cend(), [](auto& pair) { return isLabelAttrExpr(pair.second); }); + if (found == items.cend()) { + std::for_each(items.begin(), items.end(), [this](auto& pair) { + const_cast<Expression*>(pair.second)->accept(this); + }); + return; + } + + std::vector<MapExpression::Item> newItems; + newItems.reserve(items.size()); + for (auto& pair : items) { + MapExpression::Item newItem; + newItem.first.reset(new std::string(*pair.first)); + if (isLabelAttrExpr(pair.second)) { + auto symExpr = static_cast<const LabelAttributeExpression*>(pair.second); + newItem.second.reset(createExpr(symExpr)); + } else { + newItem.second = Expression::decode(pair.second->encode()); + newItem.second->accept(this); + } + newItems.emplace_back(std::move(newItem)); + } + expr->setItems(std::move(newItems)); +} + +void RewriteLabelAttrVisitor::visitBinaryExpr(BinaryExpression* expr) { + if (isLabelAttrExpr(expr->left())) { + auto left = static_cast<const LabelAttributeExpression*>(expr->left()); + expr->setLeft(createExpr(left)); + } else { + expr->left()->accept(this); + } + if (isLabelAttrExpr(expr->right())) { + auto right = static_cast<const LabelAttributeExpression*>(expr->right()); + expr->setRight(createExpr(right)); + } else { + expr->right()->accept(this); + } +} + +std::vector<std::unique_ptr<Expression>> RewriteLabelAttrVisitor::rewriteExprList( + const std::vector<const Expression*>& exprs) { + std::vector<std::unique_ptr<Expression>> newExprs; + + auto found = std::find_if(exprs.cbegin(), exprs.cend(), isLabelAttrExpr); + if (found == exprs.cend()) { + std::for_each(exprs.cbegin(), exprs.cend(), [this](auto expr) { + const_cast<Expression*>(expr)->accept(this); + }); + return newExprs; + } + + newExprs.reserve(exprs.size()); + for (auto item : exprs) { + if (isLabelAttrExpr(item)) { + auto symExpr = static_cast<const LabelAttributeExpression*>(item); + newExprs.emplace_back(createExpr(symExpr)); + } else { + auto newExpr = Expression::decode(item->encode()); + newExpr->accept(this); + newExprs.emplace_back(std::move(newExpr)); + } + } + return newExprs; +} + +Expression* RewriteLabelAttrVisitor::createExpr(const LabelAttributeExpression* expr) { + auto leftName = new std::string(*expr->left()->name()); + auto rightName = new std::string(*expr->right()->name()); + if (isTag_) { + return new TagPropertyExpression(leftName, rightName); + } + return new EdgePropertyExpression(leftName, rightName); +} + +bool RewriteLabelAttrVisitor::isLabelAttrExpr(const Expression* expr) { + return expr->kind() == Expression::Kind::kLabelAttribute; +} + +} // namespace graph +} // namespace nebula diff --git a/src/visitor/RewriteLabelAttrVisitor.h b/src/visitor/RewriteLabelAttrVisitor.h new file mode 100644 index 00000000..9697c255 --- /dev/null +++ b/src/visitor/RewriteLabelAttrVisitor.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef VISITOR_REWRITELABELATTRVISITOR_H_ +#define VISITOR_REWRITELABELATTRVISITOR_H_ + +#include <memory> +#include <vector> + +#include "visitor/ExprVisitorImpl.h" + +namespace nebula { +namespace graph { + +class RewriteLabelAttrVisitor final : public ExprVisitorImpl { +public: + explicit RewriteLabelAttrVisitor(bool isTag); + + bool ok() const override { + return true; + } + +private: + using ExprVisitorImpl::visit; + + void visit(TypeCastingExpression *expr) override; + void visit(UnaryExpression *expr) override; + void visit(FunctionCallExpression *expr) override; + void visit(ListExpression *expr) override; + void visit(SetExpression *expr) override; + void visit(MapExpression *expr) override; + void visit(ConstantExpression *) override {} + void visit(LabelExpression *) override {} + void visit(UUIDExpression *) override {} + void visit(LabelAttributeExpression *) override {} + void visit(VariableExpression *) override {} + void visit(VersionedVariableExpression *) override {} + void visit(TagPropertyExpression *) override {} + void visit(EdgePropertyExpression *) override {} + void visit(InputPropertyExpression *) override {} + void visit(VariablePropertyExpression *) override {} + void visit(DestPropertyExpression *) override {} + void visit(SourcePropertyExpression *) override {} + void visit(EdgeSrcIdExpression *) override {} + void visit(EdgeTypeExpression *) override {} + void visit(EdgeRankExpression *) override {} + void visit(EdgeDstIdExpression *) override {} + void visit(VertexExpression *) override {} + void visit(EdgeExpression *) override {} + + void visitBinaryExpr(BinaryExpression *expr) override; + + Expression *createExpr(const LabelAttributeExpression *expr); + std::vector<std::unique_ptr<Expression>> rewriteExprList( + const std::vector<const Expression *> &exprs); + static bool isLabelAttrExpr(const Expression *expr); + + bool isTag_{false}; +}; + +} // namespace graph +} // namespace nebula + +#endif // VISITOR_REWRITELABELATTRVISITOR_H_ -- GitLab