diff --git a/.linters/cpp/checkKeyword.py b/.linters/cpp/checkKeyword.py index 8796d14dc1b213b17fb208306c51b9af56181e93..3ee523717cec9c9807173371868dd9e05aadc836 100755 --- a/.linters/cpp/checkKeyword.py +++ b/.linters/cpp/checkKeyword.py @@ -104,6 +104,7 @@ reserved_key_words = [ 'KW_RECOVER', 'KW_EXPLAIN', 'KW_UNWIND', + 'KW_CASE', ] diff --git a/src/parser/parser.yy b/src/parser/parser.yy index d7969cf1e027d5c20c7d564f89bf18fd4feb599d..9433e6141ee0f98bc9233aa66fbaa06069c67015 100644 --- a/src/parser/parser.yy +++ b/src/parser/parser.yy @@ -20,6 +20,7 @@ #include "common/expression/AttributeExpression.h" #include "common/expression/LabelAttributeExpression.h" #include "common/expression/VariableExpression.h" +#include "common/expression/CaseExpression.h" #include "util/SchemaUtil.h" namespace nebula { @@ -117,6 +118,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL; MatchStepRange *match_step_range; nebula::meta::cpp2::IndexFieldDef *index_field; nebula::IndexFieldList *index_field_list; + CaseList *case_list; } /* destructors */ @@ -156,11 +158,12 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL; %token KW_CONTAINS %token KW_STARTS KW_ENDS %token KW_UNWIND KW_SKIP KW_OPTIONAL +%token KW_CASE KW_THEN KW_ELSE KW_END /* symbols */ %token L_PAREN R_PAREN L_BRACKET R_BRACKET L_BRACE R_BRACE COMMA %token PIPE ASSIGN -%token DOT DOT_DOT COLON SEMICOLON L_ARROW R_ARROW AT +%token DOT DOT_DOT COLON QM SEMICOLON L_ARROW R_ARROW AT %token ID_PROP TYPE_PROP SRC_ID_PROP DST_ID_PROP RANK_PROP INPUT_REF DST_REF SRC_REF /* token type specification */ @@ -177,6 +180,8 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL; %type <expr> edge_prop_expression %type <expr> input_prop_expression %type <expr> var_prop_expression +%type <expr> generic_case_expression +%type <expr> conditional_expression %type <expr> vid_ref_expression %type <expr> vid %type <expr> function_call_expression @@ -187,6 +192,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL; %type <expr> container_expression %type <expr> subscript_expression %type <expr> attribute_expression +%type <expr> case_expression %type <expr> compound_expression %type <argument_list> argument_list opt_argument_list %type <type> type_spec @@ -237,6 +243,9 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL; %type <both_in_out_clause> both_in_out_clause %type <expression_list> expression_list %type <map_item_list> map_item_list +%type <case_list> when_then_list +%type <expr> case_condition +%type <expr> case_default %type <match_path> match_path_pattern %type <match_path> match_path @@ -306,6 +315,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL; %type <boolval> opt_if_not_exists %type <boolval> opt_if_exists +%left QM COLON %left KW_OR KW_XOR %left KW_AND %right KW_NOT @@ -403,6 +413,9 @@ unreserved_keyword | KW_BOTH { $$ = new std::string("both"); } | KW_OUT { $$ = new std::string("out"); } | KW_SUBGRAPH { $$ = new std::string("subgraph"); } + | KW_THEN { $$ = new std::string("then"); } + | KW_ELSE { $$ = new std::string("else"); } + | KW_END { $$ = new std::string("end"); } ; agg_function @@ -529,6 +542,9 @@ expression | expression KW_XOR expression { $$ = new LogicalExpression(Expression::Kind::kLogicalXor, $1, $3); } + | case_expression { + $$ = $1; + } ; compound_expression @@ -589,6 +605,63 @@ attribute_expression } ; +case_expression + : generic_case_expression { + $$ = $1; + } + | conditional_expression { + $$ = $1; + } + ; + +generic_case_expression + : KW_CASE case_condition when_then_list case_default KW_END { + auto expr = new CaseExpression($3); + expr->setCondition($2); + expr->setDefault($4); + $$ = expr; + } + ; + +conditional_expression + : expression QM expression COLON expression { + auto cases = new CaseList(); + cases->add($1, $3); + auto expr = new CaseExpression(cases, false); + expr->setDefault($5); + $$ = expr; + } + ; + +case_condition + : %empty { + $$ = nullptr; + } + | expression { + $$ = $1; + } + ; + +case_default + : %empty { + $$ = nullptr; + } + | KW_ELSE expression { + $$ = $2; + } + ; + +when_then_list + : KW_WHEN expression KW_THEN expression { + $$ = new CaseList(); + $$->add($2, $4); + } + | when_then_list KW_WHEN expression KW_THEN expression { + $1->add($3, $5); + $$ = $1; + } + ; + input_prop_expression : INPUT_REF DOT name_label { $$ = new InputPropertyExpression($3); diff --git a/src/parser/scanner.lex b/src/parser/scanner.lex index 7faae8421281eb35cb32ad103304ce900382e0a0..5fc13a994dcfc68482c5139765b1845b10b29e60 100644 --- a/src/parser/scanner.lex +++ b/src/parser/scanner.lex @@ -69,7 +69,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5]) "LOOKUP" { return TokenType::KW_LOOKUP; } "ALTER" { return TokenType::KW_ALTER; } "STEPS" { return TokenType::KW_STEPS; } -"STEP" { return TokenType::KW_STEPS; } +"STEP" { return TokenType::KW_STEPS; } "OVER" { return TokenType::KW_OVER; } "UPTO" { return TokenType::KW_UPTO; } "REVERSELY" { return TokenType::KW_REVERSELY; } @@ -134,6 +134,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5]) "EXPLAIN" { return TokenType::KW_EXPLAIN; } "PROFILE" { return TokenType::KW_PROFILE; } "FORMAT" { return TokenType::KW_FORMAT; } +"CASE" { return TokenType::KW_CASE; } /** @@ -212,6 +213,9 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5]) "UNWIND" { return TokenType::KW_UNWIND;} "SKIP" { return TokenType::KW_SKIP;} "OPTIONAL" { return TokenType::KW_OPTIONAL;} +"THEN" { return TokenType::KW_THEN; } +"ELSE" { return TokenType::KW_ELSE; } +"END" { return TokenType::KW_END; } "TRUE" { yylval->boolval = true; return TokenType::BOOL; } @@ -223,6 +227,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5]) ":" { return TokenType::COLON; } ";" { return TokenType::SEMICOLON; } "@" { return TokenType::AT; } +"?" { return TokenType::QM; } "+" { return TokenType::PLUS; } "-" { return TokenType::MINUS; } diff --git a/src/parser/test/ExpressionParsingTest.cpp b/src/parser/test/ExpressionParsingTest.cpp index ba2da6b373d0c9fe2dc6d613466432a385737764..a4d0dd79791a5040edaa13329e1e531b463d2da3 100644 --- a/src/parser/test/ExpressionParsingTest.cpp +++ b/src/parser/test/ExpressionParsingTest.cpp @@ -311,6 +311,18 @@ TEST_F(ExpressionParsingTest, Associativity) { make<ConstantExpression>(false))); add("!!false", ast); + auto cases = new CaseList(); + cases->add(make<ConstantExpression>(3), make<ConstantExpression>(4)); + ast = make<CaseExpression>(cases); + static_cast<CaseExpression*>(ast)->setCondition(make<LabelExpression>("a")); + auto cases2 = new CaseList(); + cases2->add(make<ConstantExpression>(5), make<ConstantExpression>(6)); + auto ast2 = make<CaseExpression>(cases2); + ast2->setCondition(make<LabelExpression>("b")); + ast2->setDefault(make<ConstantExpression>(7)); + static_cast<CaseExpression*>(ast)->setDefault(ast2); + add("CASE a WHEN 3 THEN 4 ELSE CASE b WHEN 5 THEN 6 ELSE 7 END END", ast); + run(); } diff --git a/src/visitor/CollectAllExprsVisitor.cpp b/src/visitor/CollectAllExprsVisitor.cpp index 73232ef6b0c98137987aec8034adfba477a041d0..687bc2e151e60f41106f97606e98477b556d0423 100644 --- a/src/visitor/CollectAllExprsVisitor.cpp +++ b/src/visitor/CollectAllExprsVisitor.cpp @@ -127,6 +127,20 @@ void CollectAllExprsVisitor::visit(EdgeExpression *expr) { collectExpr(expr); } +void CollectAllExprsVisitor::visit(CaseExpression *expr) { + collectExpr(expr); + if (expr->hasCondition()) { + expr->condition()->accept(this); + } + if (expr->hasDefault()) { + expr->defaultResult()->accept(this); + } + for (const auto &whenThen : expr->cases()) { + whenThen.when->accept(this); + whenThen.then->accept(this); + } +} + void CollectAllExprsVisitor::visitBinaryExpr(BinaryExpression *expr) { collectExpr(expr); expr->left()->accept(this); diff --git a/src/visitor/CollectAllExprsVisitor.h b/src/visitor/CollectAllExprsVisitor.h index d86b715b29d581f611313a0c7c9f57c0bdaaa937..eb9864ef3621755b8e29557107c8f348be8fe41b 100644 --- a/src/visitor/CollectAllExprsVisitor.h +++ b/src/visitor/CollectAllExprsVisitor.h @@ -55,8 +55,7 @@ private: void visit(AttributeExpression* expr) override; void visit(VertexExpression* expr) override; void visit(EdgeExpression* expr) override; - // TODO : CaseExpression - void visit(CaseExpression *) override {}; + void visit(CaseExpression* expr) override; void visitBinaryExpr(BinaryExpression* expr) override; void collectExpr(const Expression* expr); diff --git a/src/visitor/DeducePropsVisitor.h b/src/visitor/DeducePropsVisitor.h index 7d11bfd4d00c28ccec301d8cf183375f9066fc1d..ae0ecea539b7cab5610558d50fcde2b0e6a9d393 100644 --- a/src/visitor/DeducePropsVisitor.h +++ b/src/visitor/DeducePropsVisitor.h @@ -122,8 +122,6 @@ private: void visit(ConstantExpression* expr) override; void visit(VertexExpression* expr) override; void visit(EdgeExpression* expr) override; - // TODO : CaseExpression - void visit(CaseExpression*) override {}; void visitEdgePropExpr(PropertyExpression* expr); void reportError(const Expression* expr); diff --git a/src/visitor/DeduceTypeVisitor.cpp b/src/visitor/DeduceTypeVisitor.cpp index 3a42a89604f48ed222b528b41523477ed3f1de8c..67a3859656175339db013a3b0e672c8bf8de953f 100644 --- a/src/visitor/DeduceTypeVisitor.cpp +++ b/src/visitor/DeduceTypeVisitor.cpp @@ -377,6 +377,32 @@ void DeduceTypeVisitor::visit(EdgeExpression *) { type_ = Value::Type::EDGE; } +void DeduceTypeVisitor::visit(CaseExpression *expr) { + if (expr->hasCondition()) { + expr->condition()->accept(this); + if (!ok()) return; + } + if (expr->hasDefault()) { + expr->defaultResult()->accept(this); + if (!ok()) return; + } + + for (const auto &whenThen : expr->cases()) { + whenThen.when->accept(this); + if (!ok()) return; + if (!expr->hasCondition() && type_ != Value::Type::BOOL) { + status_ = Status::SemanticError( + "`%s': Invalid expression type, expecting expression of type BOOL", + expr->toString().c_str()); + return; + } + whenThen.then->accept(this); + if (!ok()) return; + } + // NOTE: we are not able to deduce the return type of case expression currently + type_ = Value::Type::__EMPTY__; +} + void DeduceTypeVisitor::visitVertexPropertyExpr(PropertyExpression *expr) { auto *tag = expr->sym(); auto tagId = qctx_->schemaMng()->toTagID(space_, *tag); diff --git a/src/visitor/DeduceTypeVisitor.h b/src/visitor/DeduceTypeVisitor.h index deaddcf638dc4fec31fc801f48ac926208bce1b8..e8de280073c02c98381b4b8b04196ed808b54527 100644 --- a/src/visitor/DeduceTypeVisitor.h +++ b/src/visitor/DeduceTypeVisitor.h @@ -73,8 +73,8 @@ private: // vertex/edge expression void visit(VertexExpression *expr) override; void visit(EdgeExpression *expr) override; - // TODO : CaseExpression - void visit(CaseExpression *) override {}; + // case expression + void visit(CaseExpression *expr) override; void visitVertexPropertyExpr(PropertyExpression *expr); diff --git a/src/visitor/EvaluableExprVisitor.h b/src/visitor/EvaluableExprVisitor.h index 1861f004378c4057631dc934641f7cce88ab2c52..b77f05a0bd9115be638e5fd74dc7cd22b31477b8 100644 --- a/src/visitor/EvaluableExprVisitor.h +++ b/src/visitor/EvaluableExprVisitor.h @@ -89,11 +89,6 @@ private: isEvaluable_ = false; } - // TODO : CaseExpression - void visit(CaseExpression *) override { - isEvaluable_ = false; - } - bool isEvaluable_{true}; }; diff --git a/src/visitor/ExprVisitorImpl.cpp b/src/visitor/ExprVisitorImpl.cpp index eda5a2ab72cbc4d4eaf78c448d314d25cfca5928..a3ef8ca5bcb6470ee428ec6206cc1265fd48375f 100644 --- a/src/visitor/ExprVisitorImpl.cpp +++ b/src/visitor/ExprVisitorImpl.cpp @@ -90,6 +90,33 @@ void ExprVisitorImpl::visit(MapExpression *expr) { } } +// case expression +void ExprVisitorImpl::visit(CaseExpression *expr) { + DCHECK(ok()); + if (expr->hasCondition()) { + expr->condition()->accept(this); + if (!ok()) { + return; + } + } + if (expr->hasDefault()) { + expr->defaultResult()->accept(this); + if (!ok()) { + return; + } + } + for (const auto &whenThen : expr->cases()) { + whenThen.when->accept(this); + if (!ok()) { + break; + } + whenThen.then->accept(this); + if (!ok()) { + break; + } + } +} + void ExprVisitorImpl::visitBinaryExpr(BinaryExpression *expr) { DCHECK(ok()); expr->left()->accept(this); diff --git a/src/visitor/ExprVisitorImpl.h b/src/visitor/ExprVisitorImpl.h index 748764d4bdcc6260612fc1ec2e72020fdf428bb3..7e87f4fc502d6351409504268341629479df4bee 100644 --- a/src/visitor/ExprVisitorImpl.h +++ b/src/visitor/ExprVisitorImpl.h @@ -29,6 +29,8 @@ public: void visit(ListExpression *expr) override; void visit(SetExpression *expr) override; void visit(MapExpression *expr) override; + // case expression + void visit(CaseExpression *expr) override; protected: using ExprVisitor::visit; diff --git a/src/visitor/ExtractFilterExprVisitor.h b/src/visitor/ExtractFilterExprVisitor.h index 9f8e96e16587975fa4d7c45f513c564e4b72f8eb..bf8c1d3afd9f922f458738ff96a718478ab1f971 100644 --- a/src/visitor/ExtractFilterExprVisitor.h +++ b/src/visitor/ExtractFilterExprVisitor.h @@ -46,8 +46,6 @@ private: void visit(VertexExpression *) override; void visit(EdgeExpression *) override; void visit(LogicalExpression *) override; - // TODO : CaseExpression - void visit(CaseExpression *) override {}; bool canBePushed_{true}; std::unique_ptr<Expression> remainedExpr_; diff --git a/src/visitor/ExtractPropExprVisitor.cpp b/src/visitor/ExtractPropExprVisitor.cpp index 303337d93ec950a2d30e5a29bac08c52dde6255a..06d19fe59f1d2eaf418a85c02dd31eb26a098260 100644 --- a/src/visitor/ExtractPropExprVisitor.cpp +++ b/src/visitor/ExtractPropExprVisitor.cpp @@ -192,7 +192,7 @@ void ExtractPropExprVisitor::visit(DestPropertyExpression* expr) { } } -void ExtractPropExprVisitor::reportError(const Expression *expr) { +void ExtractPropExprVisitor::reportError(const Expression* expr) { std::stringstream ss; ss << "Not supported expression `" << expr->toString() << "' for ExtractPropsExpression."; status_ = Status::SemanticError(ss.str()); diff --git a/src/visitor/ExtractPropExprVisitor.h b/src/visitor/ExtractPropExprVisitor.h index d01351dfa9c26ae9a99338a3f821a00f44060f9e..d5c521f9c59ebe090149972f2d69d6642fbd96e7 100644 --- a/src/visitor/ExtractPropExprVisitor.h +++ b/src/visitor/ExtractPropExprVisitor.h @@ -60,8 +60,6 @@ private: void visit(EdgeExpression *) override; // binary expression void visit(SubscriptExpression *) override; - // TODO : CaseExpression - void visit(CaseExpression *) override {}; void visitVertexEdgePropExpr(PropertyExpression *); void visitPropertyExpr(PropertyExpression *); diff --git a/src/visitor/FindAnyExprVisitor.cpp b/src/visitor/FindAnyExprVisitor.cpp index 766aae91c9f6a5fbcce941ab15148dbb44a99db2..2f0af4c9a073e45d3bf1339f2b355c7f3c8904c1 100644 --- a/src/visitor/FindAnyExprVisitor.cpp +++ b/src/visitor/FindAnyExprVisitor.cpp @@ -62,6 +62,25 @@ void FindAnyExprVisitor::visit(MapExpression *expr) { } } +void FindAnyExprVisitor::visit(CaseExpression *expr) { + findExpr(expr); + if (found_) return; + if (expr->hasCondition()) { + expr->condition()->accept(this); + if (found_) return; + } + if (expr->hasDefault()) { + expr->defaultResult()->accept(this); + if (found_) return; + } + for (const auto &whenThen : expr->cases()) { + whenThen.when->accept(this); + if (found_) return; + whenThen.then->accept(this); + if (found_) return; + } +} + void FindAnyExprVisitor::visit(ConstantExpression *expr) { findExpr(expr); } diff --git a/src/visitor/FindAnyExprVisitor.h b/src/visitor/FindAnyExprVisitor.h index 23727a461f058136d51cf3a49cc750285668b767..6372fd1ca5cd130df72457b959fc516485ab545a 100644 --- a/src/visitor/FindAnyExprVisitor.h +++ b/src/visitor/FindAnyExprVisitor.h @@ -37,6 +37,7 @@ private: void visit(ListExpression* expr) override; void visit(SetExpression* expr) override; void visit(MapExpression* expr) override; + void visit(CaseExpression* expr) override; void visit(ConstantExpression* expr) override; void visit(EdgePropertyExpression* expr) override; @@ -55,8 +56,6 @@ private: void visit(LabelExpression* expr) override; void visit(VertexExpression* expr) override; void visit(EdgeExpression* expr) override; - // TODO : CaseExpression - void visit(CaseExpression*) override {}; void visitBinaryExpr(BinaryExpression* expr) override; diff --git a/src/visitor/FoldConstantExprVisitor.cpp b/src/visitor/FoldConstantExprVisitor.cpp index 7fafff0498f383742b96c73479e181abd21084c8..64f9792e94f4abe90cab370519afac3f0c090b15 100644 --- a/src/visitor/FoldConstantExprVisitor.cpp +++ b/src/visitor/FoldConstantExprVisitor.cpp @@ -19,16 +19,20 @@ void FoldConstantExprVisitor::visit(ConstantExpression *expr) { } void FoldConstantExprVisitor::visit(UnaryExpression *expr) { - expr->operand()->accept(this); - if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) { - expr->setOperand(fold(expr->operand())); + if (!isConstant(expr->operand())) { + expr->operand()->accept(this); + if (canBeFolded_) { + expr->setOperand(fold(expr->operand())); + } } } void FoldConstantExprVisitor::visit(TypeCastingExpression *expr) { - expr->operand()->accept(this); - if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) { - expr->setOperand(fold(expr->operand())); + if (!isConstant(expr->operand())) { + expr->operand()->accept(this); + if (canBeFolded_) { + expr->setOperand(fold(expr->operand())); + } } } @@ -68,7 +72,7 @@ void FoldConstantExprVisitor::visit(LogicalExpression *expr) { void FoldConstantExprVisitor::visit(FunctionCallExpression *expr) { bool canBeFolded = true; for (auto &arg : expr->args()->args()) { - if (arg->kind() != Expression::Kind::kConstant) { + if (!isConstant(arg.get())) { arg->accept(this); if (canBeFolded_) { arg.reset(fold(arg.get())); @@ -110,13 +114,14 @@ void FoldConstantExprVisitor::visit(ListExpression *expr) { bool canBeFolded = true; for (size_t i = 0; i < items.size(); ++i) { auto item = items[i].get(); - item->accept(this); - if (!canBeFolded_) { - canBeFolded = false; + if (isConstant(item)) { continue; } - if (item->kind() != Expression::Kind::kConstant) { + item->accept(this); + if (canBeFolded_) { expr->setItem(i, std::unique_ptr<Expression>{fold(item)}); + } else { + canBeFolded = false; } } canBeFolded_ = canBeFolded; @@ -127,13 +132,14 @@ void FoldConstantExprVisitor::visit(SetExpression *expr) { bool canBeFolded = true; for (size_t i = 0; i < items.size(); ++i) { auto item = items[i].get(); - item->accept(this); - if (!canBeFolded_) { - canBeFolded = false; + if (isConstant(item)) { continue; } - if (item->kind() != Expression::Kind::kConstant) { + item->accept(this); + if (canBeFolded_) { expr->setItem(i, std::unique_ptr<Expression>{fold(item)}); + } else { + canBeFolded = false; } } canBeFolded_ = canBeFolded; @@ -145,15 +151,59 @@ void FoldConstantExprVisitor::visit(MapExpression *expr) { for (size_t i = 0; i < items.size(); ++i) { auto &pair = items[i]; auto item = const_cast<Expression *>(pair.second.get()); - item->accept(this); - if (!canBeFolded_) { - canBeFolded = false; + if (isConstant(item)) { continue; } - if (item->kind() != Expression::Kind::kConstant) { + item->accept(this); + if (canBeFolded_) { auto key = std::make_unique<std::string>(*pair.first); auto val = std::unique_ptr<Expression>(fold(item)); expr->setItem(i, std::make_pair(std::move(key), std::move(val))); + } else { + canBeFolded = false; + } + } + canBeFolded_ = canBeFolded; +} + +// case Expression +void FoldConstantExprVisitor::visit(CaseExpression *expr) { + bool canBeFolded = true; + if (expr->hasCondition() && !isConstant(expr->condition())) { + expr->condition()->accept(this); + if (canBeFolded_) { + expr->setCondition(fold(expr->condition())); + } else { + canBeFolded = false; + } + } + if (expr->hasDefault() && !isConstant(expr->defaultResult())) { + expr->defaultResult()->accept(this); + if (canBeFolded_) { + expr->setDefault(fold(expr->defaultResult())); + } else { + canBeFolded = false; + } + } + auto &cases = expr->cases(); + for (size_t i = 0; i < cases.size(); ++i) { + auto when = cases[i].when.get(); + auto then = cases[i].then.get(); + if (!isConstant(when)) { + when->accept(this); + if (canBeFolded_) { + expr->setWhen(i, fold(when)); + } else { + canBeFolded = false; + } + } + if (!isConstant(then)) { + then->accept(this); + if (canBeFolded_) { + expr->setThen(i, fold(then)); + } else { + canBeFolded = false; + } } } canBeFolded_ = canBeFolded; @@ -222,15 +272,20 @@ void FoldConstantExprVisitor::visit(EdgeExpression *expr) { } void FoldConstantExprVisitor::visitBinaryExpr(BinaryExpression *expr) { - expr->left()->accept(this); - auto leftCanBeFolded = canBeFolded_; - if (leftCanBeFolded && expr->left()->kind() != Expression::Kind::kConstant) { - expr->setLeft(fold(expr->left())); + bool leftCanBeFolded = true, rightCanBeFolded = true; + if (!isConstant(expr->left())) { + expr->left()->accept(this); + leftCanBeFolded = canBeFolded_; + if (leftCanBeFolded) { + expr->setLeft(fold(expr->left())); + } } - expr->right()->accept(this); - auto rightCanBeFolded = canBeFolded_; - if (rightCanBeFolded && expr->right()->kind() != Expression::Kind::kConstant) { - expr->setRight(fold(expr->right())); + if (!isConstant(expr->right())) { + expr->right()->accept(this); + rightCanBeFolded = canBeFolded_; + if (rightCanBeFolded) { + expr->setRight(fold(expr->right())); + } } canBeFolded_ = leftCanBeFolded && rightCanBeFolded; } diff --git a/src/visitor/FoldConstantExprVisitor.h b/src/visitor/FoldConstantExprVisitor.h index 04be0cad2ac06b4b7cdfdbb65ff49783e4213ac7..0667f2508ea32d0b520fd190666af49f0fbc560f 100644 --- a/src/visitor/FoldConstantExprVisitor.h +++ b/src/visitor/FoldConstantExprVisitor.h @@ -18,6 +18,10 @@ public: return canBeFolded_; } + bool isConstant(Expression *expr) const { + return expr->kind() == Expression::Kind::kConstant; + } + void visit(ConstantExpression *expr) override; void visit(UnaryExpression *expr) override; void visit(TypeCastingExpression *expr) override; @@ -53,8 +57,8 @@ public: // vertex/edge expression void visit(VertexExpression *expr) override; void visit(EdgeExpression *expr) override; - // TODO : CaseExpression - void visit(CaseExpression*) override {}; + // case expression + void visit(CaseExpression *expr) override; void visitBinaryExpr(BinaryExpression *expr); Expression *fold(Expression *expr) const; diff --git a/src/visitor/RewriteInputPropVisitor.cpp b/src/visitor/RewriteInputPropVisitor.cpp index 4c2bd8fe14b580f336154faebdf95be5758a6dd8..7221136af475ff78eea6e48a777d5308f3c31e57 100644 --- a/src/visitor/RewriteInputPropVisitor.cpp +++ b/src/visitor/RewriteInputPropVisitor.cpp @@ -168,6 +168,32 @@ void RewriteInputPropVisitor::visit(TypeCastingExpression* expr) { } } +void RewriteInputPropVisitor::visit(CaseExpression* expr) { + if (expr->hasCondition()) { + expr->condition()->accept(this); + if (ok()) { + expr->setCondition(result_.release()); + } + } + if (expr->hasDefault()) { + expr->defaultResult()->accept(this); + if (ok()) { + expr->setDefault(result_.release()); + } + } + for (size_t i = 0; i < expr->cases().size(); ++i) { + const auto& whenThen = expr->cases()[i]; + whenThen.when->accept(this); + if (ok()) { + expr->setWhen(i, result_.release()); + } + whenThen.then->accept(this); + if (ok()) { + expr->setThen(i, result_.release()); + } + } +} + void RewriteInputPropVisitor::visitBinaryExpr(BinaryExpression* expr) { expr->left()->accept(this); if (ok()) { diff --git a/src/visitor/RewriteInputPropVisitor.h b/src/visitor/RewriteInputPropVisitor.h index cac274e23b16ab0a36f66f78b6f2ebb687b3223e..a94468a9e19cd6f3321103e645a90131de68c9ce 100644 --- a/src/visitor/RewriteInputPropVisitor.h +++ b/src/visitor/RewriteInputPropVisitor.h @@ -68,8 +68,8 @@ private: // vertex/edge expression void visit(VertexExpression *) override; void visit(EdgeExpression *) override; - void visit(CaseExpression*) override {}; - + // case expression + void visit(CaseExpression *) override; void visitBinaryExpr(BinaryExpression *expr); void visitUnaryExpr(UnaryExpression *expr); diff --git a/src/visitor/RewriteLabelAttrVisitor.cpp b/src/visitor/RewriteLabelAttrVisitor.cpp index b43aeadad3c1eb5cd82148408c9fb6d3a03b3990..3942be6d0f46f7917713811108f4623b5e3fa1e8 100644 --- a/src/visitor/RewriteLabelAttrVisitor.cpp +++ b/src/visitor/RewriteLabelAttrVisitor.cpp @@ -84,6 +84,42 @@ void RewriteLabelAttrVisitor::visit(MapExpression* expr) { expr->setItems(std::move(newItems)); } +void RewriteLabelAttrVisitor::visit(CaseExpression* expr) { + if (expr->hasCondition()) { + if (isLabelAttrExpr(expr->condition())) { + auto newExpr = static_cast<LabelAttributeExpression*>(expr->condition()); + expr->setCondition(createExpr(newExpr)); + } else { + expr->condition()->accept(this); + } + } + if (expr->hasDefault()) { + if (isLabelAttrExpr(expr->defaultResult())) { + auto newExpr = static_cast<LabelAttributeExpression*>(expr->defaultResult()); + expr->setDefault(createExpr(newExpr)); + } else { + expr->defaultResult()->accept(this); + } + } + auto& cases = expr->cases(); + for (size_t i = 0; i < cases.size(); ++i) { + auto when = cases[i].when.get(); + auto then = cases[i].then.get(); + if (isLabelAttrExpr(when)) { + auto newExpr = static_cast<LabelAttributeExpression*>(when); + expr->setWhen(i, createExpr(newExpr)); + } else { + when->accept(this); + } + if (isLabelAttrExpr(then)) { + auto newExpr = static_cast<LabelAttributeExpression*>(then); + expr->setThen(i, createExpr(newExpr)); + } else { + then->accept(this); + } + } +} + void RewriteLabelAttrVisitor::visitBinaryExpr(BinaryExpression* expr) { if (isLabelAttrExpr(expr->left())) { auto left = static_cast<const LabelAttributeExpression*>(expr->left()); diff --git a/src/visitor/RewriteLabelAttrVisitor.h b/src/visitor/RewriteLabelAttrVisitor.h index a6b5795defd68baef69dfe4b81ff05678b13f417..b03140fc07081a4e2a176e660db893cb401a4e65 100644 --- a/src/visitor/RewriteLabelAttrVisitor.h +++ b/src/visitor/RewriteLabelAttrVisitor.h @@ -32,6 +32,7 @@ private: void visit(ListExpression *expr) override; void visit(SetExpression *expr) override; void visit(MapExpression *expr) override; + void visit(CaseExpression *) override; void visit(ConstantExpression *) override {} void visit(LabelExpression *) override {} void visit(UUIDExpression *) override {} @@ -50,8 +51,6 @@ private: void visit(EdgeDstIdExpression *) override {} void visit(VertexExpression *) override {} void visit(EdgeExpression *) override {} - // TODO : CaseExpression - void visit(CaseExpression *) override {} void visitBinaryExpr(BinaryExpression *expr) override; diff --git a/src/visitor/RewriteMatchLabelVisitor.cpp b/src/visitor/RewriteMatchLabelVisitor.cpp index febe8439c1d7db956e9169a78e2d1fd64d944251..87bbabdae7e9c798138546362b49b2f718d04c53 100644 --- a/src/visitor/RewriteMatchLabelVisitor.cpp +++ b/src/visitor/RewriteMatchLabelVisitor.cpp @@ -89,6 +89,37 @@ void RewriteMatchLabelVisitor::visit(MapExpression *expr) { expr->setItems(std::move(newItems)); } +void RewriteMatchLabelVisitor::visit(CaseExpression *expr) { + if (expr->hasCondition()) { + if (isLabel(expr->condition())) { + expr->setCondition(rewriter_(expr)); + } else { + expr->condition()->accept(this); + } + } + if (expr->hasDefault()) { + if (isLabel(expr->defaultResult())) { + expr->setDefault(rewriter_(expr)); + } else { + expr->defaultResult()->accept(this); + } + } + auto &cases = expr->cases(); + for (size_t i = 0; i < cases.size(); ++i) { + auto when = cases[i].when.get(); + auto then = cases[i].then.get(); + if (isLabel(when)) { + expr->setWhen(i, rewriter_(when)); + } else { + when->accept(this); + } + if (isLabel(then)) { + expr->setThen(i, rewriter_(then)); + } else { + then->accept(this); + } + } +} void RewriteMatchLabelVisitor::visitBinaryExpr(BinaryExpression *expr) { if (isLabel(expr->left())) { diff --git a/src/visitor/RewriteMatchLabelVisitor.h b/src/visitor/RewriteMatchLabelVisitor.h index af14ec081262b9e871f804b66afade07ab076566..16fe932bdae1e4b70751a30847d83422cb656755 100644 --- a/src/visitor/RewriteMatchLabelVisitor.h +++ b/src/visitor/RewriteMatchLabelVisitor.h @@ -39,7 +39,8 @@ private: void visit(ListExpression*) override; void visit(SetExpression*) override; void visit(MapExpression*) override; - void visit(ConstantExpression*) override {} + void visit(CaseExpression *) override; + void visit(ConstantExpression *) override {} void visit(LabelExpression*) override {} void visit(AttributeExpression*) override; void visit(UUIDExpression*) override {} @@ -58,8 +59,6 @@ private: void visit(EdgeDstIdExpression*) override {} void visit(VertexExpression*) override {} void visit(EdgeExpression*) override {} - // TODO : CaseExpression - void visit(CaseExpression*) override {} void visitBinaryExpr(BinaryExpression *) override; diff --git a/src/visitor/RewriteSymExprVisitor.cpp b/src/visitor/RewriteSymExprVisitor.cpp index 072e0a79501058cfc3ec659c2ad6c27143e77267..b9ba838477581e619e64b22528463d250488a67f 100644 --- a/src/visitor/RewriteSymExprVisitor.cpp +++ b/src/visitor/RewriteSymExprVisitor.cpp @@ -199,6 +199,34 @@ void RewriteSymExprVisitor::visit(EdgeExpression *expr) { expr_.reset(); } +void RewriteSymExprVisitor::visit(CaseExpression *expr) { + if (expr->hasCondition()) { + expr->condition()->accept(this); + if (expr_) { + expr->setCondition(expr_.release()); + } + } + if (expr->hasDefault()) { + expr->defaultResult()->accept(this); + if (expr_) { + expr->setDefault(expr_.release()); + } + } + auto &cases = expr->cases(); + for (size_t i = 0; i < cases.size(); ++i) { + auto when = cases[i].when.get(); + auto then = cases[i].then.get(); + when->accept(this); + if (expr_) { + expr->setWhen(i, expr_.release()); + } + then->accept(this); + if (expr_) { + expr->setThen(i, expr_.release()); + } + } +} + void RewriteSymExprVisitor::visitBinaryExpr(BinaryExpression *expr) { expr->left()->accept(this); if (expr_) { diff --git a/src/visitor/RewriteSymExprVisitor.h b/src/visitor/RewriteSymExprVisitor.h index 40b44d9f6910c8f4105f3e6c0303e4eb56fcc2eb..8027b222d7db5b970b32e66bc25e78666ea7ee08 100644 --- a/src/visitor/RewriteSymExprVisitor.h +++ b/src/visitor/RewriteSymExprVisitor.h @@ -64,8 +64,8 @@ public: // vertex/edge expression void visit(VertexExpression *expr) override; void visit(EdgeExpression *expr) override; - // TODO : CaseExpression - void visit(CaseExpression*) override {}; + // case expression + void visit(CaseExpression *expr) override; private: void visitBinaryExpr(BinaryExpression *expr); diff --git a/src/visitor/test/FoldConstantExprVisitorTest.cpp b/src/visitor/test/FoldConstantExprVisitorTest.cpp index f35131dc5974a3ae847fd36b19648dd94134e537..42a8e9f809e2c29c7f178d0bf9fafde18078c65a 100644 --- a/src/visitor/test/FoldConstantExprVisitorTest.cpp +++ b/src/visitor/test/FoldConstantExprVisitorTest.cpp @@ -100,6 +100,16 @@ public: return new VariableExpression(new std::string(name)); } + static CaseExpression *caseExpr(Expression *cond, Expression *defaltResult, + Expression *when, Expression *then) { + auto caseList = new CaseList; + caseList->add(when, then); + auto expr = new CaseExpression(caseList); + expr->setCondition(cond); + expr->setDefault(defaltResult); + return expr; + } + protected: ObjectPool pool; }; @@ -215,6 +225,26 @@ TEST_F(FoldConstantExprVisitorTest, TestMapExpr) { ASSERT(visitor.canBeFolded()); } +TEST_F(FoldConstantExprVisitorTest, TestCaseExpr) { + // CASE pow(2, (2+1)) WHEN (2+3) THEN (5-1) ELSE (7+8) + auto expr = pool.add(caseExpr(fnExpr("pow", {constantExpr(2), + addExpr(constantExpr(2), constantExpr(1))}), + addExpr(constantExpr(7), constantExpr(8)), + addExpr(constantExpr(2), constantExpr(3)), + minusExpr(constantExpr(5), constantExpr(1)))); + auto expected = pool.add(caseExpr(constantExpr(8), constantExpr(15), + constantExpr(5), constantExpr(4))); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); + + // CASE 8 WHEN 5 THEN 4 ELSE 15 => 15 + auto root = pool.add(visitor.fold(expr)); + auto rootExpected = pool.add(constantExpr(15)); + ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); +} + TEST_F(FoldConstantExprVisitorTest, TestFoldFunction) { // pure function // abs(-1) + 1 => 1 + 1 diff --git a/tests/query/v2/test_case_expression.py b/tests/query/v2/test_case_expression.py new file mode 100644 index 0000000000000000000000000000000000000000..29dc701d1c54b4a3781787d79ebec177e2aa87e5 --- /dev/null +++ b/tests/query/v2/test_case_expression.py @@ -0,0 +1,224 @@ +# --coding:utf-8-- +# +# Copyright (c) 2020 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License, +# attached with Common Clause Condition 1.0, found in the LICENSES directory. + +from tests.common.nebula_test_suite import NebulaTestSuite +from tests.common.nebula_test_suite import T_NULL, T_EMPTY +import pytest + + +class TestCaseExpression(NebulaTestSuite): + @classmethod + def prepare(self): + self.use_nba() + + def cleanup(): + pass + + def test_generic_case_expression(self): + stmt = 'YIELD CASE 2 + 3 WHEN 4 THEN 0 WHEN 5 THEN 1 ELSE 2 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[1]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE true WHEN false THEN 0 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[T_NULL]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'GO FROM "Jonathon Simmons" OVER serve YIELD $$.team.name as name, \ + CASE serve.end_year > 2017 WHEN true THEN "ok" ELSE "no" END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [['Spurs', 'no'], ['Magic', 'ok'], ['76ers', 'ok']] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''GO FROM "Boris Diaw" OVER serve YIELD \ + $^.player.name, serve.start_year, serve.end_year, \ + CASE serve.start_year > 2006 WHEN true THEN "new" ELSE "old" END, $$.team.name''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [ + ["Boris Diaw", 2003, 2005, "old", "Hawks"], + ["Boris Diaw", 2005, 2008, "old", "Suns"], + ["Boris Diaw", 2008, 2012, "new", "Hornets"], + ["Boris Diaw", 2012, 2016, "new", "Spurs"], + ["Boris Diaw", 2016, 2017, "new", "Jazz"] + ] + self.check_out_of_order_result(resp, expected_data) + + # # we are not able to deduce the return type of case expression in where_clause + # stmt = '''GO FROM "Rajon Rondo" OVER serve WHERE \ + # CASE serve.start_year WHEN 2016 THEN true ELSE false END YIELD \ + # $^.player.name, serve.start_year, serve.end_year, $$.team.name''' + # resp = self.execute_query(stmt) + # self.check_resp_succeeded(resp) + # expected_data = [ + # ["Rajon Rondo", 2016, 2017, "Bulls"], + # ] + # self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE WHEN 4 > 5 THEN 0 WHEN 3+4==7 THEN 1 ELSE 2 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[1]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE WHEN false THEN 0 ELSE 1 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[1]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'GO FROM "Tim Duncan" OVER serve YIELD $$.team.name as name, \ + CASE WHEN serve.start_year < 1998 THEN "old" ELSE "young" END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [['Spurs', 'old']] + self.check_out_of_order_result(resp, expected_data) + + # # we are not able to deduce the return type of case expression in where_clause + # stmt = '''GO FROM "Rajon Rondo" OVER serve WHERE \ + # CASE WHEN serve.start_year > 2016 THEN true ELSE false END YIELD \ + # $^.player.name, serve.start_year, serve.end_year, $$.team.name''' + # resp = self.execute_query(stmt) + # self.check_resp_succeeded(resp) + # expected_data = [ + # ["Rajon Rondo", 2018, 2019, "Lakers"], + # ["Rajon Rondo", 2017, 2018, "Pelicans"] + # ] + # self.check_out_of_order_result(resp, expected_data) + + def test_conditional_case_expression(self): + stmt = 'YIELD 3 > 5 ? 0 : 1' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[1]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD true ? "yes" : "no"' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [["yes"]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'GO FROM "Tim Duncan" OVER serve YIELD $$.team.name as name, \ + serve.start_year < 1998 ? "old" : "young"' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [['Spurs', 'old']] + self.check_out_of_order_result(resp, expected_data) + + # # we are not able to deduce the return type of case expression in where_clause + # stmt = '''GO FROM "Rajon Rondo" OVER serve WHERE \ + # serve.start_year > 2016 ? true : false YIELD \ + # $^.player.name, serve.start_year, serve.end_year, $$.team.name''' + # resp = self.execute_query(stmt) + # self.check_resp_succeeded(resp) + # expected_data = [ + # ["Rajon Rondo", 2018, 2019, "Lakers"], + # ["Rajon Rondo", 2017, 2018, "Pelicans"] + # ] + # self.check_out_of_order_result(resp, expected_data) + + def test_generic_with_conditional_case_expression(self): + stmt = '''YIELD CASE 2 + 3 WHEN CASE 1 WHEN 1 \ + THEN 5 ELSE 4 END THEN 0 WHEN 5 THEN 1 ELSE 2 END''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[0]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE 2 + 3 WHEN 5 THEN CASE 1 WHEN 1 THEN 7 ELSE 4 END ELSE 2 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[7]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE 2 + 3 WHEN 3 THEN 7 ELSE CASE 9 WHEN 8 THEN 10 ELSE 11 END END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[11]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE 3 > 2 ? 1 : 0 WHEN 1 THEN 5 ELSE 4 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[5]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE 1 WHEN true ? 1 : 0 THEN 5 ELSE 4 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[5]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE 1 WHEN 1 THEN 7 > 0 ? 6 : 9 ELSE 4 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[6]] + self.check_out_of_order_result(resp, expected_data) + + stmt = 'YIELD CASE 1 WHEN 2 THEN 6 ELSE false ? 4 : 9 END' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[9]] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''YIELD CASE WHEN 2 > 7 THEN false ? 3 : 8 \ + ELSE CASE true WHEN false THEN 9 ELSE 11 END END''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[11]] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''YIELD CASE 3 WHEN 4 THEN 5 ELSE 6 END \ + > 11 ? 7 : CASE WHEN true THEN 8 ELSE 9 END''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[8]] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''YIELD 8 > 11 ? CASE WHEN true THEN 8 ELSE 9 END : \ + CASE 14 WHEN 8+6 THEN 0 ELSE 1 END''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[0]] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''YIELD CASE 3 WHEN 4 THEN 5 ELSE 6 END > (3 > 2 ? 8 : 9) ? \ + CASE WHEN true THEN 8 ELSE 9 END : \ + CASE 14 WHEN 8+6 THEN 0 ELSE 1 END''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [[0]] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''GO FROM "Jonathon Simmons" OVER serve YIELD $$.team.name as name, \ + CASE serve.end_year > 2017 WHEN true THEN 2017 < 2020 ? "ok" : "no" \ + ELSE CASE WHEN false THEN "good" ELSE "bad" END END''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [['Spurs', 'bad'], ['Magic', 'ok'], ['76ers', 'ok']] + self.check_out_of_order_result(resp, expected_data) + + stmt = '''GO FROM "Boris Diaw" OVER serve YIELD \ + $^.player.name, serve.start_year, serve.end_year, \ + CASE serve.start_year > 2006 ? false : true \ + WHEN true THEN "new" ELSE CASE WHEN serve.start_year != 2012 THEN "old" \ + WHEN serve.start_year > 2009 THEN "bad" ELSE "good" END END, $$.team.name''' + resp = self.execute_query(stmt) + self.check_resp_succeeded(resp) + expected_data = [ + ["Boris Diaw", 2003, 2005, "new", "Hawks"], + ["Boris Diaw", 2005, 2008, "new", "Suns"], + ["Boris Diaw", 2008, 2012, "old", "Hornets"], + ["Boris Diaw", 2012, 2016, "bad", "Spurs"], + ["Boris Diaw", 2016, 2017, "old", "Jazz"] + ] + self.check_out_of_order_result(resp, expected_data) diff --git a/tests/query/v2/test_match.py b/tests/query/v2/test_match.py index baab080fcc63c1ad9c8685ba553be7b6ad22f6c5..341aac3c63ee530d5fa501a1c4093cd93af26e4d 100644 --- a/tests/query/v2/test_match.py +++ b/tests/query/v2/test_match.py @@ -161,12 +161,12 @@ class TestMatch(NebulaTestSuite): stmt = ''' MATCH (v1:player{name: "LeBron James"}) -[r:serve]-> (v2 {name: "Cavaliers"}) WHERE r.start_year <= 2005 AND r.end_year >= 2005 - RETURN r.start_year AS Start, r.end_year AS End + RETURN r.start_year AS Start_Year, r.end_year AS Start_Year ''' resp = self.execute_query(stmt) self.check_resp_succeeded(resp) expected = { - 'column_names': ['Start', 'End'], + 'column_names': ['Start_Year', 'Start_Year'], 'rows': [ [2003, 2010], ]