Skip to content
Snippets Groups Projects
Unverified Commit fe5029c4 authored by Yee's avatar Yee Committed by GitHub
Browse files

Add rewrite label attr expr visitor (#259)

* Add rewrite label attr expr visitor

* Fix review comments

* Rename
parent 7417badf
No related branches found
No related tags found
No related merge requests found
......@@ -10,11 +10,13 @@
#include "common/expression/BinaryExpression.h"
#include "common/expression/Expression.h"
#include "common/expression/FunctionCallExpression.h"
#include "common/expression/LabelExpression.h"
#include "common/expression/PropertyExpression.h"
#include "common/expression/TypeCastingExpression.h"
#include "common/expression/UnaryExpression.h"
#include "common/expression/LabelExpression.h"
#include "common/expression/LabelAttributeExpression.h"
#include "visitor/CollectAllExprsVisitor.h"
#include "visitor/FindAnyExprVisitor.h"
#include "visitor/RewriteLabelAttrVisitor.h"
namespace nebula {
namespace graph {
......@@ -23,180 +25,74 @@ class ExpressionUtils {
public:
explicit ExpressionUtils(...) = delete;
// return true for continue, false return directly
using Visitor = std::function<bool(const Expression*)>;
using MutableVisitor = std::function<bool(Expression*)>;
// preorder traverse in fact for tail call optimization
// if want to do some thing like eval, don't try it
template <typename T,
typename V,
typename = std::enable_if_t<std::is_same<std::remove_const_t<T>, Expression>::value>>
static bool traverse(T* expr, V visitor) {
if (!visitor(expr)) {
return false;
}
switch (expr->kind()) {
case Expression::Kind::kDstProperty:
case Expression::Kind::kSrcProperty:
case Expression::Kind::kTagProperty:
case Expression::Kind::kEdgeProperty:
case Expression::Kind::kEdgeSrc:
case Expression::Kind::kEdgeType:
case Expression::Kind::kEdgeRank:
case Expression::Kind::kEdgeDst:
case Expression::Kind::kInputProperty:
case Expression::Kind::kVarProperty:
case Expression::Kind::kVertex:
case Expression::Kind::kEdge:
case Expression::Kind::kUUID:
case Expression::Kind::kVar:
case Expression::Kind::kVersionedVar:
case Expression::Kind::kLabelAttribute:
case Expression::Kind::kConstant: {
return true;
}
case Expression::Kind::kAdd:
case Expression::Kind::kMinus:
case Expression::Kind::kMultiply:
case Expression::Kind::kDivision:
case Expression::Kind::kMod:
case Expression::Kind::kRelEQ:
case Expression::Kind::kRelNE:
case Expression::Kind::kRelLT:
case Expression::Kind::kRelLE:
case Expression::Kind::kRelGT:
case Expression::Kind::kRelGE:
case Expression::Kind::kRelIn:
case Expression::Kind::kRelNotIn:
case Expression::Kind::kContains:
case Expression::Kind::kLogicalAnd:
case Expression::Kind::kLogicalOr:
case Expression::Kind::kLogicalXor: {
using ToType = keep_const_t<T, BinaryExpression>;
auto biExpr = static_cast<ToType*>(expr);
if (!traverse(biExpr->left(), visitor)) {
return false;
}
return traverse(biExpr->right(), visitor);
}
case Expression::Kind::kUnaryIncr:
case Expression::Kind::kUnaryDecr:
case Expression::Kind::kUnaryPlus:
case Expression::Kind::kUnaryNegate:
case Expression::Kind::kUnaryNot: {
using ToType = keep_const_t<T, UnaryExpression>;
auto unaryExpr = static_cast<ToType*>(expr);
return traverse(unaryExpr->operand(), visitor);
}
case Expression::Kind::kTypeCasting: {
using ToType = keep_const_t<T, TypeCastingExpression>;
auto typeCastingExpr = static_cast<ToType*>(expr);
return traverse(typeCastingExpr->operand(), visitor);
}
case Expression::Kind::kFunctionCall: {
using ToType = keep_const_t<T, FunctionCallExpression>;
auto funcExpr = static_cast<ToType*>(expr);
for (auto& arg : funcExpr->args()->args()) {
if (!traverse(arg.get(), visitor)) {
return false;
}
}
return true;
}
case Expression::Kind::kList: // FIXME(dutor)
case Expression::Kind::kSet:
case Expression::Kind::kMap:
case Expression::Kind::kSubscript:
case Expression::Kind::kAttribute:
case Expression::Kind::kLabel: {
return false;
}
}
DLOG(FATAL) << "Impossible expression kind " << static_cast<int>(expr->kind());
return false;
}
static inline bool isKindOf(const Expression* expr,
const std::unordered_set<Expression::Kind>& expected) {
return expected.find(expr->kind()) != expected.end();
}
// null for not found
static const Expression* findAnyKind(const Expression* self,
const std::unordered_set<Expression::Kind>& expected) {
const Expression* found = nullptr;
traverse(self, [&expected, &found](const Expression* expr) -> bool {
if (isKindOf(expr, expected)) {
found = expr;
return false; // Already find so return now
}
return true; // Not find so continue traverse
});
return found;
static const Expression* findAny(const Expression* self,
const std::unordered_set<Expression::Kind>& expected) {
FindAnyExprVisitor visitor(expected);
const_cast<Expression*>(self)->accept(&visitor);
return visitor.expr();
}
// Find all expression fit any kind
// Empty for not found any one
static std::vector<const Expression*> findAnyKindInAll(
static std::vector<const Expression*> collectAll(
const Expression* self,
const std::unordered_set<Expression::Kind>& expected) {
std::vector<const Expression*> exprs;
traverse(self, [&expected, &exprs](const Expression* expr) -> bool {
if (isKindOf(expr, expected)) {
exprs.emplace_back(expr);
}
return true; // Not return always to traverse entire expression tree
});
return exprs;
CollectAllExprsVisitor visitor(expected);
const_cast<Expression*>(self)->accept(&visitor);
return std::move(visitor).exprs();
}
static bool hasAnyKind(const Expression* expr,
const std::unordered_set<Expression::Kind>& expected) {
return findAnyKind(expr, expected) != nullptr;
static bool hasAny(const Expression* expr,
const std::unordered_set<Expression::Kind>& expected) {
return findAny(expr, expected) != nullptr;
}
// Require data from input/variable
static bool hasInput(const Expression* expr) {
return hasAnyKind(expr,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kVar,
Expression::Kind::kVersionedVar});
return hasAny(expr,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kVar,
Expression::Kind::kVersionedVar});
}
// require data from graph storage
static const Expression* findStorage(const Expression* expr) {
return findAnyKind(expr,
{Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
return findAny(expr,
{Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
}
static std::vector<const Expression*> findAllStorage(const Expression* expr) {
return findAnyKindInAll(expr,
{Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
return collectAll(expr,
{Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
}
static std::vector<const Expression*> findAllInputVariableProp(const Expression* expr) {
return findAnyKindInAll(expr,
{Expression::Kind::kInputProperty, Expression::Kind::kVarProperty});
return collectAll(expr, {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty});
}
static bool hasStorage(const Expression* expr) {
......@@ -218,23 +114,22 @@ public:
}
static bool isConstExpr(const Expression* expr) {
return !hasAnyKind(expr,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kVar,
Expression::Kind::kVersionedVar,
Expression::Kind::kLabelAttribute,
Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
return !hasAny(expr,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kVar,
Expression::Kind::kVersionedVar,
Expression::Kind::kLabelAttribute,
Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
}
// clone expression
......@@ -251,99 +146,8 @@ public:
typename = std::enable_if_t<std::is_same<To, EdgePropertyExpression>::value ||
std::is_same<To, TagPropertyExpression>::value>>
static void rewriteLabelAttribute(Expression* expr) {
traverse(expr, [](Expression* current) -> bool {
switch (current->kind()) {
case Expression::Kind::kDstProperty:
case Expression::Kind::kSrcProperty:
case Expression::Kind::kLabelAttribute:
case Expression::Kind::kTagProperty:
case Expression::Kind::kEdgeProperty:
case Expression::Kind::kEdgeSrc:
case Expression::Kind::kEdgeType:
case Expression::Kind::kEdgeRank:
case Expression::Kind::kEdgeDst:
case Expression::Kind::kVertex:
case Expression::Kind::kEdge:
case Expression::Kind::kInputProperty:
case Expression::Kind::kVarProperty:
case Expression::Kind::kUUID:
case Expression::Kind::kVar:
case Expression::Kind::kVersionedVar:
case Expression::Kind::kConstant: {
return true;
}
case Expression::Kind::kAdd:
case Expression::Kind::kMinus:
case Expression::Kind::kMultiply:
case Expression::Kind::kDivision:
case Expression::Kind::kMod:
case Expression::Kind::kRelEQ:
case Expression::Kind::kRelNE:
case Expression::Kind::kRelLT:
case Expression::Kind::kRelLE:
case Expression::Kind::kRelGT:
case Expression::Kind::kRelGE:
case Expression::Kind::kRelIn:
case Expression::Kind::kRelNotIn:
case Expression::Kind::kContains:
case Expression::Kind::kLogicalAnd:
case Expression::Kind::kLogicalOr:
case Expression::Kind::kLogicalXor: {
auto* biExpr = static_cast<BinaryExpression*>(current);
if (biExpr->left()->kind() == Expression::Kind::kLabelAttribute) {
auto* laExpr = static_cast<LabelAttributeExpression*>(biExpr->left());
biExpr->setLeft(rewriteLabelAttribute<To>(laExpr));
}
if (biExpr->right()->kind() == Expression::Kind::kLabelAttribute) {
auto* laExpr = static_cast<LabelAttributeExpression*>(biExpr->right());
biExpr->setRight(rewriteLabelAttribute<To>(laExpr));
}
return true;
}
case Expression::Kind::kUnaryIncr:
case Expression::Kind::kUnaryDecr:
case Expression::Kind::kUnaryPlus:
case Expression::Kind::kUnaryNegate:
case Expression::Kind::kUnaryNot: {
auto* unaryExpr = static_cast<UnaryExpression*>(current);
if (unaryExpr->operand()->kind() == Expression::Kind::kLabelAttribute) {
auto* laExpr =
static_cast<LabelAttributeExpression*>(unaryExpr->operand());
unaryExpr->setOperand(rewriteLabelAttribute<To>(laExpr));
}
return true;
}
case Expression::Kind::kTypeCasting: {
auto* typeCastingExpr = static_cast<TypeCastingExpression*>(current);
if (typeCastingExpr->operand()->kind() == Expression::Kind::kLabelAttribute) {
auto* laExpr =
static_cast<LabelAttributeExpression*>(typeCastingExpr->operand());
typeCastingExpr->setOperand(rewriteLabelAttribute<To>(laExpr));
}
return true;
}
case Expression::Kind::kFunctionCall: {
auto* funcExpr = static_cast<FunctionCallExpression*>(current);
for (auto& arg : funcExpr->args()->args()) {
if (arg->kind() == Expression::Kind::kLabelAttribute) {
auto* laExpr = static_cast<LabelAttributeExpression*>(arg.get());
arg.reset(rewriteLabelAttribute<To>(laExpr));
}
}
return true;
}
case Expression::Kind::kList: // FIXME(dutor)
case Expression::Kind::kSet:
case Expression::Kind::kMap:
case Expression::Kind::kSubscript:
case Expression::Kind::kAttribute:
case Expression::Kind::kLabel: {
return false;
}
} // switch
DLOG(FATAL) << "Impossible expression kind " << static_cast<int>(current->kind());
return false;
}); // traverse
RewriteLabelAttrVisitor visitor(std::is_same<To, TagPropertyExpression>::value);
expr->accept(&visitor);
}
template <typename To,
......@@ -353,16 +157,9 @@ public:
return new To(new std::string(std::move(*expr->left()->name())),
new std::string(std::move(*expr->right()->name())));
}
private:
// keep const or non-const with T
template <typename T, typename To>
using keep_const_t = std::conditional_t<std::is_const<T>::value,
const std::remove_const_t<To>,
std::remove_const_t<To>>;
};
} // namespace graph
} // namespace nebula
#endif // _UTIL_EXPRESSION_UTILS_H_
#endif // _UTIL_EXPRESSION_UTILS_H_
......@@ -13,9 +13,35 @@ nebula_add_test(
$<TARGET_OBJECTS:common_function_manager_obj>
$<TARGET_OBJECTS:common_time_obj>
$<TARGET_OBJECTS:common_time_function_obj>
$<TARGET_OBJECTS:common_meta_thrift_obj>
$<TARGET_OBJECTS:common_meta_client_obj>
$<TARGET_OBJECTS:common_meta_obj>
$<TARGET_OBJECTS:common_storage_thrift_obj>
$<TARGET_OBJECTS:common_graph_thrift_obj>
$<TARGET_OBJECTS:common_conf_obj>
$<TARGET_OBJECTS:common_fs_obj>
$<TARGET_OBJECTS:common_thrift_obj>
$<TARGET_OBJECTS:common_common_thrift_obj>
$<TARGET_OBJECTS:common_thread_obj>
$<TARGET_OBJECTS:common_file_based_cluster_id_man_obj>
$<TARGET_OBJECTS:common_charset_obj>
$<TARGET_OBJECTS:common_encryption_obj>
$<TARGET_OBJECTS:common_http_client_obj>
$<TARGET_OBJECTS:common_process_obj>
$<TARGET_OBJECTS:common_agg_function_obj>
$<TARGET_OBJECTS:idgenerator_obj>
$<TARGET_OBJECTS:expr_visitor_obj>
$<TARGET_OBJECTS:session_obj>
$<TARGET_OBJECTS:graph_auth_obj>
$<TARGET_OBJECTS:graph_flags_obj>
$<TARGET_OBJECTS:util_obj>
$<TARGET_OBJECTS:planner_obj>
$<TARGET_OBJECTS:parser_obj>
$<TARGET_OBJECTS:context_obj>
$<TARGET_OBJECTS:validator_obj>
LIBRARIES
gtest
gtest_main
${THRIFT_LIBRARIES}
proxygenlib
)
......@@ -21,56 +21,56 @@ TEST_F(ExpressionUtilsTest, CheckComponent) {
const auto root = std::make_unique<ConstantExpression>();
ASSERT_TRUE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kConstant}));
ASSERT_TRUE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kConstant}));
ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kConstant}));
ASSERT_TRUE(ExpressionUtils::isKindOf(
root.get(), {Expression::Kind::kConstant, Expression::Kind::kAdd}));
ASSERT_TRUE(ExpressionUtils::hasAnyKind(
root.get(), {Expression::Kind::kConstant, Expression::Kind::kAdd}));
ASSERT_TRUE(ExpressionUtils::hasAny(root.get(),
{Expression::Kind::kConstant, Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::isKindOf(
root.get(), {Expression::Kind::kDivision, Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::hasAnyKind(
ASSERT_FALSE(ExpressionUtils::hasAny(
root.get(), {Expression::Kind::kDstProperty, Expression::Kind::kAdd}));
// find
const Expression *found =
ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kConstant});
ExpressionUtils::findAny(root.get(), {Expression::Kind::kConstant});
ASSERT_EQ(found, root.get());
found = ExpressionUtils::findAnyKind(
found = ExpressionUtils::findAny(
root.get(),
{Expression::Kind::kConstant, Expression::Kind::kAdd, Expression::Kind::kEdgeProperty});
ASSERT_EQ(found, root.get());
found = ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kEdgeDst});
found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kEdgeDst});
ASSERT_EQ(found, nullptr);
found = ExpressionUtils::findAnyKind(
found = ExpressionUtils::findAny(
root.get(), {Expression::Kind::kEdgeRank, Expression::Kind::kInputProperty});
ASSERT_EQ(found, nullptr);
// find all
const auto willFoundAll = std::vector<const Expression *>{root.get()};
std::vector<const Expression *> founds =
ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kConstant});
ExpressionUtils::collectAll(root.get(), {Expression::Kind::kConstant});
ASSERT_EQ(founds, willFoundAll);
founds = ExpressionUtils::findAnyKindInAll(
founds = ExpressionUtils::collectAll(
root.get(),
{Expression::Kind::kAdd, Expression::Kind::kConstant, Expression::Kind::kEdgeDst});
ASSERT_EQ(founds, willFoundAll);
founds = ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kSrcProperty});
founds = ExpressionUtils::collectAll(root.get(), {Expression::Kind::kSrcProperty});
ASSERT_TRUE(founds.empty());
founds = ExpressionUtils::findAnyKindInAll(root.get(),
{Expression::Kind::kUnaryNegate,
Expression::Kind::kEdgeDst,
Expression::Kind::kEdgeDst});
founds = ExpressionUtils::collectAll(root.get(),
{Expression::Kind::kUnaryNegate,
Expression::Kind::kEdgeDst,
Expression::Kind::kEdgeDst});
ASSERT_TRUE(founds.empty());
}
......@@ -83,54 +83,54 @@ TEST_F(ExpressionUtilsTest, CheckComponent) {
new TypeCastingExpression(Value::Type::BOOL, new ConstantExpression())));
ASSERT_TRUE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kTypeCasting}));
ASSERT_TRUE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kConstant}));
ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kConstant}));
ASSERT_TRUE(ExpressionUtils::isKindOf(
root.get(), {Expression::Kind::kTypeCasting, Expression::Kind::kAdd}));
ASSERT_TRUE(ExpressionUtils::hasAnyKind(
ASSERT_TRUE(ExpressionUtils::hasAny(
root.get(), {Expression::Kind::kTypeCasting, Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::isKindOf(
root.get(), {Expression::Kind::kDivision, Expression::Kind::kAdd}));
ASSERT_FALSE(ExpressionUtils::hasAnyKind(
ASSERT_FALSE(ExpressionUtils::hasAny(
root.get(), {Expression::Kind::kDstProperty, Expression::Kind::kAdd}));
// found
const Expression *found =
ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kTypeCasting});
ExpressionUtils::findAny(root.get(), {Expression::Kind::kTypeCasting});
ASSERT_EQ(found, root.get());
found = ExpressionUtils::findAnyKind(root.get(),
{Expression::Kind::kFunctionCall,
Expression::Kind::kTypeCasting,
Expression::Kind::kLogicalAnd});
found = ExpressionUtils::findAny(root.get(),
{Expression::Kind::kFunctionCall,
Expression::Kind::kTypeCasting,
Expression::Kind::kLogicalAnd});
ASSERT_EQ(found, root.get());
found = ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kDivision});
found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kDivision});
ASSERT_EQ(found, nullptr);
found = ExpressionUtils::findAnyKind(root.get(),
{Expression::Kind::kLogicalXor,
Expression::Kind::kRelGE,
Expression::Kind::kEdgeProperty});
found = ExpressionUtils::findAny(root.get(),
{Expression::Kind::kLogicalXor,
Expression::Kind::kRelGE,
Expression::Kind::kEdgeProperty});
ASSERT_EQ(found, nullptr);
// found all
std::vector<const Expression *> founds =
ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kConstant});
ExpressionUtils::collectAll(root.get(), {Expression::Kind::kConstant});
ASSERT_EQ(founds.size(), 1);
founds = ExpressionUtils::findAnyKindInAll(
founds = ExpressionUtils::collectAll(
root.get(), {Expression::Kind::kFunctionCall, Expression::Kind::kTypeCasting});
ASSERT_EQ(founds.size(), 3);
founds = ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kAdd});
founds = ExpressionUtils::collectAll(root.get(), {Expression::Kind::kAdd});
ASSERT_TRUE(founds.empty());
founds = ExpressionUtils::findAnyKindInAll(
founds = ExpressionUtils::collectAll(
root.get(), {Expression::Kind::kRelLE, Expression::Kind::kDstProperty});
ASSERT_TRUE(founds.empty());
}
......@@ -151,54 +151,53 @@ TEST_F(ExpressionUtilsTest, CheckComponent) {
new ConstantExpression(2)));
ASSERT_TRUE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kAdd}));
ASSERT_TRUE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kMinus}));
ASSERT_TRUE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kMinus}));
ASSERT_TRUE(ExpressionUtils::isKindOf(
root.get(), {Expression::Kind::kTypeCasting, Expression::Kind::kAdd}));
ASSERT_TRUE(ExpressionUtils::hasAnyKind(
ASSERT_TRUE(ExpressionUtils::hasAny(
root.get(), {Expression::Kind::kLabelAttribute, Expression::Kind::kDivision}));
ASSERT_FALSE(ExpressionUtils::isKindOf(root.get(), {Expression::Kind::kConstant}));
ASSERT_FALSE(ExpressionUtils::hasAnyKind(root.get(), {Expression::Kind::kFunctionCall}));
ASSERT_FALSE(ExpressionUtils::hasAny(root.get(), {Expression::Kind::kFunctionCall}));
ASSERT_FALSE(ExpressionUtils::isKindOf(
root.get(), {Expression::Kind::kDivision, Expression::Kind::kEdgeProperty}));
ASSERT_FALSE(ExpressionUtils::hasAnyKind(
ASSERT_FALSE(ExpressionUtils::hasAny(
root.get(), {Expression::Kind::kDstProperty, Expression::Kind::kLogicalAnd}));
// found
const Expression *found =
ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kAdd});
const Expression *found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kAdd});
ASSERT_EQ(found, root.get());
found = ExpressionUtils::findAnyKind(root.get(),
{Expression::Kind::kFunctionCall,
Expression::Kind::kRelLE,
Expression::Kind::kMultiply});
found = ExpressionUtils::findAny(root.get(),
{Expression::Kind::kFunctionCall,
Expression::Kind::kRelLE,
Expression::Kind::kMultiply});
ASSERT_NE(found, nullptr);
found = ExpressionUtils::findAnyKind(root.get(), {Expression::Kind::kInputProperty});
found = ExpressionUtils::findAny(root.get(), {Expression::Kind::kInputProperty});
ASSERT_EQ(found, nullptr);
found = ExpressionUtils::findAnyKind(root.get(),
{Expression::Kind::kLogicalXor,
Expression::Kind::kEdgeRank,
Expression::Kind::kUnaryNot});
found = ExpressionUtils::findAny(root.get(),
{Expression::Kind::kLogicalXor,
Expression::Kind::kEdgeRank,
Expression::Kind::kUnaryNot});
ASSERT_EQ(found, nullptr);
// found all
std::vector<const Expression *> founds =
ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kConstant});
ExpressionUtils::collectAll(root.get(), {Expression::Kind::kConstant});
ASSERT_EQ(founds.size(), 6);
founds = ExpressionUtils::findAnyKindInAll(
founds = ExpressionUtils::collectAll(
root.get(), {Expression::Kind::kDivision, Expression::Kind::kMinus});
ASSERT_EQ(founds.size(), 2);
founds = ExpressionUtils::findAnyKindInAll(root.get(), {Expression::Kind::kEdgeDst});
founds = ExpressionUtils::collectAll(root.get(), {Expression::Kind::kEdgeDst});
ASSERT_TRUE(founds.empty());
founds = ExpressionUtils::findAnyKindInAll(
founds = ExpressionUtils::collectAll(
root.get(), {Expression::Kind::kLogicalAnd, Expression::Kind::kUnaryNegate});
ASSERT_TRUE(founds.empty());
}
......
......@@ -251,11 +251,11 @@ Status FetchEdgesValidator::preparePropertiesWithoutYield() {
/*static*/
const Expression *FetchEdgesValidator::findInvalidYieldExpression(const Expression *root) {
return ExpressionUtils::findAnyKind(root,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kDstProperty});
return ExpressionUtils::findAny(root,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kDstProperty});
}
// TODO(shylock) optimize dedup input when distinct given
......
......@@ -236,15 +236,15 @@ Status FetchVerticesValidator::preparePropertiesWithoutYield() {
/*static*/
const Expression *FetchVerticesValidator::findInvalidYieldExpression(const Expression *root) {
return ExpressionUtils::findAnyKind(root,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst});
return ExpressionUtils::findAny(root,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst});
}
// TODO(shylock) optimize dedup input when distinct given
......
......@@ -6,6 +6,9 @@
nebula_add_library(
expr_visitor_obj OBJECT
ExprVisitorImpl.cpp
CollectAllExprsVisitor.cpp
DeducePropsVisitor.cpp
DeduceTypeVisitor.cpp
FindAnyExprVisitor.cpp
RewriteLabelAttrVisitor.cpp
)
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "visitor/CollectAllExprsVisitor.h"
namespace nebula {
namespace graph {
CollectAllExprsVisitor::CollectAllExprsVisitor(
const std::unordered_set<Expression::Kind> &exprKinds)
: exprKinds_(exprKinds) {}
void CollectAllExprsVisitor::visit(TypeCastingExpression *expr) {
collectExpr(expr);
expr->operand()->accept(this);
}
void CollectAllExprsVisitor::visit(UnaryExpression *expr) {
collectExpr(expr);
expr->operand()->accept(this);
}
void CollectAllExprsVisitor::visit(FunctionCallExpression *expr) {
collectExpr(expr);
for (const auto &arg : expr->args()->args()) {
arg->accept(this);
}
}
void CollectAllExprsVisitor::visit(ListExpression *expr) {
collectExpr(expr);
for (auto item : expr->items()) {
const_cast<Expression *>(item)->accept(this);
}
}
void CollectAllExprsVisitor::visit(SetExpression *expr) {
collectExpr(expr);
for (auto item : expr->items()) {
const_cast<Expression *>(item)->accept(this);
}
}
void CollectAllExprsVisitor::visit(MapExpression *expr) {
collectExpr(expr);
for (const auto &pair : expr->items()) {
const_cast<Expression *>(pair.second)->accept(this);
}
}
void CollectAllExprsVisitor::visit(ConstantExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(EdgePropertyExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(TagPropertyExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(InputPropertyExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(VariablePropertyExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(SourcePropertyExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(DestPropertyExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(EdgeSrcIdExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(EdgeTypeExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(EdgeRankExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(EdgeDstIdExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(UUIDExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(VariableExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(VersionedVariableExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(LabelExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(VertexExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(EdgeExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visitBinaryExpr(BinaryExpression *expr) {
collectExpr(expr);
expr->left()->accept(this);
expr->right()->accept(this);
}
void CollectAllExprsVisitor::collectExpr(const Expression *expr) {
if (exprKinds_.find(expr->kind()) != exprKinds_.cend()) {
exprs_.push_back(expr);
}
}
} // namespace graph
} // namespace nebula
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#ifndef VISITOR_COLLECTALLEXPRSVISITOR_H_
#define VISITOR_COLLECTALLEXPRSVISITOR_H_
#include <unordered_set>
#include "common/expression/Expression.h"
#include "visitor/ExprVisitorImpl.h"
namespace nebula {
namespace graph {
class CollectAllExprsVisitor final : public ExprVisitorImpl {
public:
explicit CollectAllExprsVisitor(const std::unordered_set<Expression::Kind>& exprKinds);
bool ok() const override {
return !exprKinds_.empty();
}
std::vector<const Expression*> exprs() && {
return std::move(exprs_);
}
private:
using ExprVisitorImpl::visit;
void visit(TypeCastingExpression* expr) override;
void visit(UnaryExpression* expr) override;
void visit(FunctionCallExpression* expr) override;
void visit(ListExpression* expr) override;
void visit(SetExpression* expr) override;
void visit(MapExpression* expr) override;
void visit(ConstantExpression* expr) override;
void visit(EdgePropertyExpression* expr) override;
void visit(TagPropertyExpression* expr) override;
void visit(InputPropertyExpression* expr) override;
void visit(VariablePropertyExpression* expr) override;
void visit(SourcePropertyExpression* expr) override;
void visit(DestPropertyExpression* expr) override;
void visit(EdgeSrcIdExpression* expr) override;
void visit(EdgeTypeExpression* expr) override;
void visit(EdgeRankExpression* expr) override;
void visit(EdgeDstIdExpression* expr) override;
void visit(UUIDExpression* expr) override;
void visit(VariableExpression* expr) override;
void visit(VersionedVariableExpression* expr) override;
void visit(LabelExpression* expr) override;
void visit(VertexExpression* expr) override;
void visit(EdgeExpression* expr) override;
void visitBinaryExpr(BinaryExpression* expr) override;
void collectExpr(const Expression* expr);
const std::unordered_set<Expression::Kind>& exprKinds_;
std::vector<const Expression*> exprs_;
};
} // namespace graph
} // namespace nebula
#endif // VISITOR_COLLECTALLEXPRSVISITOR_H_
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "visitor/FindAnyExprVisitor.h"
namespace nebula {
namespace graph {
FindAnyExprVisitor::FindAnyExprVisitor(const std::unordered_set<Expression::Kind> &kinds)
: kinds_(kinds) {
DCHECK(!kinds.empty());
}
void FindAnyExprVisitor::visit(TypeCastingExpression *expr) {
findExpr(expr);
if (found_) return;
expr->operand()->accept(this);
}
void FindAnyExprVisitor::visit(UnaryExpression *expr) {
findExpr(expr);
if (found_) return;
expr->operand()->accept(this);
}
void FindAnyExprVisitor::visit(FunctionCallExpression *expr) {
findExpr(expr);
if (found_) return;
for (const auto &arg : expr->args()->args()) {
arg->accept(this);
if (found_) return;
}
}
void FindAnyExprVisitor::visit(ListExpression *expr) {
findExpr(expr);
if (found_) return;
for (const auto &item : expr->items()) {
const_cast<Expression *>(item)->accept(this);
if (found_) return;
}
}
void FindAnyExprVisitor::visit(SetExpression *expr) {
findExpr(expr);
if (found_) return;
for (const auto &item : expr->items()) {
const_cast<Expression *>(item)->accept(this);
if (found_) return;
}
}
void FindAnyExprVisitor::visit(MapExpression *expr) {
findExpr(expr);
if (found_) return;
for (const auto &pair : expr->items()) {
const_cast<Expression *>(pair.second)->accept(this);
if (found_) return;
}
}
void FindAnyExprVisitor::visit(ConstantExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(EdgePropertyExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(TagPropertyExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(InputPropertyExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(VariablePropertyExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(SourcePropertyExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(DestPropertyExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(EdgeSrcIdExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(EdgeTypeExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(EdgeRankExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(EdgeDstIdExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(UUIDExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(VariableExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(VersionedVariableExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(LabelExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(VertexExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visit(EdgeExpression *expr) {
findExpr(expr);
}
void FindAnyExprVisitor::visitBinaryExpr(BinaryExpression *expr) {
findExpr(expr);
if (found_) return;
expr->left()->accept(this);
if (found_) return;
expr->right()->accept(this);
}
void FindAnyExprVisitor::findExpr(const Expression *expr) {
found_ = kinds_.find(expr->kind()) != kinds_.cend();
if (found_) {
expr_ = expr;
}
}
} // namespace graph
} // namespace nebula
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#ifndef VISITOR_FINDANYEXPRVISITOR_H_
#define VISITOR_FINDANYEXPRVISITOR_H_
#include <unordered_set>
#include "common/expression/Expression.h"
#include "visitor/ExprVisitorImpl.h"
namespace nebula {
namespace graph {
class FindAnyExprVisitor final : public ExprVisitorImpl {
public:
explicit FindAnyExprVisitor(const std::unordered_set<Expression::Kind>& kinds);
bool ok() const override {
// continue if not found
return !found_;
}
const Expression* expr() const {
return expr_;
}
private:
using ExprVisitorImpl::visit;
void visit(TypeCastingExpression* expr) override;
void visit(UnaryExpression* expr) override;
void visit(FunctionCallExpression* expr) override;
void visit(ListExpression* expr) override;
void visit(SetExpression* expr) override;
void visit(MapExpression* expr) override;
void visit(ConstantExpression* expr) override;
void visit(EdgePropertyExpression* expr) override;
void visit(TagPropertyExpression* expr) override;
void visit(InputPropertyExpression* expr) override;
void visit(VariablePropertyExpression* expr) override;
void visit(SourcePropertyExpression* expr) override;
void visit(DestPropertyExpression* expr) override;
void visit(EdgeSrcIdExpression* expr) override;
void visit(EdgeTypeExpression* expr) override;
void visit(EdgeRankExpression* expr) override;
void visit(EdgeDstIdExpression* expr) override;
void visit(UUIDExpression* expr) override;
void visit(VariableExpression* expr) override;
void visit(VersionedVariableExpression* expr) override;
void visit(LabelExpression* expr) override;
void visit(VertexExpression* expr) override;
void visit(EdgeExpression* expr) override;
void visitBinaryExpr(BinaryExpression* expr) override;
void findExpr(const Expression* expr);
bool found_{false};
const Expression* expr_{nullptr};
const std::unordered_set<Expression::Kind>& kinds_;
};
} // namespace graph
} // namespace nebula
#endif // VISITOR_FINDANYEXPRVISITOR_H_
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "visitor/RewriteLabelAttrVisitor.h"
#include "common/base/Logging.h"
namespace nebula {
namespace graph {
RewriteLabelAttrVisitor::RewriteLabelAttrVisitor(bool isTag) : isTag_(isTag) {}
void RewriteLabelAttrVisitor::visit(TypeCastingExpression* expr) {
if (isLabelAttrExpr(expr->operand())) {
auto operand = static_cast<LabelAttributeExpression*>(expr->operand());
expr->setOperand(createExpr(operand));
} else {
expr->operand()->accept(this);
}
}
void RewriteLabelAttrVisitor::visit(UnaryExpression* expr) {
if (isLabelAttrExpr(expr->operand())) {
auto operand = static_cast<LabelAttributeExpression*>(expr->operand());
expr->setOperand(createExpr(operand));
} else {
expr->operand()->accept(this);
}
}
void RewriteLabelAttrVisitor::visit(FunctionCallExpression* expr) {
for (auto& arg : expr->args()->args()) {
if (isLabelAttrExpr(arg.get())) {
auto newArg = static_cast<LabelAttributeExpression*>(arg.get());
arg.reset(createExpr(newArg));
} else {
arg->accept(this);
}
}
}
void RewriteLabelAttrVisitor::visit(ListExpression* expr) {
auto newItems = rewriteExprList(expr->items());
if (!newItems.empty()) {
expr->setItems(std::move(newItems));
}
}
void RewriteLabelAttrVisitor::visit(SetExpression* expr) {
auto newItems = rewriteExprList(expr->items());
if (!newItems.empty()) {
expr->setItems(std::move(newItems));
}
}
void RewriteLabelAttrVisitor::visit(MapExpression* expr) {
auto items = expr->items();
auto found = std::find_if(
items.cbegin(), items.cend(), [](auto& pair) { return isLabelAttrExpr(pair.second); });
if (found == items.cend()) {
std::for_each(items.begin(), items.end(), [this](auto& pair) {
const_cast<Expression*>(pair.second)->accept(this);
});
return;
}
std::vector<MapExpression::Item> newItems;
newItems.reserve(items.size());
for (auto& pair : items) {
MapExpression::Item newItem;
newItem.first.reset(new std::string(*pair.first));
if (isLabelAttrExpr(pair.second)) {
auto symExpr = static_cast<const LabelAttributeExpression*>(pair.second);
newItem.second.reset(createExpr(symExpr));
} else {
newItem.second = Expression::decode(pair.second->encode());
newItem.second->accept(this);
}
newItems.emplace_back(std::move(newItem));
}
expr->setItems(std::move(newItems));
}
void RewriteLabelAttrVisitor::visitBinaryExpr(BinaryExpression* expr) {
if (isLabelAttrExpr(expr->left())) {
auto left = static_cast<const LabelAttributeExpression*>(expr->left());
expr->setLeft(createExpr(left));
} else {
expr->left()->accept(this);
}
if (isLabelAttrExpr(expr->right())) {
auto right = static_cast<const LabelAttributeExpression*>(expr->right());
expr->setRight(createExpr(right));
} else {
expr->right()->accept(this);
}
}
std::vector<std::unique_ptr<Expression>> RewriteLabelAttrVisitor::rewriteExprList(
const std::vector<const Expression*>& exprs) {
std::vector<std::unique_ptr<Expression>> newExprs;
auto found = std::find_if(exprs.cbegin(), exprs.cend(), isLabelAttrExpr);
if (found == exprs.cend()) {
std::for_each(exprs.cbegin(), exprs.cend(), [this](auto expr) {
const_cast<Expression*>(expr)->accept(this);
});
return newExprs;
}
newExprs.reserve(exprs.size());
for (auto item : exprs) {
if (isLabelAttrExpr(item)) {
auto symExpr = static_cast<const LabelAttributeExpression*>(item);
newExprs.emplace_back(createExpr(symExpr));
} else {
auto newExpr = Expression::decode(item->encode());
newExpr->accept(this);
newExprs.emplace_back(std::move(newExpr));
}
}
return newExprs;
}
Expression* RewriteLabelAttrVisitor::createExpr(const LabelAttributeExpression* expr) {
auto leftName = new std::string(*expr->left()->name());
auto rightName = new std::string(*expr->right()->name());
if (isTag_) {
return new TagPropertyExpression(leftName, rightName);
}
return new EdgePropertyExpression(leftName, rightName);
}
bool RewriteLabelAttrVisitor::isLabelAttrExpr(const Expression* expr) {
return expr->kind() == Expression::Kind::kLabelAttribute;
}
} // namespace graph
} // namespace nebula
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#ifndef VISITOR_REWRITELABELATTRVISITOR_H_
#define VISITOR_REWRITELABELATTRVISITOR_H_
#include <memory>
#include <vector>
#include "visitor/ExprVisitorImpl.h"
namespace nebula {
namespace graph {
class RewriteLabelAttrVisitor final : public ExprVisitorImpl {
public:
explicit RewriteLabelAttrVisitor(bool isTag);
bool ok() const override {
return true;
}
private:
using ExprVisitorImpl::visit;
void visit(TypeCastingExpression *expr) override;
void visit(UnaryExpression *expr) override;
void visit(FunctionCallExpression *expr) override;
void visit(ListExpression *expr) override;
void visit(SetExpression *expr) override;
void visit(MapExpression *expr) override;
void visit(ConstantExpression *) override {}
void visit(LabelExpression *) override {}
void visit(UUIDExpression *) override {}
void visit(LabelAttributeExpression *) override {}
void visit(VariableExpression *) override {}
void visit(VersionedVariableExpression *) override {}
void visit(TagPropertyExpression *) override {}
void visit(EdgePropertyExpression *) override {}
void visit(InputPropertyExpression *) override {}
void visit(VariablePropertyExpression *) override {}
void visit(DestPropertyExpression *) override {}
void visit(SourcePropertyExpression *) override {}
void visit(EdgeSrcIdExpression *) override {}
void visit(EdgeTypeExpression *) override {}
void visit(EdgeRankExpression *) override {}
void visit(EdgeDstIdExpression *) override {}
void visit(VertexExpression *) override {}
void visit(EdgeExpression *) override {}
void visitBinaryExpr(BinaryExpression *expr) override;
Expression *createExpr(const LabelAttributeExpression *expr);
std::vector<std::unique_ptr<Expression>> rewriteExprList(
const std::vector<const Expression *> &exprs);
static bool isLabelAttrExpr(const Expression *expr);
bool isTag_{false};
};
} // namespace graph
} // namespace nebula
#endif // VISITOR_REWRITELABELATTRVISITOR_H_
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment