diff --git a/.gitignore b/.gitignore index 0fed2568ca761649a0ca00c82b1bd22e5269aaf8..3c4daa6363d91e85f0054e652fa95e6b2436e9de 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ cmake-build-release/ *.pot *.py[co] __pycache__ +venv/ diff --git a/src/context/test/CMakeLists.txt b/src/context/test/CMakeLists.txt index 82eb6b828eac265b42341f051cb88548113fce00..77621617c03d2dbea8a933e53f093d8657206d61 100644 --- a/src/context/test/CMakeLists.txt +++ b/src/context/test/CMakeLists.txt @@ -4,7 +4,10 @@ # attached with Common Clause Condition 1.0, found in the LICENSES directory. SET(CONTEXT_TEST_LIBS + $<TARGET_OBJECTS:common_agg_function_obj> + $<TARGET_OBJECTS:common_charset_obj> $<TARGET_OBJECTS:common_datatypes_obj> + $<TARGET_OBJECTS:common_encryption_obj> $<TARGET_OBJECTS:common_expression_obj> $<TARGET_OBJECTS:common_function_manager_obj> $<TARGET_OBJECTS:common_fs_obj> @@ -21,10 +24,15 @@ SET(CONTEXT_TEST_LIBS $<TARGET_OBJECTS:common_graph_thrift_obj> $<TARGET_OBJECTS:common_storage_thrift_obj> $<TARGET_OBJECTS:common_time_function_obj> + $<TARGET_OBJECTS:common_http_client_obj> + $<TARGET_OBJECTS:common_process_obj> $<TARGET_OBJECTS:util_obj> $<TARGET_OBJECTS:context_obj> + $<TARGET_OBJECTS:expr_visitor_obj> $<TARGET_OBJECTS:parser_obj> + $<TARGET_OBJECTS:validator_obj> $<TARGET_OBJECTS:graph_flags_obj> + $<TARGET_OBJECTS:graph_auth_obj> $<TARGET_OBJECTS:session_obj> $<TARGET_OBJECTS:planner_obj> $<TARGET_OBJECTS:idgenerator_obj> diff --git a/src/optimizer/OptimizerUtils.cpp b/src/optimizer/OptimizerUtils.cpp index ec8ab1d46c43d2722ed86cd69c58972641823b01..f2054a8c697cf07281326ce6b3ba78e82e5bde1e 100644 --- a/src/optimizer/OptimizerUtils.cpp +++ b/src/optimizer/OptimizerUtils.cpp @@ -5,6 +5,7 @@ */ #include "optimizer/OptimizerUtils.h" + namespace nebula { namespace graph { diff --git a/src/parser/test/CMakeLists.txt b/src/parser/test/CMakeLists.txt index c55dff8d505df082b119ec8d57bcbbf03ec15008..dc3f5053f15d3df03a3f79b0af366957e938e19e 100644 --- a/src/parser/test/CMakeLists.txt +++ b/src/parser/test/CMakeLists.txt @@ -7,9 +7,9 @@ set(PARSER_TEST_LIBS $<TARGET_OBJECTS:parser_obj> $<TARGET_OBJECTS:common_time_function_obj> $<TARGET_OBJECTS:common_expression_obj> + $<TARGET_OBJECTS:common_encryption_obj> $<TARGET_OBJECTS:common_network_obj> $<TARGET_OBJECTS:common_fs_obj> - $<TARGET_OBJECTS:common_time_obj> $<TARGET_OBJECTS:common_stats_obj> $<TARGET_OBJECTS:common_time_obj> $<TARGET_OBJECTS:common_common_thrift_obj> @@ -18,18 +18,25 @@ set(PARSER_TEST_LIBS $<TARGET_OBJECTS:common_datatypes_obj> $<TARGET_OBJECTS:common_base_obj> $<TARGET_OBJECTS:common_function_manager_obj> + $<TARGET_OBJECTS:common_agg_function_obj> $<TARGET_OBJECTS:common_meta_thrift_obj> $<TARGET_OBJECTS:common_graph_thrift_obj> + $<TARGET_OBJECTS:common_http_client_obj> $<TARGET_OBJECTS:common_storage_thrift_obj> $<TARGET_OBJECTS:common_meta_obj> $<TARGET_OBJECTS:common_meta_client_obj> $<TARGET_OBJECTS:common_conf_obj> + $<TARGET_OBJECTS:common_charset_obj> $<TARGET_OBJECTS:common_file_based_cluster_id_man_obj> + $<TARGET_OBJECTS:common_process_obj> $<TARGET_OBJECTS:session_obj> $<TARGET_OBJECTS:graph_flags_obj> + $<TARGET_OBJECTS:graph_auth_obj> $<TARGET_OBJECTS:util_obj> + $<TARGET_OBJECTS:expr_visitor_obj> $<TARGET_OBJECTS:context_obj> $<TARGET_OBJECTS:planner_obj> + $<TARGET_OBJECTS:validator_obj> $<TARGET_OBJECTS:idgenerator_obj> ) @@ -37,28 +44,28 @@ nebula_add_test( NAME parser_test SOURCES ParserTest.cpp OBJECTS ${PARSER_TEST_LIBS} - LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} + LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} proxygenlib ) nebula_add_test( NAME scanner_test SOURCES ScannerTest.cpp OBJECTS ${PARSER_TEST_LIBS} - LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} + LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} proxygenlib ) nebula_add_executable( NAME parser_bm SOURCES ParserBenchmark.cpp OBJECTS ${PARSER_TEST_LIBS} - LIBRARIES follybenchmark boost_regex ${THRIFT_LIBRARIES} + LIBRARIES follybenchmark boost_regex ${THRIFT_LIBRARIES} proxygenlib ) nebula_add_test( NAME expression_parsing_test SOURCES ExpressionParsingTest.cpp OBJECTS ${PARSER_TEST_LIBS} - LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} + LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} proxygenlib ) if(ENABLE_FUZZ_TEST) diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index f333c400b835911d23f056624d3cd5b5e2c00e8b..b5dedf77aba3c6fb7506b8b96c6b7e5ba729b289 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -6,6 +6,7 @@ nebula_add_library( util_obj OBJECT + ExpressionUtils.cpp SchemaUtil.cpp ToJson.cpp ) diff --git a/src/util/ExpressionUtils.cpp b/src/util/ExpressionUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb20ee2d98bb639cc03a6583e9e19deacda8f105 --- /dev/null +++ b/src/util/ExpressionUtils.cpp @@ -0,0 +1,25 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "util/ExpressionUtils.h" + +#include "visitor/FoldConstantExprVisitor.h" + +namespace nebula { +namespace graph { + +std::unique_ptr<Expression> ExpressionUtils::foldConstantExpr(const Expression *expr) { + auto newExpr = expr->clone(); + FoldConstantExprVisitor visitor; + newExpr->accept(&visitor); + if (visitor.canBeFolded()) { + return std::unique_ptr<Expression>(visitor.fold(newExpr.get())); + } + return newExpr; +} + +} // namespace graph +} // namespace nebula diff --git a/src/util/ExpressionUtils.h b/src/util/ExpressionUtils.h index b79551d30f416b8f72c688a808ae3b1d7aec18d6..b404719c1a4438f56b8edc7429d47a4981b996f7 100644 --- a/src/util/ExpressionUtils.h +++ b/src/util/ExpressionUtils.h @@ -53,30 +53,6 @@ public: return findAny(expr, expected) != nullptr; } - // Require data from input/variable - static bool hasInput(const Expression* expr) { - return hasAny(expr, - {Expression::Kind::kInputProperty, - Expression::Kind::kVarProperty, - Expression::Kind::kVar, - Expression::Kind::kVersionedVar}); - } - - // require data from graph storage - static const Expression* findStorage(const Expression* expr) { - return findAny(expr, - {Expression::Kind::kTagProperty, - Expression::Kind::kEdgeProperty, - Expression::Kind::kDstProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kEdgeSrc, - Expression::Kind::kEdgeType, - Expression::Kind::kEdgeRank, - Expression::Kind::kEdgeDst, - Expression::Kind::kVertex, - Expression::Kind::kEdge}); - } - static std::vector<const Expression*> findAllStorage(const Expression* expr) { return collectAll(expr, {Expression::Kind::kTagProperty, @@ -95,24 +71,6 @@ public: return collectAll(expr, {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty}); } - static bool hasStorage(const Expression* expr) { - return findStorage(expr) != nullptr; - } - - static bool isStorage(const Expression* expr) { - return isKindOf(expr, - {Expression::Kind::kTagProperty, - Expression::Kind::kEdgeProperty, - Expression::Kind::kDstProperty, - Expression::Kind::kSrcProperty, - Expression::Kind::kEdgeSrc, - Expression::Kind::kEdgeType, - Expression::Kind::kEdgeRank, - Expression::Kind::kEdgeDst, - Expression::Kind::kVertex, - Expression::Kind::kEdge}); - } - static bool isConstExpr(const Expression* expr) { return !hasAny(expr, {Expression::Kind::kInputProperty, @@ -148,6 +106,9 @@ public: return new To(new std::string(std::move(*expr->left()->name())), new std::string(std::move(*expr->right()->name()))); } + + // Clone and fold constant expression + static std::unique_ptr<Expression> foldConstantExpr(const Expression* expr); }; } // namespace graph diff --git a/src/validator/YieldValidator.cpp b/src/validator/YieldValidator.cpp index 932fc387af9cdc70c1a5cc0a126241e00d6410c2..c482614ef91fbbaf22befd86f8b2951ea20a38e4 100644 --- a/src/validator/YieldValidator.cpp +++ b/src/validator/YieldValidator.cpp @@ -11,6 +11,8 @@ #include "parser/Clauses.h" #include "parser/TraverseSentences.h" #include "planner/Query.h" +#include "util/ExpressionUtils.h" +#include "visitor/FoldConstantExprVisitor.h" namespace nebula { namespace graph { @@ -70,7 +72,7 @@ Status YieldValidator::checkAggFunAndBuildGroupItems(const YieldClause *clause) } Status YieldValidator::checkInputProps() const { - auto& inputProps = const_cast<ExpressionProps*>(&exprProps_)->inputProps(); + auto &inputProps = const_cast<ExpressionProps *>(&exprProps_)->inputProps(); if (inputs_.empty() && !inputProps.empty()) { return Status::SemanticError("no inputs for yield columns."); } @@ -82,7 +84,7 @@ Status YieldValidator::checkInputProps() const { } Status YieldValidator::checkVarProps() const { - auto& varProps = const_cast<ExpressionProps*>(&exprProps_)->varProps(); + auto &varProps = const_cast<ExpressionProps *>(&exprProps_)->varProps(); for (auto &pair : varProps) { auto &var = pair.first; if (!vctx_->existVar(var)) { @@ -102,6 +104,7 @@ Status YieldValidator::makeOutputColumn(YieldColumn *column) { auto expr = column->expr(); DCHECK(expr != nullptr); + NG_RETURN_IF_ERROR(deduceProps(expr, exprProps_)); auto status = deduceExprType(expr); @@ -111,6 +114,13 @@ Status YieldValidator::makeOutputColumn(YieldColumn *column) { auto name = deduceColName(column); outputColumnNames_.emplace_back(name); + // Constant expression folding must be after type deduction + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + if (visitor.canBeFolded()) { + column->setExpr(visitor.fold(expr)); + } + outputs_.emplace_back(name, type); return Status::OK(); } @@ -121,7 +131,7 @@ void YieldValidator::genConstantExprValues() { ds.colNames = outputColumnNames_; QueryExpressionContext ctx; Row row; - for (auto& column : columns_->columns()) { + for (auto &column : columns_->columns()) { row.values.emplace_back(Expression::eval(column->expr(), ctx(nullptr))); } ds.emplace_back(std::move(row)); @@ -192,6 +202,8 @@ Status YieldValidator::validateWhere(const WhereClause *clause) { } if (filter != nullptr) { NG_RETURN_IF_ERROR(deduceProps(filter, exprProps_)); + auto newFilter = ExpressionUtils::foldConstantExpr(filter); + filterCondition_ = qctx_->objPool()->add(newFilter.release()); } return Status::OK(); } @@ -201,7 +213,7 @@ Status YieldValidator::toPlan() { Filter *filter = nullptr; if (yield->where()) { - filter = Filter::make(qctx_, nullptr, yield->where()->filter()); + filter = Filter::make(qctx_, nullptr, filterCondition_); std::vector<std::string> colNames(inputs_.size()); std::transform( inputs_.cbegin(), inputs_.cend(), colNames.begin(), [](auto &in) { return in.first; }); @@ -247,4 +259,3 @@ Status YieldValidator::toPlan() { } // namespace graph } // namespace nebula - diff --git a/src/validator/YieldValidator.h b/src/validator/YieldValidator.h index 1defb2edc07832f404511436e0b16066c3658dfe..7d2aeb9f38d2e6947934b12d994b246e7967f85d 100644 --- a/src/validator/YieldValidator.h +++ b/src/validator/YieldValidator.h @@ -48,8 +48,9 @@ private: YieldColumns *columns_{nullptr}; std::vector<std::string> outputColumnNames_; std::vector<Aggregate::GroupItem> groupItems_; - ExpressionProps exprProps_; + ExpressionProps exprProps_; std::string constantExprVar_; + Expression *filterCondition_{nullptr}; }; } // namespace graph diff --git a/src/visitor/CMakeLists.txt b/src/visitor/CMakeLists.txt index acc59d6e78eded151c3c2b098aa93ffc74299fff..b34ae76ea9e6e476e91aac41ffb6fc0eeb9139cb 100644 --- a/src/visitor/CMakeLists.txt +++ b/src/visitor/CMakeLists.txt @@ -11,5 +11,8 @@ nebula_add_library( DeduceTypeVisitor.cpp ExtractFilterExprVisitor.cpp FindAnyExprVisitor.cpp + FoldConstantExprVisitor.cpp RewriteLabelAttrVisitor.cpp ) + +nebula_add_subdirectory(test) diff --git a/src/visitor/FoldConstantExprVisitor.cpp b/src/visitor/FoldConstantExprVisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d8e7559f3b3e16807b3ecfe77ff74a5681f698a --- /dev/null +++ b/src/visitor/FoldConstantExprVisitor.cpp @@ -0,0 +1,234 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "visitor/FoldConstantExprVisitor.h" + +#include "context/QueryExpressionContext.h" + +namespace nebula { +namespace graph { + +void FoldConstantExprVisitor::visit(ConstantExpression *expr) { + UNUSED(expr); + canBeFolded_ = true; +} + +void FoldConstantExprVisitor::visit(UnaryExpression *expr) { + expr->operand()->accept(this); + if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) { + expr->setOperand(fold(expr->operand())); + } +} + +void FoldConstantExprVisitor::visit(TypeCastingExpression *expr) { + expr->operand()->accept(this); + if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) { + expr->setOperand(fold(expr->operand())); + } +} + +void FoldConstantExprVisitor::visit(LabelExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(LabelAttributeExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +// binary expression +void FoldConstantExprVisitor::visit(ArithmeticExpression *expr) { + visitBinaryExpr(expr); +} + +void FoldConstantExprVisitor::visit(RelationalExpression *expr) { + visitBinaryExpr(expr); +} + +void FoldConstantExprVisitor::visit(SubscriptExpression *expr) { + visitBinaryExpr(expr); +} + +void FoldConstantExprVisitor::visit(AttributeExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(LogicalExpression *expr) { + visitBinaryExpr(expr); +} + +// function call +void FoldConstantExprVisitor::visit(FunctionCallExpression *expr) { + bool canBeFolded = true; + for (auto &arg : expr->args()->args()) { + if (arg->kind() != Expression::Kind::kConstant) { + arg->accept(this); + if (canBeFolded_) { + arg.reset(fold(arg.get())); + } else { + canBeFolded = false; + } + } + } + canBeFolded_ = canBeFolded; +} + +void FoldConstantExprVisitor::visit(UUIDExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +// variable expression +void FoldConstantExprVisitor::visit(VariableExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(VersionedVariableExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +// container expression +void FoldConstantExprVisitor::visit(ListExpression *expr) { + auto items = expr->items(); + bool canBeFolded = true; + for (size_t i = 0; i < items.size(); ++i) { + auto item = const_cast<Expression *>(items[i]); + item->accept(this); + if (!canBeFolded_) { + canBeFolded = false; + continue; + } + if (item->kind() != Expression::Kind::kConstant) { + expr->setItem(i, fold(item)); + } + } + canBeFolded_ = canBeFolded; +} + +void FoldConstantExprVisitor::visit(SetExpression *expr) { + auto items = expr->items(); + bool canBeFolded = true; + for (size_t i = 0; i < items.size(); ++i) { + auto item = const_cast<Expression *>(items[i]); + item->accept(this); + if (!canBeFolded_) { + canBeFolded = false; + continue; + } + if (item->kind() != Expression::Kind::kConstant) { + expr->setItem(i, fold(item)); + } + } + canBeFolded_ = canBeFolded; +} + +void FoldConstantExprVisitor::visit(MapExpression *expr) { + auto items = expr->items(); + bool canBeFolded = true; + for (size_t i = 0; i < items.size(); ++i) { + auto &pair = items[i]; + auto item = const_cast<Expression *>(pair.second); + if (!canBeFolded_) { + canBeFolded = false; + continue; + } + if (item->kind() != Expression::Kind::kConstant) { + auto key = std::make_unique<std::string>(*pair.first); + auto val = std::unique_ptr<Expression>(fold(item)); + expr->setItem(i, std::make_pair(std::move(key), std::move(val))); + } + } + canBeFolded_ = canBeFolded; +} + +// property Expression +void FoldConstantExprVisitor::visit(TagPropertyExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(EdgePropertyExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(InputPropertyExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(VariablePropertyExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(DestPropertyExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(SourcePropertyExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(EdgeSrcIdExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(EdgeTypeExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(EdgeRankExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(EdgeDstIdExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +// vertex/edge expression +void FoldConstantExprVisitor::visit(VertexExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visit(EdgeExpression *expr) { + UNUSED(expr); + canBeFolded_ = false; +} + +void FoldConstantExprVisitor::visitBinaryExpr(BinaryExpression *expr) { + expr->left()->accept(this); + auto leftCanBeFolded = canBeFolded_; + if (leftCanBeFolded && expr->left()->kind() != Expression::Kind::kConstant) { + expr->setLeft(fold(expr->left())); + } + expr->right()->accept(this); + auto rightCanBeFolded = canBeFolded_; + if (rightCanBeFolded && expr->right()->kind() != Expression::Kind::kConstant) { + expr->setRight(fold(expr->right())); + } + canBeFolded_ = leftCanBeFolded && rightCanBeFolded; +} + +Expression *FoldConstantExprVisitor::fold(Expression *expr) const { + QueryExpressionContext ctx; + auto value = expr->eval(ctx(nullptr)); + return new ConstantExpression(std::move(value)); +} + +} // namespace graph +} // namespace nebula diff --git a/src/visitor/FoldConstantExprVisitor.h b/src/visitor/FoldConstantExprVisitor.h new file mode 100644 index 0000000000000000000000000000000000000000..a99ae13abc00e57e37dd4e7f29c52305b678d04a --- /dev/null +++ b/src/visitor/FoldConstantExprVisitor.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#ifndef VISITOR_FOLDCONSTANTEXPRVISITOR_H_ +#define VISITOR_FOLDCONSTANTEXPRVISITOR_H_ + +#include "common/expression/ExprVisitor.h" + +namespace nebula { +namespace graph { + +class FoldConstantExprVisitor final : public ExprVisitor { +public: + bool canBeFolded() const { + return canBeFolded_; + } + + void visit(ConstantExpression *expr) override; + void visit(UnaryExpression *expr) override; + void visit(TypeCastingExpression *expr) override; + void visit(LabelExpression *expr) override; + void visit(LabelAttributeExpression *expr) override; + // binary expression + void visit(ArithmeticExpression *expr) override; + void visit(RelationalExpression *expr) override; + void visit(SubscriptExpression *expr) override; + void visit(AttributeExpression *expr) override; + void visit(LogicalExpression *expr) override; + // function call + void visit(FunctionCallExpression *expr) override; + void visit(UUIDExpression *expr) override; + // variable expression + void visit(VariableExpression *expr) override; + void visit(VersionedVariableExpression *expr) override; + // container expression + void visit(ListExpression *expr) override; + void visit(SetExpression *expr) override; + void visit(MapExpression *expr) override; + // property Expression + void visit(TagPropertyExpression *expr) override; + void visit(EdgePropertyExpression *expr) override; + void visit(InputPropertyExpression *expr) override; + void visit(VariablePropertyExpression *expr) override; + void visit(DestPropertyExpression *expr) override; + void visit(SourcePropertyExpression *expr) override; + void visit(EdgeSrcIdExpression *expr) override; + void visit(EdgeTypeExpression *expr) override; + void visit(EdgeRankExpression *expr) override; + void visit(EdgeDstIdExpression *expr) override; + // vertex/edge expression + void visit(VertexExpression *expr) override; + void visit(EdgeExpression *expr) override; + + void visitBinaryExpr(BinaryExpression *expr); + Expression *fold(Expression *expr) const; + +private: + bool canBeFolded_{false}; +}; + +} // namespace graph +} // namespace nebula + +#endif // VISITOR_FOLDCONSTANTEXPRVISITOR_H_ diff --git a/src/visitor/test/CMakeLists.txt b/src/visitor/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3d5e5036d85bb7623e300f5755e2d99e6a5dae2 --- /dev/null +++ b/src/visitor/test/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) 2020 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License, +# attached with Common Clause Condition 1.0, found in the LICENSES directory. + +nebula_add_test( + NAME + expr_visitor_test + SOURCES + FoldConstantExprVisitorTest.cpp + OBJECTS + $<TARGET_OBJECTS:mock_schema_obj> + $<TARGET_OBJECTS:util_obj> + $<TARGET_OBJECTS:validator_obj> + $<TARGET_OBJECTS:expr_visitor_obj> + $<TARGET_OBJECTS:planner_obj> + $<TARGET_OBJECTS:session_obj> + $<TARGET_OBJECTS:graph_flags_obj> + $<TARGET_OBJECTS:parser_obj> + $<TARGET_OBJECTS:idgenerator_obj> + $<TARGET_OBJECTS:context_obj> + $<TARGET_OBJECTS:graph_auth_obj> + $<TARGET_OBJECTS:common_time_function_obj> + $<TARGET_OBJECTS:common_expression_obj> + $<TARGET_OBJECTS:common_network_obj> + $<TARGET_OBJECTS:common_fs_obj> + $<TARGET_OBJECTS:common_time_obj> + $<TARGET_OBJECTS:common_stats_obj> + $<TARGET_OBJECTS:common_time_obj> + $<TARGET_OBJECTS:common_common_thrift_obj> + $<TARGET_OBJECTS:common_graph_thrift_obj> + $<TARGET_OBJECTS:common_storage_thrift_obj> + $<TARGET_OBJECTS:common_thrift_obj> + $<TARGET_OBJECTS:common_thread_obj> + $<TARGET_OBJECTS:common_datatypes_obj> + $<TARGET_OBJECTS:common_base_obj> + $<TARGET_OBJECTS:common_meta_thrift_obj> + $<TARGET_OBJECTS:common_meta_obj> + $<TARGET_OBJECTS:common_graph_thrift_obj> + $<TARGET_OBJECTS:common_charset_obj> + $<TARGET_OBJECTS:common_meta_client_obj> + $<TARGET_OBJECTS:common_file_based_cluster_id_man_obj> + $<TARGET_OBJECTS:common_function_manager_obj> + $<TARGET_OBJECTS:common_agg_function_obj> + $<TARGET_OBJECTS:common_conf_obj> + $<TARGET_OBJECTS:common_encryption_obj> + $<TARGET_OBJECTS:common_http_client_obj> + $<TARGET_OBJECTS:common_process_obj> + LIBRARIES + gtest + gtest_main + ${THRIFT_LIBRARIES} + wangle + proxygenhttpserver + proxygenlib +) diff --git a/src/visitor/test/FoldConstantExprVisitorTest.cpp b/src/visitor/test/FoldConstantExprVisitorTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cf7587ce6fb25b5f8732ca0b3b998eb088e6568 --- /dev/null +++ b/src/visitor/test/FoldConstantExprVisitorTest.cpp @@ -0,0 +1,180 @@ + +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License, + * attached with Common Clause Condition 1.0, found in the LICENSES directory. + */ + +#include "visitor/FoldConstantExprVisitor.h" + +#include <gtest/gtest.h> + +#include "util/ObjectPool.h" + +using Type = nebula::Value::Type; + +namespace nebula { +namespace graph { + +class FoldConstantExprVisitorTest : public ::testing::Test { +public: + void TearDown() override { + pool.clear(); + } + + static ConstantExpression *constant(Value value) { + return new ConstantExpression(std::move(value)); + } + + static ArithmeticExpression *add(Expression *lhs, Expression *rhs) { + return new ArithmeticExpression(Expression::Kind::kAdd, lhs, rhs); + } + + static ArithmeticExpression *minus(Expression *lhs, Expression *rhs) { + return new ArithmeticExpression(Expression::Kind::kMinus, lhs, rhs); + } + + static RelationalExpression *gt(Expression *lhs, Expression *rhs) { + return new RelationalExpression(Expression::Kind::kRelGT, lhs, rhs); + } + + static RelationalExpression *eq(Expression *lhs, Expression *rhs) { + return new RelationalExpression(Expression::Kind::kRelEQ, lhs, rhs); + } + + static TypeCastingExpression *cast(Type type, Expression *expr) { + return new TypeCastingExpression(type, expr); + } + + static UnaryExpression *not_(Expression *expr) { + return new UnaryExpression(Expression::Kind::kUnaryNot, expr); + } + + static LogicalExpression *and_(Expression *lhs, Expression *rhs) { + return new LogicalExpression(Expression::Kind::kLogicalAnd, lhs, rhs); + } + + static LogicalExpression *or_(Expression *lhs, Expression *rhs) { + return new LogicalExpression(Expression::Kind::kLogicalOr, lhs, rhs); + } + + static ListExpression *list_(std::initializer_list<Expression *> exprs) { + auto exprList = new ExpressionList; + for (auto expr : exprs) { + exprList->add(expr); + } + return new ListExpression(exprList); + } + + static SubscriptExpression *sub(Expression *lhs, Expression *rhs) { + return new SubscriptExpression(lhs, rhs); + } + + static FunctionCallExpression *fn(std::string fn, std::initializer_list<Expression *> args) { + auto argsList = new ArgumentList; + for (auto arg : args) { + argsList->addArgument(std::unique_ptr<Expression>(arg)); + } + return new FunctionCallExpression(new std::string(std::move(fn)), argsList); + } + + static VariableExpression *var(const std::string &name) { + return new VariableExpression(new std::string(name)); + } + +protected: + ObjectPool pool; +}; + +TEST_F(FoldConstantExprVisitorTest, TestArithmeticExpr) { + // (5 - 1) + 2 => 4 + 2 + auto expr = pool.add(add(minus(constant(5), constant(1)), constant(2))); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + auto expected = pool.add(add(constant(4), constant(2))); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); + + // 4+2 => 6 + auto root = pool.add(visitor.fold(expr)); + auto rootExpected = pool.add(constant(6)); + ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); +} + +TEST_F(FoldConstantExprVisitorTest, TestRelationExpr) { + // false == !(3 > (1+1)) => false == false + auto expr = pool.add(eq(constant(false), not_(gt(constant(3), add(constant(1), constant(1)))))); + auto expected = pool.add(eq(constant(false), constant(false))); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); + + // false==false => true + auto root = pool.add(visitor.fold(expr)); + auto rootExpected = pool.add(constant(true)); + ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); +} + +TEST_F(FoldConstantExprVisitorTest, TestLogicalExpr) { + // false && (false || (3 > (1 + 1))) => false && true + auto expr = pool.add(and_( + constant(false), or_(constant(false), gt(constant(3), add(constant(1), constant(1)))))); + auto expected = pool.add(and_(constant(false), constant(true))); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); + + // false && true => false + auto root = pool.add(visitor.fold(expr)); + auto rootExpected = pool.add(constant(false)); + ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); +} + +TEST_F(FoldConstantExprVisitorTest, TestSubscriptExpr) { + // 1 + [1, pow(2, 2+1), 2][2-1] => 1 + 8 + auto expr = pool.add(add(constant(1), + sub(list_({constant(1), + fn("pow", {constant(2), add(constant(2), constant(1))}), + constant(2)}), + minus(constant(2), constant(1))))); + auto expected = pool.add(add(constant(1), constant(8))); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); + + // 1+8 => 9 + auto root = pool.add(visitor.fold(expr)); + auto rootExpected = pool.add(constant(9)); + ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); +} + +TEST_F(FoldConstantExprVisitorTest, TestFoldFailed) { + // function call + { + // pow($v, (1+2)) => pow($v, 3) + auto expr = pool.add(fn("pow", {var("v"), add(constant(1), constant(2))})); + auto expected = pool.add(fn("pow", {var("v"), constant(3)})); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT_FALSE(visitor.canBeFolded()); + } + // list + { + // [$v, pow(1, 2), 1+2][2-1] => [$v, 1, 3][0] + auto expr = pool.add(sub( + list_({var("v"), fn("pow", {constant(1), constant(2)}), add(constant(1), constant(2))}), + minus(constant(1), constant(1)))); + auto expected = pool.add(sub(list_({var("v"), constant(1), constant(3)}), constant(0))); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT_FALSE(visitor.canBeFolded()); + } +} + +} // namespace graph +} // namespace nebula