Skip to content
Snippets Groups Projects
Unverified Commit d8c433e2 authored by jie.wang's avatar jie.wang Committed by GitHub
Browse files

Support reduce expression (#479)

* add reduce expression

* format tests
parent e59523f4
No related branches found
No related tags found
No related merge requests found
Showing
with 241 additions and 15 deletions
......@@ -27,6 +27,9 @@
#include "common/expression/ListComprehensionExpression.h"
#include "common/expression/AggregateExpression.h"
#include "common/expression/ReduceExpression.h"
#include "util/ParserUtil.h"
#include "context/QueryContext.h"
#include "util/SchemaUtil.h"
#include "util/ParserUtil.h"
#include "context/QueryContext.h"
......@@ -184,6 +187,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%token KW_AUTO KW_FUZZY KW_PREFIX KW_REGEXP KW_WILDCARD
%token KW_TEXT KW_SEARCH KW_CLIENTS KW_SIGN KW_SERVICE KW_TEXT_SEARCH
%token KW_ANY KW_SINGLE KW_NONE
%token KW_REDUCE
/* symbols */
%token L_PAREN R_PAREN L_BRACKET R_BRACKET L_BRACE R_BRACE COMMA
......@@ -219,6 +223,7 @@ static constexpr size_t MAX_ABS_INTEGER = 9223372036854775808ULL;
%type <expr> case_expression
%type <expr> predicate_expression
%type <expr> list_comprehension_expression
%type <expr> reduce_expression
%type <expr> compound_expression
%type <expr> aggregate_expression
%type <expr> text_search_expression
......@@ -461,6 +466,7 @@ unreserved_keyword
| KW_ANY { $$ = new std::string("any"); }
| KW_SINGLE { $$ = new std::string("single"); }
| KW_NONE { $$ = new std::string("none"); }
| KW_REDUCE { $$ = new std::string("reduce"); }
| KW_SHORTEST { $$ = new std::string("shortest"); }
| KW_NOLOOP { $$ = new std::string("noloop"); }
| KW_COUNT_DISTINCT { $$ = new std::string("count_distinct"); }
......@@ -643,6 +649,9 @@ expression
| list_comprehension_expression {
$$ = $1;
}
| reduce_expression {
$$ = $1;
}
;
compound_expression
......@@ -780,6 +789,7 @@ predicate_expression
$$ = expr;
delete $3;
}
;
list_comprehension_expression
: L_BRACKET expression KW_IN expression KW_WHERE expression R_BRACKET {
......@@ -814,6 +824,14 @@ list_comprehension_expression
}
;
reduce_expression
: KW_REDUCE L_PAREN name_label ASSIGN expression COMMA name_label KW_IN expression PIPE expression R_PAREN {
auto *expr = new ReduceExpression($3, $5, $7, $9, $11);
nebula::graph::ParserUtil::rewriteReduce(qctx, expr, *$3, *$7);
$$ = expr;
}
;
input_prop_expression
: INPUT_REF DOT name_label {
$$ = new InputPropertyExpression($3);
......
......@@ -169,6 +169,7 @@ IP_OCTET ([0-9]|[1-9][0-9]|1[0-9][0-9]|2[0-4][0-9]|25[0-5])
"ANY" { return TokenType::KW_ANY; }
"SINGLE" { return TokenType::KW_SINGLE; }
"NONE" { return TokenType::KW_NONE; }
"REDUCE" { return TokenType::KW_REDUCE; }
"LEADER" { return TokenType::KW_LEADER; }
"UUID" { return TokenType::KW_UUID; }
"DATA" { return TokenType::KW_DATA; }
......
......@@ -498,6 +498,8 @@ TEST(Scanner, Basic) {
CHECK_SEMANTIC_TYPE("else", TokenType::KW_ELSE),
CHECK_SEMANTIC_TYPE("END", TokenType::KW_END),
CHECK_SEMANTIC_TYPE("end", TokenType::KW_END),
CHECK_SEMANTIC_TYPE("REDUCE", TokenType::KW_REDUCE),
CHECK_SEMANTIC_TYPE("reduce", TokenType::KW_REDUCE),
CHECK_SEMANTIC_TYPE("_type", TokenType::TYPE_PROP),
CHECK_SEMANTIC_TYPE("_id", TokenType::ID_PROP),
......
......@@ -125,6 +125,63 @@ public:
}
pred->setFilter(newFilter);
}
static void rewriteReduce(QueryContext *qctx,
ReduceExpression *reduce,
const std::string &oldAccName,
const std::string &oldVarName) {
const auto &newAccName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newAccName, Value());
const auto &newVarName = qctx->vctx()->anonVarGen()->getVar();
qctx->ectx()->setValue(newVarName, Value());
auto rewriter = [oldAccName, newAccName, oldVarName, newVarName](
const Expression *expr) {
Expression *ret = nullptr;
if (expr->kind() == Expression::Kind::kLabel) {
auto *label = static_cast<const LabelExpression *>(expr);
if (*label->name() == oldAccName) {
ret = new VariableExpression(new std::string(newAccName));
} else if (*label->name() == oldVarName) {
ret = new VariableExpression(new std::string(newVarName));
} else {
ret = label->clone().release();
}
} else {
DCHECK(expr->kind() == Expression::Kind::kLabelAttribute);
auto *la = static_cast<const LabelAttributeExpression *>(expr);
if (*la->left()->name() == oldAccName) {
const auto &value = la->right()->value();
ret =
new AttributeExpression(new VariableExpression(new std::string(newAccName)),
new ConstantExpression(value));
} else if (*la->left()->name() == oldVarName) {
const auto &value = la->right()->value();
ret =
new AttributeExpression(new VariableExpression(new std::string(newVarName)),
new ConstantExpression(value));
} else {
ret = la->clone().release();
}
}
return ret;
};
RewriteMatchLabelVisitor visitor(rewriter);
reduce->setOriginString(new std::string(reduce->makeString()));
reduce->setAccumulator(new std::string(newAccName));
reduce->setInnerVar(new std::string(newVarName));
Expression *mapping = reduce->mapping();
Expression *newMapping = nullptr;
if (isLabel(mapping)) {
newMapping = rewriter(mapping);
} else {
newMapping = mapping->clone().release();
newMapping->accept(&visitor);
}
reduce->setMapping(newMapping);
}
};
} // namespace graph
......
......@@ -163,6 +163,13 @@ void CollectAllExprsVisitor::visit(PredicateExpression *expr) {
expr->filter()->accept(this);
}
void CollectAllExprsVisitor::visit(ReduceExpression *expr) {
collectExpr(expr);
expr->initial()->accept(this);
expr->collection()->accept(this);
expr->mapping()->accept(this);
}
void CollectAllExprsVisitor::visitBinaryExpr(BinaryExpression *expr) {
collectExpr(expr);
expr->left()->accept(this);
......
......@@ -58,6 +58,7 @@ private:
void visit(CaseExpression* expr) override;
void visit(PredicateExpression* expr) override;
void visit(ListComprehensionExpression* expr) override;
void visit(ReduceExpression* expr) override;
void visit(ColumnExpression* expr) override;
......
......@@ -602,6 +602,9 @@ void DeduceTypeVisitor::visit(CaseExpression *expr) {
}
void DeduceTypeVisitor::visit(PredicateExpression *expr) {
expr->filter()->accept(this);
if (!ok()) return;
expr->collection()->accept(this);
if (!ok()) return;
if (type_ == Value::Type::NULLVALUE || type_ == Value::Type::__EMPTY__) {
......@@ -610,17 +613,28 @@ void DeduceTypeVisitor::visit(PredicateExpression *expr) {
if (type_ != Value::Type::LIST) {
std::stringstream ss;
ss << "`" << expr->toString().c_str()
<< "': Invalid colletion type, expected type of LIST, but was:" << type_;
<< "': Invalid colletion type, expected type of LIST, but was: " << type_;
status_ = Status::SemanticError(ss.str());
return;
}
expr->filter()->accept(this);
if (!ok()) return;
type_ = Value::Type::BOOL;
}
void DeduceTypeVisitor::visit(ListComprehensionExpression *expr) {
if (expr->hasFilter()) {
expr->filter()->accept(this);
if (!ok()) {
return;
}
}
if (expr->hasMapping()) {
expr->mapping()->accept(this);
if (!ok()) {
return;
}
}
expr->collection()->accept(this);
if (!ok()) {
return;
......@@ -633,25 +647,37 @@ void DeduceTypeVisitor::visit(ListComprehensionExpression *expr) {
if (type_ != Value::Type::LIST) {
std::stringstream ss;
ss << "`" << expr->toString().c_str()
<< "': Invalid colletion type, expected type of LIST, but was:" << type_;
<< "': Invalid colletion type, expected type of LIST, but was: " << type_;
status_ = Status::SemanticError(ss.str());
return;
}
if (expr->hasFilter()) {
expr->filter()->accept(this);
if (!ok()) {
return;
}
type_ = Value::Type::LIST;
}
void DeduceTypeVisitor::visit(ReduceExpression *expr) {
expr->initial()->accept(this);
if (!ok()) return;
expr->mapping()->accept(this);
if (!ok()) return;
expr->collection()->accept(this);
if (!ok()) return;
if (type_ == Value::Type::NULLVALUE || type_ == Value::Type::__EMPTY__) {
return;
}
if (expr->hasMapping()) {
expr->mapping()->accept(this);
if (!ok()) {
return;
}
if (type_ != Value::Type::LIST) {
std::stringstream ss;
ss << "`" << expr->toString().c_str()
<< "': Invalid colletion type, expected type of LIST, but was: " << type_;
status_ = Status::SemanticError(ss.str());
return;
}
type_ = Value::Type::LIST;
// Will not deduce the actual value type returned by reduce expression.
type_ = Value::Type::__EMPTY__;
}
void DeduceTypeVisitor::visitVertexPropertyExpr(PropertyExpression *expr) {
......
......@@ -88,6 +88,8 @@ private:
void visit(PredicateExpression *expr) override;
// list comprehension expression
void visit(ListComprehensionExpression *) override;
// reduce expression
void visit(ReduceExpression *expr) override;
void visitVertexPropertyExpr(PropertyExpression *expr);
......
......@@ -167,5 +167,15 @@ void ExprVisitorImpl::visit(ListComprehensionExpression *expr) {
}
}
void ExprVisitorImpl::visit(ReduceExpression *expr) {
DCHECK(ok());
expr->initial()->accept(this);
if (!ok()) return;
expr->collection()->accept(this);
if (!ok()) return;
expr->mapping()->accept(this);
if (!ok()) return;
}
} // namespace graph
} // namespace nebula
......@@ -38,6 +38,8 @@ public:
void visit(PredicateExpression *expr) override;
// list comprehension expression
void visit(ListComprehensionExpression *expr) override;
// reduce expression
void visit(ReduceExpression *expr) override;
protected:
using ExprVisitor::visit;
......
......@@ -93,6 +93,16 @@ void FindAnyExprVisitor::visit(PredicateExpression *expr) {
expr->collection()->accept(this);
if (found_) return;
expr->filter()->accept(this);
}
void FindAnyExprVisitor::visit(ReduceExpression *expr) {
findExpr(expr);
if (found_) return;
expr->initial()->accept(this);
if (found_) return;
expr->collection()->accept(this);
if (found_) return;
expr->mapping()->accept(this);
if (found_) return;
}
......
......@@ -40,6 +40,7 @@ private:
void visit(MapExpression* expr) override;
void visit(CaseExpression* expr) override;
void visit(PredicateExpression* expr) override;
void visit(ReduceExpression* expr) override;
void visit(ConstantExpression* expr) override;
void visit(EdgePropertyExpression* expr) override;
......
......@@ -402,5 +402,34 @@ void FoldConstantExprVisitor::visit(PredicateExpression *expr) {
canBeFolded_ = canBeFolded;
}
void FoldConstantExprVisitor::visit(ReduceExpression *expr) {
bool canBeFolded = true;
if (!isConstant(expr->initial())) {
expr->initial()->accept(this);
if (canBeFolded_) {
expr->setInitial(fold(expr->initial()));
} else {
canBeFolded = false;
}
}
if (!isConstant(expr->collection())) {
expr->collection()->accept(this);
if (canBeFolded_) {
expr->setCollection(fold(expr->collection()));
} else {
canBeFolded = false;
}
}
if (!isConstant(expr->mapping())) {
expr->mapping()->accept(this);
if (canBeFolded_) {
expr->setMapping(fold(expr->mapping()));
} else {
canBeFolded = false;
}
}
canBeFolded_ = canBeFolded;
}
} // namespace graph
} // namespace nebula
......@@ -68,6 +68,8 @@ public:
void visit(PredicateExpression *expr) override;
// list comprehension expression
void visit(ListComprehensionExpression *) override;
// reduce expression
void visit(ReduceExpression *expr) override;
void visitBinaryExpr(BinaryExpression *expr);
Expression *fold(Expression *expr) const;
......
......@@ -283,5 +283,20 @@ void RewriteInputPropVisitor::visit(PredicateExpression* expr) {
}
}
void RewriteInputPropVisitor::visit(ReduceExpression* expr) {
expr->initial()->accept(this);
if (ok()) {
expr->setInitial(result_.release());
}
expr->collection()->accept(this);
if (ok()) {
expr->setCollection(result_.release());
}
expr->mapping()->accept(this);
if (ok()) {
expr->setMapping(result_.release());
}
}
} // namespace graph
} // namespace nebula
......@@ -80,6 +80,8 @@ private:
void visit(PredicateExpression *) override;
// list comprehension expression
void visit(ListComprehensionExpression *) override;
// reduce expression
void visit(ReduceExpression* expr) override;
void visitBinaryExpr(BinaryExpression *expr);
void visitUnaryExpr(UnaryExpression *expr);
......
......@@ -160,6 +160,27 @@ void RewriteLabelAttrVisitor::visit(PredicateExpression* expr) {
}
}
void RewriteLabelAttrVisitor::visit(ReduceExpression* expr) {
if (isLabelAttrExpr(expr->initial())) {
auto newExpr = static_cast<LabelAttributeExpression*>(expr->initial());
expr->setCollection(createExpr(newExpr));
} else {
expr->initial()->accept(this);
}
if (isLabelAttrExpr(expr->collection())) {
auto newExpr = static_cast<LabelAttributeExpression*>(expr->collection());
expr->setCollection(createExpr(newExpr));
} else {
expr->collection()->accept(this);
}
if (isLabelAttrExpr(expr->mapping())) {
auto newExpr = static_cast<LabelAttributeExpression*>(expr->mapping());
expr->setMapping(createExpr(newExpr));
} else {
expr->mapping()->accept(this);
}
}
void RewriteLabelAttrVisitor::visitBinaryExpr(BinaryExpression* expr) {
if (isLabelAttrExpr(expr->left())) {
auto left = static_cast<const LabelAttributeExpression*>(expr->left());
......
......@@ -35,6 +35,7 @@ private:
void visit(CaseExpression *) override;
void visit(PredicateExpression *) override;
void visit(ListComprehensionExpression *) override;
void visit(ReduceExpression * expr) override;
void visit(ConstantExpression *) override {}
void visit(LabelExpression *) override {}
void visit(UUIDExpression *) override {}
......
......@@ -164,6 +164,24 @@ void RewriteMatchLabelVisitor::visit(PredicateExpression *expr) {
}
}
void RewriteMatchLabelVisitor::visit(ReduceExpression *expr) {
if (isLabel(expr->initial())) {
expr->setInitial(rewriter_(expr));
} else {
expr->initial()->accept(this);
}
if (isLabel(expr->collection())) {
expr->setCollection(rewriter_(expr));
} else {
expr->collection()->accept(this);
}
if (isLabel(expr->mapping())) {
expr->setMapping(rewriter_(expr));
} else {
expr->mapping()->accept(this);
}
}
void RewriteMatchLabelVisitor::visitBinaryExpr(BinaryExpression *expr) {
if (isLabel(expr->left())) {
expr->setLeft(rewriter_(expr->left()));
......
......@@ -41,6 +41,7 @@ private:
void visit(SetExpression*) override;
void visit(MapExpression*) override;
void visit(CaseExpression *) override;
void visit(ReduceExpression *) override;
void visit(ConstantExpression *) override {}
void visit(LabelExpression*) override {}
void visit(AttributeExpression*) override;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment