Skip to content
Snippets Groups Projects
Unverified Commit 7714d6ee authored by Yee's avatar Yee Committed by GitHub
Browse files

Add constant expression folding visitor (#288)

* Add constant expression folding visitor

* Fold function call expr

* Address comments
parent 6ea7ff10
No related branches found
No related tags found
No related merge requests found
...@@ -45,3 +45,4 @@ cmake-build-release/ ...@@ -45,3 +45,4 @@ cmake-build-release/
*.pot *.pot
*.py[co] *.py[co]
__pycache__ __pycache__
venv/
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
# attached with Common Clause Condition 1.0, found in the LICENSES directory. # attached with Common Clause Condition 1.0, found in the LICENSES directory.
SET(CONTEXT_TEST_LIBS SET(CONTEXT_TEST_LIBS
$<TARGET_OBJECTS:common_agg_function_obj>
$<TARGET_OBJECTS:common_charset_obj>
$<TARGET_OBJECTS:common_datatypes_obj> $<TARGET_OBJECTS:common_datatypes_obj>
$<TARGET_OBJECTS:common_encryption_obj>
$<TARGET_OBJECTS:common_expression_obj> $<TARGET_OBJECTS:common_expression_obj>
$<TARGET_OBJECTS:common_function_manager_obj> $<TARGET_OBJECTS:common_function_manager_obj>
$<TARGET_OBJECTS:common_fs_obj> $<TARGET_OBJECTS:common_fs_obj>
...@@ -21,10 +24,15 @@ SET(CONTEXT_TEST_LIBS ...@@ -21,10 +24,15 @@ SET(CONTEXT_TEST_LIBS
$<TARGET_OBJECTS:common_graph_thrift_obj> $<TARGET_OBJECTS:common_graph_thrift_obj>
$<TARGET_OBJECTS:common_storage_thrift_obj> $<TARGET_OBJECTS:common_storage_thrift_obj>
$<TARGET_OBJECTS:common_time_function_obj> $<TARGET_OBJECTS:common_time_function_obj>
$<TARGET_OBJECTS:common_http_client_obj>
$<TARGET_OBJECTS:common_process_obj>
$<TARGET_OBJECTS:util_obj> $<TARGET_OBJECTS:util_obj>
$<TARGET_OBJECTS:context_obj> $<TARGET_OBJECTS:context_obj>
$<TARGET_OBJECTS:expr_visitor_obj>
$<TARGET_OBJECTS:parser_obj> $<TARGET_OBJECTS:parser_obj>
$<TARGET_OBJECTS:validator_obj>
$<TARGET_OBJECTS:graph_flags_obj> $<TARGET_OBJECTS:graph_flags_obj>
$<TARGET_OBJECTS:graph_auth_obj>
$<TARGET_OBJECTS:session_obj> $<TARGET_OBJECTS:session_obj>
$<TARGET_OBJECTS:planner_obj> $<TARGET_OBJECTS:planner_obj>
$<TARGET_OBJECTS:idgenerator_obj> $<TARGET_OBJECTS:idgenerator_obj>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include "optimizer/OptimizerUtils.h" #include "optimizer/OptimizerUtils.h"
namespace nebula { namespace nebula {
namespace graph { namespace graph {
......
...@@ -7,9 +7,9 @@ set(PARSER_TEST_LIBS ...@@ -7,9 +7,9 @@ set(PARSER_TEST_LIBS
$<TARGET_OBJECTS:parser_obj> $<TARGET_OBJECTS:parser_obj>
$<TARGET_OBJECTS:common_time_function_obj> $<TARGET_OBJECTS:common_time_function_obj>
$<TARGET_OBJECTS:common_expression_obj> $<TARGET_OBJECTS:common_expression_obj>
$<TARGET_OBJECTS:common_encryption_obj>
$<TARGET_OBJECTS:common_network_obj> $<TARGET_OBJECTS:common_network_obj>
$<TARGET_OBJECTS:common_fs_obj> $<TARGET_OBJECTS:common_fs_obj>
$<TARGET_OBJECTS:common_time_obj>
$<TARGET_OBJECTS:common_stats_obj> $<TARGET_OBJECTS:common_stats_obj>
$<TARGET_OBJECTS:common_time_obj> $<TARGET_OBJECTS:common_time_obj>
$<TARGET_OBJECTS:common_common_thrift_obj> $<TARGET_OBJECTS:common_common_thrift_obj>
...@@ -18,18 +18,25 @@ set(PARSER_TEST_LIBS ...@@ -18,18 +18,25 @@ set(PARSER_TEST_LIBS
$<TARGET_OBJECTS:common_datatypes_obj> $<TARGET_OBJECTS:common_datatypes_obj>
$<TARGET_OBJECTS:common_base_obj> $<TARGET_OBJECTS:common_base_obj>
$<TARGET_OBJECTS:common_function_manager_obj> $<TARGET_OBJECTS:common_function_manager_obj>
$<TARGET_OBJECTS:common_agg_function_obj>
$<TARGET_OBJECTS:common_meta_thrift_obj> $<TARGET_OBJECTS:common_meta_thrift_obj>
$<TARGET_OBJECTS:common_graph_thrift_obj> $<TARGET_OBJECTS:common_graph_thrift_obj>
$<TARGET_OBJECTS:common_http_client_obj>
$<TARGET_OBJECTS:common_storage_thrift_obj> $<TARGET_OBJECTS:common_storage_thrift_obj>
$<TARGET_OBJECTS:common_meta_obj> $<TARGET_OBJECTS:common_meta_obj>
$<TARGET_OBJECTS:common_meta_client_obj> $<TARGET_OBJECTS:common_meta_client_obj>
$<TARGET_OBJECTS:common_conf_obj> $<TARGET_OBJECTS:common_conf_obj>
$<TARGET_OBJECTS:common_charset_obj>
$<TARGET_OBJECTS:common_file_based_cluster_id_man_obj> $<TARGET_OBJECTS:common_file_based_cluster_id_man_obj>
$<TARGET_OBJECTS:common_process_obj>
$<TARGET_OBJECTS:session_obj> $<TARGET_OBJECTS:session_obj>
$<TARGET_OBJECTS:graph_flags_obj> $<TARGET_OBJECTS:graph_flags_obj>
$<TARGET_OBJECTS:graph_auth_obj>
$<TARGET_OBJECTS:util_obj> $<TARGET_OBJECTS:util_obj>
$<TARGET_OBJECTS:expr_visitor_obj>
$<TARGET_OBJECTS:context_obj> $<TARGET_OBJECTS:context_obj>
$<TARGET_OBJECTS:planner_obj> $<TARGET_OBJECTS:planner_obj>
$<TARGET_OBJECTS:validator_obj>
$<TARGET_OBJECTS:idgenerator_obj> $<TARGET_OBJECTS:idgenerator_obj>
) )
...@@ -37,28 +44,28 @@ nebula_add_test( ...@@ -37,28 +44,28 @@ nebula_add_test(
NAME parser_test NAME parser_test
SOURCES ParserTest.cpp SOURCES ParserTest.cpp
OBJECTS ${PARSER_TEST_LIBS} OBJECTS ${PARSER_TEST_LIBS}
LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} proxygenlib
) )
nebula_add_test( nebula_add_test(
NAME scanner_test NAME scanner_test
SOURCES ScannerTest.cpp SOURCES ScannerTest.cpp
OBJECTS ${PARSER_TEST_LIBS} OBJECTS ${PARSER_TEST_LIBS}
LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} proxygenlib
) )
nebula_add_executable( nebula_add_executable(
NAME parser_bm NAME parser_bm
SOURCES ParserBenchmark.cpp SOURCES ParserBenchmark.cpp
OBJECTS ${PARSER_TEST_LIBS} OBJECTS ${PARSER_TEST_LIBS}
LIBRARIES follybenchmark boost_regex ${THRIFT_LIBRARIES} LIBRARIES follybenchmark boost_regex ${THRIFT_LIBRARIES} proxygenlib
) )
nebula_add_test( nebula_add_test(
NAME expression_parsing_test NAME expression_parsing_test
SOURCES ExpressionParsingTest.cpp SOURCES ExpressionParsingTest.cpp
OBJECTS ${PARSER_TEST_LIBS} OBJECTS ${PARSER_TEST_LIBS}
LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} LIBRARIES gtest gtest_main ${THRIFT_LIBRARIES} proxygenlib
) )
if(ENABLE_FUZZ_TEST) if(ENABLE_FUZZ_TEST)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
nebula_add_library( nebula_add_library(
util_obj OBJECT util_obj OBJECT
ExpressionUtils.cpp
SchemaUtil.cpp SchemaUtil.cpp
ToJson.cpp ToJson.cpp
) )
......
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "util/ExpressionUtils.h"
#include "visitor/FoldConstantExprVisitor.h"
namespace nebula {
namespace graph {
std::unique_ptr<Expression> ExpressionUtils::foldConstantExpr(const Expression *expr) {
auto newExpr = expr->clone();
FoldConstantExprVisitor visitor;
newExpr->accept(&visitor);
if (visitor.canBeFolded()) {
return std::unique_ptr<Expression>(visitor.fold(newExpr.get()));
}
return newExpr;
}
} // namespace graph
} // namespace nebula
...@@ -53,30 +53,6 @@ public: ...@@ -53,30 +53,6 @@ public:
return findAny(expr, expected) != nullptr; return findAny(expr, expected) != nullptr;
} }
// Require data from input/variable
static bool hasInput(const Expression* expr) {
return hasAny(expr,
{Expression::Kind::kInputProperty,
Expression::Kind::kVarProperty,
Expression::Kind::kVar,
Expression::Kind::kVersionedVar});
}
// require data from graph storage
static const Expression* findStorage(const Expression* expr) {
return findAny(expr,
{Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
}
static std::vector<const Expression*> findAllStorage(const Expression* expr) { static std::vector<const Expression*> findAllStorage(const Expression* expr) {
return collectAll(expr, return collectAll(expr,
{Expression::Kind::kTagProperty, {Expression::Kind::kTagProperty,
...@@ -95,24 +71,6 @@ public: ...@@ -95,24 +71,6 @@ public:
return collectAll(expr, {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty}); return collectAll(expr, {Expression::Kind::kInputProperty, Expression::Kind::kVarProperty});
} }
static bool hasStorage(const Expression* expr) {
return findStorage(expr) != nullptr;
}
static bool isStorage(const Expression* expr) {
return isKindOf(expr,
{Expression::Kind::kTagProperty,
Expression::Kind::kEdgeProperty,
Expression::Kind::kDstProperty,
Expression::Kind::kSrcProperty,
Expression::Kind::kEdgeSrc,
Expression::Kind::kEdgeType,
Expression::Kind::kEdgeRank,
Expression::Kind::kEdgeDst,
Expression::Kind::kVertex,
Expression::Kind::kEdge});
}
static bool isConstExpr(const Expression* expr) { static bool isConstExpr(const Expression* expr) {
return !hasAny(expr, return !hasAny(expr,
{Expression::Kind::kInputProperty, {Expression::Kind::kInputProperty,
...@@ -148,6 +106,9 @@ public: ...@@ -148,6 +106,9 @@ public:
return new To(new std::string(std::move(*expr->left()->name())), return new To(new std::string(std::move(*expr->left()->name())),
new std::string(std::move(*expr->right()->name()))); new std::string(std::move(*expr->right()->name())));
} }
// Clone and fold constant expression
static std::unique_ptr<Expression> foldConstantExpr(const Expression* expr);
}; };
} // namespace graph } // namespace graph
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "parser/Clauses.h" #include "parser/Clauses.h"
#include "parser/TraverseSentences.h" #include "parser/TraverseSentences.h"
#include "planner/Query.h" #include "planner/Query.h"
#include "util/ExpressionUtils.h"
#include "visitor/FoldConstantExprVisitor.h"
namespace nebula { namespace nebula {
namespace graph { namespace graph {
...@@ -70,7 +72,7 @@ Status YieldValidator::checkAggFunAndBuildGroupItems(const YieldClause *clause) ...@@ -70,7 +72,7 @@ Status YieldValidator::checkAggFunAndBuildGroupItems(const YieldClause *clause)
} }
Status YieldValidator::checkInputProps() const { Status YieldValidator::checkInputProps() const {
auto& inputProps = const_cast<ExpressionProps*>(&exprProps_)->inputProps(); auto &inputProps = const_cast<ExpressionProps *>(&exprProps_)->inputProps();
if (inputs_.empty() && !inputProps.empty()) { if (inputs_.empty() && !inputProps.empty()) {
return Status::SemanticError("no inputs for yield columns."); return Status::SemanticError("no inputs for yield columns.");
} }
...@@ -82,7 +84,7 @@ Status YieldValidator::checkInputProps() const { ...@@ -82,7 +84,7 @@ Status YieldValidator::checkInputProps() const {
} }
Status YieldValidator::checkVarProps() const { Status YieldValidator::checkVarProps() const {
auto& varProps = const_cast<ExpressionProps*>(&exprProps_)->varProps(); auto &varProps = const_cast<ExpressionProps *>(&exprProps_)->varProps();
for (auto &pair : varProps) { for (auto &pair : varProps) {
auto &var = pair.first; auto &var = pair.first;
if (!vctx_->existVar(var)) { if (!vctx_->existVar(var)) {
...@@ -102,6 +104,7 @@ Status YieldValidator::makeOutputColumn(YieldColumn *column) { ...@@ -102,6 +104,7 @@ Status YieldValidator::makeOutputColumn(YieldColumn *column) {
auto expr = column->expr(); auto expr = column->expr();
DCHECK(expr != nullptr); DCHECK(expr != nullptr);
NG_RETURN_IF_ERROR(deduceProps(expr, exprProps_)); NG_RETURN_IF_ERROR(deduceProps(expr, exprProps_));
auto status = deduceExprType(expr); auto status = deduceExprType(expr);
...@@ -111,6 +114,13 @@ Status YieldValidator::makeOutputColumn(YieldColumn *column) { ...@@ -111,6 +114,13 @@ Status YieldValidator::makeOutputColumn(YieldColumn *column) {
auto name = deduceColName(column); auto name = deduceColName(column);
outputColumnNames_.emplace_back(name); outputColumnNames_.emplace_back(name);
// Constant expression folding must be after type deduction
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
if (visitor.canBeFolded()) {
column->setExpr(visitor.fold(expr));
}
outputs_.emplace_back(name, type); outputs_.emplace_back(name, type);
return Status::OK(); return Status::OK();
} }
...@@ -121,7 +131,7 @@ void YieldValidator::genConstantExprValues() { ...@@ -121,7 +131,7 @@ void YieldValidator::genConstantExprValues() {
ds.colNames = outputColumnNames_; ds.colNames = outputColumnNames_;
QueryExpressionContext ctx; QueryExpressionContext ctx;
Row row; Row row;
for (auto& column : columns_->columns()) { for (auto &column : columns_->columns()) {
row.values.emplace_back(Expression::eval(column->expr(), ctx(nullptr))); row.values.emplace_back(Expression::eval(column->expr(), ctx(nullptr)));
} }
ds.emplace_back(std::move(row)); ds.emplace_back(std::move(row));
...@@ -192,6 +202,8 @@ Status YieldValidator::validateWhere(const WhereClause *clause) { ...@@ -192,6 +202,8 @@ Status YieldValidator::validateWhere(const WhereClause *clause) {
} }
if (filter != nullptr) { if (filter != nullptr) {
NG_RETURN_IF_ERROR(deduceProps(filter, exprProps_)); NG_RETURN_IF_ERROR(deduceProps(filter, exprProps_));
auto newFilter = ExpressionUtils::foldConstantExpr(filter);
filterCondition_ = qctx_->objPool()->add(newFilter.release());
} }
return Status::OK(); return Status::OK();
} }
...@@ -201,7 +213,7 @@ Status YieldValidator::toPlan() { ...@@ -201,7 +213,7 @@ Status YieldValidator::toPlan() {
Filter *filter = nullptr; Filter *filter = nullptr;
if (yield->where()) { if (yield->where()) {
filter = Filter::make(qctx_, nullptr, yield->where()->filter()); filter = Filter::make(qctx_, nullptr, filterCondition_);
std::vector<std::string> colNames(inputs_.size()); std::vector<std::string> colNames(inputs_.size());
std::transform( std::transform(
inputs_.cbegin(), inputs_.cend(), colNames.begin(), [](auto &in) { return in.first; }); inputs_.cbegin(), inputs_.cend(), colNames.begin(), [](auto &in) { return in.first; });
...@@ -247,4 +259,3 @@ Status YieldValidator::toPlan() { ...@@ -247,4 +259,3 @@ Status YieldValidator::toPlan() {
} // namespace graph } // namespace graph
} // namespace nebula } // namespace nebula
...@@ -48,8 +48,9 @@ private: ...@@ -48,8 +48,9 @@ private:
YieldColumns *columns_{nullptr}; YieldColumns *columns_{nullptr};
std::vector<std::string> outputColumnNames_; std::vector<std::string> outputColumnNames_;
std::vector<Aggregate::GroupItem> groupItems_; std::vector<Aggregate::GroupItem> groupItems_;
ExpressionProps exprProps_; ExpressionProps exprProps_;
std::string constantExprVar_; std::string constantExprVar_;
Expression *filterCondition_{nullptr};
}; };
} // namespace graph } // namespace graph
......
...@@ -11,5 +11,8 @@ nebula_add_library( ...@@ -11,5 +11,8 @@ nebula_add_library(
DeduceTypeVisitor.cpp DeduceTypeVisitor.cpp
ExtractFilterExprVisitor.cpp ExtractFilterExprVisitor.cpp
FindAnyExprVisitor.cpp FindAnyExprVisitor.cpp
FoldConstantExprVisitor.cpp
RewriteLabelAttrVisitor.cpp RewriteLabelAttrVisitor.cpp
) )
nebula_add_subdirectory(test)
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "visitor/FoldConstantExprVisitor.h"
#include "context/QueryExpressionContext.h"
namespace nebula {
namespace graph {
void FoldConstantExprVisitor::visit(ConstantExpression *expr) {
UNUSED(expr);
canBeFolded_ = true;
}
void FoldConstantExprVisitor::visit(UnaryExpression *expr) {
expr->operand()->accept(this);
if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) {
expr->setOperand(fold(expr->operand()));
}
}
void FoldConstantExprVisitor::visit(TypeCastingExpression *expr) {
expr->operand()->accept(this);
if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) {
expr->setOperand(fold(expr->operand()));
}
}
void FoldConstantExprVisitor::visit(LabelExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(LabelAttributeExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
// binary expression
void FoldConstantExprVisitor::visit(ArithmeticExpression *expr) {
visitBinaryExpr(expr);
}
void FoldConstantExprVisitor::visit(RelationalExpression *expr) {
visitBinaryExpr(expr);
}
void FoldConstantExprVisitor::visit(SubscriptExpression *expr) {
visitBinaryExpr(expr);
}
void FoldConstantExprVisitor::visit(AttributeExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(LogicalExpression *expr) {
visitBinaryExpr(expr);
}
// function call
void FoldConstantExprVisitor::visit(FunctionCallExpression *expr) {
bool canBeFolded = true;
for (auto &arg : expr->args()->args()) {
if (arg->kind() != Expression::Kind::kConstant) {
arg->accept(this);
if (canBeFolded_) {
arg.reset(fold(arg.get()));
} else {
canBeFolded = false;
}
}
}
canBeFolded_ = canBeFolded;
}
void FoldConstantExprVisitor::visit(UUIDExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
// variable expression
void FoldConstantExprVisitor::visit(VariableExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(VersionedVariableExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
// container expression
void FoldConstantExprVisitor::visit(ListExpression *expr) {
auto items = expr->items();
bool canBeFolded = true;
for (size_t i = 0; i < items.size(); ++i) {
auto item = const_cast<Expression *>(items[i]);
item->accept(this);
if (!canBeFolded_) {
canBeFolded = false;
continue;
}
if (item->kind() != Expression::Kind::kConstant) {
expr->setItem(i, fold(item));
}
}
canBeFolded_ = canBeFolded;
}
void FoldConstantExprVisitor::visit(SetExpression *expr) {
auto items = expr->items();
bool canBeFolded = true;
for (size_t i = 0; i < items.size(); ++i) {
auto item = const_cast<Expression *>(items[i]);
item->accept(this);
if (!canBeFolded_) {
canBeFolded = false;
continue;
}
if (item->kind() != Expression::Kind::kConstant) {
expr->setItem(i, fold(item));
}
}
canBeFolded_ = canBeFolded;
}
void FoldConstantExprVisitor::visit(MapExpression *expr) {
auto items = expr->items();
bool canBeFolded = true;
for (size_t i = 0; i < items.size(); ++i) {
auto &pair = items[i];
auto item = const_cast<Expression *>(pair.second);
if (!canBeFolded_) {
canBeFolded = false;
continue;
}
if (item->kind() != Expression::Kind::kConstant) {
auto key = std::make_unique<std::string>(*pair.first);
auto val = std::unique_ptr<Expression>(fold(item));
expr->setItem(i, std::make_pair(std::move(key), std::move(val)));
}
}
canBeFolded_ = canBeFolded;
}
// property Expression
void FoldConstantExprVisitor::visit(TagPropertyExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(EdgePropertyExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(InputPropertyExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(VariablePropertyExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(DestPropertyExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(SourcePropertyExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(EdgeSrcIdExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(EdgeTypeExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(EdgeRankExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(EdgeDstIdExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
// vertex/edge expression
void FoldConstantExprVisitor::visit(VertexExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visit(EdgeExpression *expr) {
UNUSED(expr);
canBeFolded_ = false;
}
void FoldConstantExprVisitor::visitBinaryExpr(BinaryExpression *expr) {
expr->left()->accept(this);
auto leftCanBeFolded = canBeFolded_;
if (leftCanBeFolded && expr->left()->kind() != Expression::Kind::kConstant) {
expr->setLeft(fold(expr->left()));
}
expr->right()->accept(this);
auto rightCanBeFolded = canBeFolded_;
if (rightCanBeFolded && expr->right()->kind() != Expression::Kind::kConstant) {
expr->setRight(fold(expr->right()));
}
canBeFolded_ = leftCanBeFolded && rightCanBeFolded;
}
Expression *FoldConstantExprVisitor::fold(Expression *expr) const {
QueryExpressionContext ctx;
auto value = expr->eval(ctx(nullptr));
return new ConstantExpression(std::move(value));
}
} // namespace graph
} // namespace nebula
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#ifndef VISITOR_FOLDCONSTANTEXPRVISITOR_H_
#define VISITOR_FOLDCONSTANTEXPRVISITOR_H_
#include "common/expression/ExprVisitor.h"
namespace nebula {
namespace graph {
class FoldConstantExprVisitor final : public ExprVisitor {
public:
bool canBeFolded() const {
return canBeFolded_;
}
void visit(ConstantExpression *expr) override;
void visit(UnaryExpression *expr) override;
void visit(TypeCastingExpression *expr) override;
void visit(LabelExpression *expr) override;
void visit(LabelAttributeExpression *expr) override;
// binary expression
void visit(ArithmeticExpression *expr) override;
void visit(RelationalExpression *expr) override;
void visit(SubscriptExpression *expr) override;
void visit(AttributeExpression *expr) override;
void visit(LogicalExpression *expr) override;
// function call
void visit(FunctionCallExpression *expr) override;
void visit(UUIDExpression *expr) override;
// variable expression
void visit(VariableExpression *expr) override;
void visit(VersionedVariableExpression *expr) override;
// container expression
void visit(ListExpression *expr) override;
void visit(SetExpression *expr) override;
void visit(MapExpression *expr) override;
// property Expression
void visit(TagPropertyExpression *expr) override;
void visit(EdgePropertyExpression *expr) override;
void visit(InputPropertyExpression *expr) override;
void visit(VariablePropertyExpression *expr) override;
void visit(DestPropertyExpression *expr) override;
void visit(SourcePropertyExpression *expr) override;
void visit(EdgeSrcIdExpression *expr) override;
void visit(EdgeTypeExpression *expr) override;
void visit(EdgeRankExpression *expr) override;
void visit(EdgeDstIdExpression *expr) override;
// vertex/edge expression
void visit(VertexExpression *expr) override;
void visit(EdgeExpression *expr) override;
void visitBinaryExpr(BinaryExpression *expr);
Expression *fold(Expression *expr) const;
private:
bool canBeFolded_{false};
};
} // namespace graph
} // namespace nebula
#endif // VISITOR_FOLDCONSTANTEXPRVISITOR_H_
# Copyright (c) 2020 vesoft inc. All rights reserved.
#
# This source code is licensed under Apache 2.0 License,
# attached with Common Clause Condition 1.0, found in the LICENSES directory.
nebula_add_test(
NAME
expr_visitor_test
SOURCES
FoldConstantExprVisitorTest.cpp
OBJECTS
$<TARGET_OBJECTS:mock_schema_obj>
$<TARGET_OBJECTS:util_obj>
$<TARGET_OBJECTS:validator_obj>
$<TARGET_OBJECTS:expr_visitor_obj>
$<TARGET_OBJECTS:planner_obj>
$<TARGET_OBJECTS:session_obj>
$<TARGET_OBJECTS:graph_flags_obj>
$<TARGET_OBJECTS:parser_obj>
$<TARGET_OBJECTS:idgenerator_obj>
$<TARGET_OBJECTS:context_obj>
$<TARGET_OBJECTS:graph_auth_obj>
$<TARGET_OBJECTS:common_time_function_obj>
$<TARGET_OBJECTS:common_expression_obj>
$<TARGET_OBJECTS:common_network_obj>
$<TARGET_OBJECTS:common_fs_obj>
$<TARGET_OBJECTS:common_time_obj>
$<TARGET_OBJECTS:common_stats_obj>
$<TARGET_OBJECTS:common_time_obj>
$<TARGET_OBJECTS:common_common_thrift_obj>
$<TARGET_OBJECTS:common_graph_thrift_obj>
$<TARGET_OBJECTS:common_storage_thrift_obj>
$<TARGET_OBJECTS:common_thrift_obj>
$<TARGET_OBJECTS:common_thread_obj>
$<TARGET_OBJECTS:common_datatypes_obj>
$<TARGET_OBJECTS:common_base_obj>
$<TARGET_OBJECTS:common_meta_thrift_obj>
$<TARGET_OBJECTS:common_meta_obj>
$<TARGET_OBJECTS:common_graph_thrift_obj>
$<TARGET_OBJECTS:common_charset_obj>
$<TARGET_OBJECTS:common_meta_client_obj>
$<TARGET_OBJECTS:common_file_based_cluster_id_man_obj>
$<TARGET_OBJECTS:common_function_manager_obj>
$<TARGET_OBJECTS:common_agg_function_obj>
$<TARGET_OBJECTS:common_conf_obj>
$<TARGET_OBJECTS:common_encryption_obj>
$<TARGET_OBJECTS:common_http_client_obj>
$<TARGET_OBJECTS:common_process_obj>
LIBRARIES
gtest
gtest_main
${THRIFT_LIBRARIES}
wangle
proxygenhttpserver
proxygenlib
)
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "visitor/FoldConstantExprVisitor.h"
#include <gtest/gtest.h>
#include "util/ObjectPool.h"
using Type = nebula::Value::Type;
namespace nebula {
namespace graph {
class FoldConstantExprVisitorTest : public ::testing::Test {
public:
void TearDown() override {
pool.clear();
}
static ConstantExpression *constant(Value value) {
return new ConstantExpression(std::move(value));
}
static ArithmeticExpression *add(Expression *lhs, Expression *rhs) {
return new ArithmeticExpression(Expression::Kind::kAdd, lhs, rhs);
}
static ArithmeticExpression *minus(Expression *lhs, Expression *rhs) {
return new ArithmeticExpression(Expression::Kind::kMinus, lhs, rhs);
}
static RelationalExpression *gt(Expression *lhs, Expression *rhs) {
return new RelationalExpression(Expression::Kind::kRelGT, lhs, rhs);
}
static RelationalExpression *eq(Expression *lhs, Expression *rhs) {
return new RelationalExpression(Expression::Kind::kRelEQ, lhs, rhs);
}
static TypeCastingExpression *cast(Type type, Expression *expr) {
return new TypeCastingExpression(type, expr);
}
static UnaryExpression *not_(Expression *expr) {
return new UnaryExpression(Expression::Kind::kUnaryNot, expr);
}
static LogicalExpression *and_(Expression *lhs, Expression *rhs) {
return new LogicalExpression(Expression::Kind::kLogicalAnd, lhs, rhs);
}
static LogicalExpression *or_(Expression *lhs, Expression *rhs) {
return new LogicalExpression(Expression::Kind::kLogicalOr, lhs, rhs);
}
static ListExpression *list_(std::initializer_list<Expression *> exprs) {
auto exprList = new ExpressionList;
for (auto expr : exprs) {
exprList->add(expr);
}
return new ListExpression(exprList);
}
static SubscriptExpression *sub(Expression *lhs, Expression *rhs) {
return new SubscriptExpression(lhs, rhs);
}
static FunctionCallExpression *fn(std::string fn, std::initializer_list<Expression *> args) {
auto argsList = new ArgumentList;
for (auto arg : args) {
argsList->addArgument(std::unique_ptr<Expression>(arg));
}
return new FunctionCallExpression(new std::string(std::move(fn)), argsList);
}
static VariableExpression *var(const std::string &name) {
return new VariableExpression(new std::string(name));
}
protected:
ObjectPool pool;
};
TEST_F(FoldConstantExprVisitorTest, TestArithmeticExpr) {
// (5 - 1) + 2 => 4 + 2
auto expr = pool.add(add(minus(constant(5), constant(1)), constant(2)));
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
auto expected = pool.add(add(constant(4), constant(2)));
ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();
ASSERT(visitor.canBeFolded());
// 4+2 => 6
auto root = pool.add(visitor.fold(expr));
auto rootExpected = pool.add(constant(6));
ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString();
}
TEST_F(FoldConstantExprVisitorTest, TestRelationExpr) {
// false == !(3 > (1+1)) => false == false
auto expr = pool.add(eq(constant(false), not_(gt(constant(3), add(constant(1), constant(1))))));
auto expected = pool.add(eq(constant(false), constant(false)));
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();
ASSERT(visitor.canBeFolded());
// false==false => true
auto root = pool.add(visitor.fold(expr));
auto rootExpected = pool.add(constant(true));
ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString();
}
TEST_F(FoldConstantExprVisitorTest, TestLogicalExpr) {
// false && (false || (3 > (1 + 1))) => false && true
auto expr = pool.add(and_(
constant(false), or_(constant(false), gt(constant(3), add(constant(1), constant(1))))));
auto expected = pool.add(and_(constant(false), constant(true)));
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();
ASSERT(visitor.canBeFolded());
// false && true => false
auto root = pool.add(visitor.fold(expr));
auto rootExpected = pool.add(constant(false));
ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString();
}
TEST_F(FoldConstantExprVisitorTest, TestSubscriptExpr) {
// 1 + [1, pow(2, 2+1), 2][2-1] => 1 + 8
auto expr = pool.add(add(constant(1),
sub(list_({constant(1),
fn("pow", {constant(2), add(constant(2), constant(1))}),
constant(2)}),
minus(constant(2), constant(1)))));
auto expected = pool.add(add(constant(1), constant(8)));
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();
ASSERT(visitor.canBeFolded());
// 1+8 => 9
auto root = pool.add(visitor.fold(expr));
auto rootExpected = pool.add(constant(9));
ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString();
}
TEST_F(FoldConstantExprVisitorTest, TestFoldFailed) {
// function call
{
// pow($v, (1+2)) => pow($v, 3)
auto expr = pool.add(fn("pow", {var("v"), add(constant(1), constant(2))}));
auto expected = pool.add(fn("pow", {var("v"), constant(3)}));
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();
ASSERT_FALSE(visitor.canBeFolded());
}
// list
{
// [$v, pow(1, 2), 1+2][2-1] => [$v, 1, 3][0]
auto expr = pool.add(sub(
list_({var("v"), fn("pow", {constant(1), constant(2)}), add(constant(1), constant(2))}),
minus(constant(1), constant(1))));
auto expected = pool.add(sub(list_({var("v"), constant(1), constant(3)}), constant(0)));
FoldConstantExprVisitor visitor;
expr->accept(&visitor);
ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();
ASSERT_FALSE(visitor.canBeFolded());
}
}
} // namespace graph
} // namespace nebula
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment