Skip to content
Snippets Groups Projects
Unverified Commit 4c85decf authored by jie.wang's avatar jie.wang Committed by GitHub
Browse files

Support case expression (#353)

* support case expression

* fix FoldConstant

* add more tests

* resolve conflicts
parent 4099cdb7
No related branches found
No related tags found
No related merge requests found
Showing
with 300 additions and 49 deletions
......@@ -104,6 +104,7 @@ reserved_key_words = [
'KW_RECOVER',
'KW_EXPLAIN',
'KW_UNWIND',
'KW_CASE',
]
......
......@@ -20,6 +20,7 @@
#include "common/expression/AttributeExpression.h"
#include "common/expression/LabelAttributeExpression.h"
#include "common/expression/VariableExpression.h"
#include "common/expression/CaseExpression.h"
#include "util/SchemaUtil.h"
namespace nebula {
......@@ -117,6 +118,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
MatchStepRange *match_step_range;
nebula::meta::cpp2::IndexFieldDef *index_field;
nebula::IndexFieldList *index_field_list;
CaseList *case_list;
}
/* destructors */
......@@ -156,11 +158,12 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%token KW_CONTAINS
%token KW_STARTS KW_ENDS
%token KW_UNWIND KW_SKIP KW_OPTIONAL
%token KW_CASE KW_THEN KW_ELSE KW_END
/* symbols */
%token L_PAREN R_PAREN L_BRACKET R_BRACKET L_BRACE R_BRACE COMMA
%token PIPE ASSIGN
%token DOT DOT_DOT COLON SEMICOLON L_ARROW R_ARROW AT
%token DOT DOT_DOT COLON QM SEMICOLON L_ARROW R_ARROW AT
%token ID_PROP TYPE_PROP SRC_ID_PROP DST_ID_PROP RANK_PROP INPUT_REF DST_REF SRC_REF
/* token type specification */
......@@ -177,6 +180,8 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%type <expr> edge_prop_expression
%type <expr> input_prop_expression
%type <expr> var_prop_expression
%type <expr> generic_case_expression
%type <expr> conditional_expression
%type <expr> vid_ref_expression
%type <expr> vid
%type <expr> function_call_expression
......@@ -187,6 +192,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%type <expr> container_expression
%type <expr> subscript_expression
%type <expr> attribute_expression
%type <expr> case_expression
%type <expr> compound_expression
%type <argument_list> argument_list opt_argument_list
%type <type> type_spec
......@@ -237,6 +243,9 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%type <both_in_out_clause> both_in_out_clause
%type <expression_list> expression_list
%type <map_item_list> map_item_list
%type <case_list> when_then_list
%type <expr> case_condition
%type <expr> case_default
%type <match_path> match_path_pattern
%type <match_path> match_path
......@@ -306,6 +315,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%type <boolval> opt_if_not_exists
%type <boolval> opt_if_exists
%left QM COLON
%left KW_OR KW_XOR
%left KW_AND
%right KW_NOT
......@@ -403,6 +413,9 @@ unreserved_keyword
| KW_BOTH { $$ = new std::string("both"); }
| KW_OUT { $$ = new std::string("out"); }
| KW_SUBGRAPH { $$ = new std::string("subgraph"); }
| KW_THEN { $$ = new std::string("then"); }
| KW_ELSE { $$ = new std::string("else"); }
| KW_END { $$ = new std::string("end"); }
;
agg_function
......@@ -529,6 +542,9 @@ expression
| expression KW_XOR expression {
$$ = new LogicalExpression(Expression::Kind::kLogicalXor, $1, $3);
}
| case_expression {
$$ = $1;
}
;
compound_expression
......@@ -589,6 +605,63 @@ attribute_expression
}
;
case_expression
: generic_case_expression {
$$ = $1;
}
| conditional_expression {
$$ = $1;
}
;
generic_case_expression
: KW_CASE case_condition when_then_list case_default KW_END {
auto expr = new CaseExpression($3);
expr->setCondition($2);
expr->setDefault($4);
$$ = expr;
}
;
conditional_expression
: expression QM expression COLON expression {
auto cases = new CaseList();
cases->add($1, $3);
auto expr = new CaseExpression(cases, false);
expr->setDefault($5);
$$ = expr;
}
;
case_condition
: %empty {
$$ = nullptr;
}
| expression {
$$ = $1;
}
;
case_default
: %empty {
$$ = nullptr;
}
| KW_ELSE expression {
$$ = $2;
}
;
when_then_list
: KW_WHEN expression KW_THEN expression {
$$ = new CaseList();
$$->add($2, $4);
}
| when_then_list KW_WHEN expression KW_THEN expression {
$1->add($3, $5);
$$ = $1;
}
;
input_prop_expression
: INPUT_REF DOT name_label {
$$ = new InputPropertyExpression($3);
......
......@@ -69,7 +69,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])
"LOOKUP" { return TokenType::KW_LOOKUP; }
"ALTER" { return TokenType::KW_ALTER; }
"STEPS" { return TokenType::KW_STEPS; }
"STEP" { return TokenType::KW_STEPS; }
"STEP" { return TokenType::KW_STEPS; }
"OVER" { return TokenType::KW_OVER; }
"UPTO" { return TokenType::KW_UPTO; }
"REVERSELY" { return TokenType::KW_REVERSELY; }
......@@ -134,6 +134,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])
"EXPLAIN" { return TokenType::KW_EXPLAIN; }
"PROFILE" { return TokenType::KW_PROFILE; }
"FORMAT" { return TokenType::KW_FORMAT; }
"CASE" { return TokenType::KW_CASE; }
/**
......@@ -212,6 +213,9 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])
"UNWIND" { return TokenType::KW_UNWIND;}
"SKIP" { return TokenType::KW_SKIP;}
"OPTIONAL" { return TokenType::KW_OPTIONAL;}
"THEN" { return TokenType::KW_THEN; }
"ELSE" { return TokenType::KW_ELSE; }
"END" { return TokenType::KW_END; }
"TRUE" { yylval->boolval = true; return TokenType::BOOL; }
......@@ -223,6 +227,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])
":" { return TokenType::COLON; }
";" { return TokenType::SEMICOLON; }
"@" { return TokenType::AT; }
"?" { return TokenType::QM; }
"+" { return TokenType::PLUS; }
"-" { return TokenType::MINUS; }
......
......@@ -311,6 +311,18 @@ TEST_F(ExpressionParsingTest, Associativity) {
make<ConstantExpression>(false)));
add("!!false", ast);
auto cases = new CaseList();
cases->add(make<ConstantExpression>(3), make<ConstantExpression>(4));
ast = make<CaseExpression>(cases);
static_cast<CaseExpression*>(ast)->setCondition(make<LabelExpression>("a"));
auto cases2 = new CaseList();
cases2->add(make<ConstantExpression>(5), make<ConstantExpression>(6));
auto ast2 = make<CaseExpression>(cases2);
ast2->setCondition(make<LabelExpression>("b"));
ast2->setDefault(make<ConstantExpression>(7));
static_cast<CaseExpression*>(ast)->setDefault(ast2);
add("CASE a WHEN 3 THEN 4 ELSE CASE b WHEN 5 THEN 6 ELSE 7 END END", ast);
run();
}
......
......@@ -127,6 +127,20 @@ void CollectAllExprsVisitor::visit(EdgeExpression *expr) {
collectExpr(expr);
}
void CollectAllExprsVisitor::visit(CaseExpression *expr) {
collectExpr(expr);
if (expr->hasCondition()) {
expr->condition()->accept(this);
}
if (expr->hasDefault()) {
expr->defaultResult()->accept(this);
}
for (const auto &whenThen : expr->cases()) {
whenThen.when->accept(this);
whenThen.then->accept(this);
}
}
void CollectAllExprsVisitor::visitBinaryExpr(BinaryExpression *expr) {
collectExpr(expr);
expr->left()->accept(this);
......
......@@ -55,8 +55,7 @@ private:
void visit(AttributeExpression* expr) override;
void visit(VertexExpression* expr) override;
void visit(EdgeExpression* expr) override;
// TODO : CaseExpression
void visit(CaseExpression *) override {};
void visit(CaseExpression* expr) override;
void visitBinaryExpr(BinaryExpression* expr) override;
void collectExpr(const Expression* expr);
......
......@@ -122,8 +122,6 @@ private:
void visit(ConstantExpression* expr) override;
void visit(VertexExpression* expr) override;
void visit(EdgeExpression* expr) override;
// TODO : CaseExpression
void visit(CaseExpression*) override {};
void visitEdgePropExpr(PropertyExpression* expr);
void reportError(const Expression* expr);
......
......@@ -377,6 +377,32 @@ void DeduceTypeVisitor::visit(EdgeExpression *) {
type_ = Value::Type::EDGE;
}
void DeduceTypeVisitor::visit(CaseExpression *expr) {
if (expr->hasCondition()) {
expr->condition()->accept(this);
if (!ok()) return;
}
if (expr->hasDefault()) {
expr->defaultResult()->accept(this);
if (!ok()) return;
}
for (const auto &whenThen : expr->cases()) {
whenThen.when->accept(this);
if (!ok()) return;
if (!expr->hasCondition() && type_ != Value::Type::BOOL) {
status_ = Status::SemanticError(
"`%s': Invalid expression type, expecting expression of type BOOL",
expr->toString().c_str());
return;
}
whenThen.then->accept(this);
if (!ok()) return;
}
// NOTE: we are not able to deduce the return type of case expression currently
type_ = Value::Type::__EMPTY__;
}
void DeduceTypeVisitor::visitVertexPropertyExpr(PropertyExpression *expr) {
auto *tag = expr->sym();
auto tagId = qctx_->schemaMng()->toTagID(space_, *tag);
......
......@@ -73,8 +73,8 @@ private:
// vertex/edge expression
void visit(VertexExpression *expr) override;
void visit(EdgeExpression *expr) override;
// TODO : CaseExpression
void visit(CaseExpression *) override {};
// case expression
void visit(CaseExpression *expr) override;
void visitVertexPropertyExpr(PropertyExpression *expr);
......
......@@ -89,11 +89,6 @@ private:
isEvaluable_ = false;
}
// TODO : CaseExpression
void visit(CaseExpression *) override {
isEvaluable_ = false;
}
bool isEvaluable_{true};
};
......
......@@ -90,6 +90,33 @@ void ExprVisitorImpl::visit(MapExpression *expr) {
}
}
// case expression
void ExprVisitorImpl::visit(CaseExpression *expr) {
DCHECK(ok());
if (expr->hasCondition()) {
expr->condition()->accept(this);
if (!ok()) {
return;
}
}
if (expr->hasDefault()) {
expr->defaultResult()->accept(this);
if (!ok()) {
return;
}
}
for (const auto &whenThen : expr->cases()) {
whenThen.when->accept(this);
if (!ok()) {
break;
}
whenThen.then->accept(this);
if (!ok()) {
break;
}
}
}
void ExprVisitorImpl::visitBinaryExpr(BinaryExpression *expr) {
DCHECK(ok());
expr->left()->accept(this);
......
......@@ -29,6 +29,8 @@ public:
void visit(ListExpression *expr) override;
void visit(SetExpression *expr) override;
void visit(MapExpression *expr) override;
// case expression
void visit(CaseExpression *expr) override;
protected:
using ExprVisitor::visit;
......
......@@ -46,8 +46,6 @@ private:
void visit(VertexExpression *) override;
void visit(EdgeExpression *) override;
void visit(LogicalExpression *) override;
// TODO : CaseExpression
void visit(CaseExpression *) override {};
bool canBePushed_{true};
std::unique_ptr<Expression> remainedExpr_;
......
......@@ -192,7 +192,7 @@ void ExtractPropExprVisitor::visit(DestPropertyExpression* expr) {
}
}
void ExtractPropExprVisitor::reportError(const Expression *expr) {
void ExtractPropExprVisitor::reportError(const Expression* expr) {
std::stringstream ss;
ss << "Not supported expression `" << expr->toString() << "' for ExtractPropsExpression.";
status_ = Status::SemanticError(ss.str());
......
......@@ -60,8 +60,6 @@ private:
void visit(EdgeExpression *) override;
// binary expression
void visit(SubscriptExpression *) override;
// TODO : CaseExpression
void visit(CaseExpression *) override {};
void visitVertexEdgePropExpr(PropertyExpression *);
void visitPropertyExpr(PropertyExpression *);
......
......@@ -62,6 +62,25 @@ void FindAnyExprVisitor::visit(MapExpression *expr) {
}
}
void FindAnyExprVisitor::visit(CaseExpression *expr) {
findExpr(expr);
if (found_) return;
if (expr->hasCondition()) {
expr->condition()->accept(this);
if (found_) return;
}
if (expr->hasDefault()) {
expr->defaultResult()->accept(this);
if (found_) return;
}
for (const auto &whenThen : expr->cases()) {
whenThen.when->accept(this);
if (found_) return;
whenThen.then->accept(this);
if (found_) return;
}
}
void FindAnyExprVisitor::visit(ConstantExpression *expr) {
findExpr(expr);
}
......
......@@ -37,6 +37,7 @@ private:
void visit(ListExpression* expr) override;
void visit(SetExpression* expr) override;
void visit(MapExpression* expr) override;
void visit(CaseExpression* expr) override;
void visit(ConstantExpression* expr) override;
void visit(EdgePropertyExpression* expr) override;
......@@ -55,8 +56,6 @@ private:
void visit(LabelExpression* expr) override;
void visit(VertexExpression* expr) override;
void visit(EdgeExpression* expr) override;
// TODO : CaseExpression
void visit(CaseExpression*) override {};
void visitBinaryExpr(BinaryExpression* expr) override;
......
......@@ -19,16 +19,20 @@ void FoldConstantExprVisitor::visit(ConstantExpression *expr) {
}
void FoldConstantExprVisitor::visit(UnaryExpression *expr) {
expr->operand()->accept(this);
if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) {
expr->setOperand(fold(expr->operand()));
if (!isConstant(expr->operand())) {
expr->operand()->accept(this);
if (canBeFolded_) {
expr->setOperand(fold(expr->operand()));
}
}
}
void FoldConstantExprVisitor::visit(TypeCastingExpression *expr) {
expr->operand()->accept(this);
if (canBeFolded_ && expr->operand()->kind() != Expression::Kind::kConstant) {
expr->setOperand(fold(expr->operand()));
if (!isConstant(expr->operand())) {
expr->operand()->accept(this);
if (canBeFolded_) {
expr->setOperand(fold(expr->operand()));
}
}
}
......@@ -68,7 +72,7 @@ void FoldConstantExprVisitor::visit(LogicalExpression *expr) {
void FoldConstantExprVisitor::visit(FunctionCallExpression *expr) {
bool canBeFolded = true;
for (auto &arg : expr->args()->args()) {
if (arg->kind() != Expression::Kind::kConstant) {
if (!isConstant(arg.get())) {
arg->accept(this);
if (canBeFolded_) {
arg.reset(fold(arg.get()));
......@@ -110,13 +114,14 @@ void FoldConstantExprVisitor::visit(ListExpression *expr) {
bool canBeFolded = true;
for (size_t i = 0; i < items.size(); ++i) {
auto item = items[i].get();
item->accept(this);
if (!canBeFolded_) {
canBeFolded = false;
if (isConstant(item)) {
continue;
}
if (item->kind() != Expression::Kind::kConstant) {
item->accept(this);
if (canBeFolded_) {
expr->setItem(i, std::unique_ptr<Expression>{fold(item)});
} else {
canBeFolded = false;
}
}
canBeFolded_ = canBeFolded;
......@@ -127,13 +132,14 @@ void FoldConstantExprVisitor::visit(SetExpression *expr) {
bool canBeFolded = true;
for (size_t i = 0; i < items.size(); ++i) {
auto item = items[i].get();
item->accept(this);
if (!canBeFolded_) {
canBeFolded = false;
if (isConstant(item)) {
continue;
}
if (item->kind() != Expression::Kind::kConstant) {
item->accept(this);
if (canBeFolded_) {
expr->setItem(i, std::unique_ptr<Expression>{fold(item)});
} else {
canBeFolded = false;
}
}
canBeFolded_ = canBeFolded;
......@@ -145,15 +151,59 @@ void FoldConstantExprVisitor::visit(MapExpression *expr) {
for (size_t i = 0; i < items.size(); ++i) {
auto &pair = items[i];
auto item = const_cast<Expression *>(pair.second.get());
item->accept(this);
if (!canBeFolded_) {
canBeFolded = false;
if (isConstant(item)) {
continue;
}
if (item->kind() != Expression::Kind::kConstant) {
item->accept(this);
if (canBeFolded_) {
auto key = std::make_unique<std::string>(*pair.first);
auto val = std::unique_ptr<Expression>(fold(item));
expr->setItem(i, std::make_pair(std::move(key), std::move(val)));
} else {
canBeFolded = false;
}
}
canBeFolded_ = canBeFolded;
}
// case Expression
void FoldConstantExprVisitor::visit(CaseExpression *expr) {
bool canBeFolded = true;
if (expr->hasCondition() && !isConstant(expr->condition())) {
expr->condition()->accept(this);
if (canBeFolded_) {
expr->setCondition(fold(expr->condition()));
} else {
canBeFolded = false;
}
}
if (expr->hasDefault() && !isConstant(expr->defaultResult())) {
expr->defaultResult()->accept(this);
if (canBeFolded_) {
expr->setDefault(fold(expr->defaultResult()));
} else {
canBeFolded = false;
}
}
auto &cases = expr->cases();
for (size_t i = 0; i < cases.size(); ++i) {
auto when = cases[i].when.get();
auto then = cases[i].then.get();
if (!isConstant(when)) {
when->accept(this);
if (canBeFolded_) {
expr->setWhen(i, fold(when));
} else {
canBeFolded = false;
}
}
if (!isConstant(then)) {
then->accept(this);
if (canBeFolded_) {
expr->setThen(i, fold(then));
} else {
canBeFolded = false;
}
}
}
canBeFolded_ = canBeFolded;
......@@ -222,15 +272,20 @@ void FoldConstantExprVisitor::visit(EdgeExpression *expr) {
}
void FoldConstantExprVisitor::visitBinaryExpr(BinaryExpression *expr) {
expr->left()->accept(this);
auto leftCanBeFolded = canBeFolded_;
if (leftCanBeFolded && expr->left()->kind() != Expression::Kind::kConstant) {
expr->setLeft(fold(expr->left()));
bool leftCanBeFolded = true, rightCanBeFolded = true;
if (!isConstant(expr->left())) {
expr->left()->accept(this);
leftCanBeFolded = canBeFolded_;
if (leftCanBeFolded) {
expr->setLeft(fold(expr->left()));
}
}
expr->right()->accept(this);
auto rightCanBeFolded = canBeFolded_;
if (rightCanBeFolded && expr->right()->kind() != Expression::Kind::kConstant) {
expr->setRight(fold(expr->right()));
if (!isConstant(expr->right())) {
expr->right()->accept(this);
rightCanBeFolded = canBeFolded_;
if (rightCanBeFolded) {
expr->setRight(fold(expr->right()));
}
}
canBeFolded_ = leftCanBeFolded && rightCanBeFolded;
}
......
......@@ -18,6 +18,10 @@ public:
return canBeFolded_;
}
bool isConstant(Expression *expr) const {
return expr->kind() == Expression::Kind::kConstant;
}
void visit(ConstantExpression *expr) override;
void visit(UnaryExpression *expr) override;
void visit(TypeCastingExpression *expr) override;
......@@ -53,8 +57,8 @@ public:
// vertex/edge expression
void visit(VertexExpression *expr) override;
void visit(EdgeExpression *expr) override;
// TODO : CaseExpression
void visit(CaseExpression*) override {};
// case expression
void visit(CaseExpression *expr) override;
void visitBinaryExpr(BinaryExpression *expr);
Expression *fold(Expression *expr) const;
......
......@@ -168,6 +168,32 @@ void RewriteInputPropVisitor::visit(TypeCastingExpression* expr) {
}
}
void RewriteInputPropVisitor::visit(CaseExpression* expr) {
if (expr->hasCondition()) {
expr->condition()->accept(this);
if (ok()) {
expr->setCondition(result_.release());
}
}
if (expr->hasDefault()) {
expr->defaultResult()->accept(this);
if (ok()) {
expr->setDefault(result_.release());
}
}
for (size_t i = 0; i < expr->cases().size(); ++i) {
const auto& whenThen = expr->cases()[i];
whenThen.when->accept(this);
if (ok()) {
expr->setWhen(i, result_.release());
}
whenThen.then->accept(this);
if (ok()) {
expr->setThen(i, result_.release());
}
}
}
void RewriteInputPropVisitor::visitBinaryExpr(BinaryExpression* expr) {
expr->left()->accept(this);
if (ok()) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment