From 93f4c0a22172040b5870c5174d59325e6725a3ec Mon Sep 17 00:00:00 2001 From: dutor <440396+dutor@users.noreply.github.com> Date: Mon, 9 Nov 2020 17:59:41 +0800 Subject: [PATCH] Flatten LogicalExpression (#400) * Flatten LogicalExpression * remove left/right interface Co-authored-by: cpw <13495049+CPWstatic@users.noreply.github.com> --- src/optimizer/rule/IndexScanRule.cpp | 5 ++-- src/util/ExpressionUtils.cpp | 26 ++++++++++---------- src/validator/IndexScanValidator.cpp | 5 ++-- src/visitor/DeduceTypeVisitor.cpp | 28 ++++++++++++++++++++-- src/visitor/ExprVisitorImpl.cpp | 7 +++++- src/visitor/ExtractFilterExprVisitor.cpp | 30 ++++++++++++++---------- src/visitor/FoldConstantExprVisitor.cpp | 26 +++++++++++++++++++- src/visitor/RewriteInputPropVisitor.cpp | 8 ++++++- src/visitor/RewriteSymExprVisitor.cpp | 8 ++++++- 9 files changed, 109 insertions(+), 34 deletions(-) diff --git a/src/optimizer/rule/IndexScanRule.cpp b/src/optimizer/rule/IndexScanRule.cpp index 9a50f3d5..2c6e599a 100644 --- a/src/optimizer/rule/IndexScanRule.cpp +++ b/src/optimizer/rule/IndexScanRule.cpp @@ -308,8 +308,9 @@ Status IndexScanRule::analyzeExpression(Expression* expr, Expression::encode(*expr).c_str()); return Status::NotSupported(errorMsg); } - NG_RETURN_IF_ERROR(analyzeExpression(lExpr->left(), items, kind, isEdge)); - NG_RETURN_IF_ERROR(analyzeExpression(lExpr->right(), items, kind, isEdge)); + // TODO(dutor) Deal with n-ary operands + NG_RETURN_IF_ERROR(analyzeExpression(lExpr->operand(0), items, kind, isEdge)); + NG_RETURN_IF_ERROR(analyzeExpression(lExpr->operand(1), items, kind, isEdge)); break; } case Expression::Kind::kRelLE: diff --git a/src/util/ExpressionUtils.cpp b/src/util/ExpressionUtils.cpp index 10a4cc31..82db2976 100644 --- a/src/util/ExpressionUtils.cpp +++ b/src/util/ExpressionUtils.cpp @@ -26,17 +26,18 @@ std::vector<const Expression*> ExpressionUtils::pullAnds(const Expression *expr) auto *root = static_cast<const LogicalExpression*>(expr); std::vector<const Expression*> operands; - if (root->left()->kind() != Expression::Kind::kLogicalAnd) { - operands.emplace_back(root->left()); + // TODO(dutor) Deal with n-ary operands + if (root->operand(0)->kind() != Expression::Kind::kLogicalAnd) { + operands.emplace_back(root->operand(0)); } else { - auto ands = pullAnds(root->left()); + auto ands = pullAnds(root->operand(0)); operands.insert(operands.end(), ands.begin(), ands.end()); } - if (root->right()->kind() != Expression::Kind::kLogicalAnd) { - operands.emplace_back(root->right()); + if (root->operand(1)->kind() != Expression::Kind::kLogicalAnd) { + operands.emplace_back(root->operand(1)); } else { - auto ands = pullAnds(root->right()); + auto ands = pullAnds(root->operand(1)); operands.insert(operands.end(), ands.begin(), ands.end()); } @@ -48,17 +49,18 @@ std::vector<const Expression*> ExpressionUtils::pullOrs(const Expression *expr) auto *root = static_cast<const LogicalExpression*>(expr); std::vector<const Expression*> operands; - if (root->left()->kind() != Expression::Kind::kLogicalOr) { - operands.emplace_back(root->left()); + // TODO(dutor) Deal with n-ary operands + if (root->operand(0)->kind() != Expression::Kind::kLogicalOr) { + operands.emplace_back(root->operand(0)); } else { - auto ands = pullOrs(root->left()); + auto ands = pullOrs(root->operand(0)); operands.insert(operands.end(), ands.begin(), ands.end()); } - if (root->right()->kind() != Expression::Kind::kLogicalOr) { - operands.emplace_back(root->right()); + if (root->operand(1)->kind() != Expression::Kind::kLogicalOr) { + operands.emplace_back(root->operand(1)); } else { - auto ands = pullOrs(root->right()); + auto ands = pullOrs(root->operand(1)); operands.insert(operands.end(), ands.begin(), ands.end()); } diff --git a/src/validator/IndexScanValidator.cpp b/src/validator/IndexScanValidator.cpp index def908cc..75809325 100644 --- a/src/validator/IndexScanValidator.cpp +++ b/src/validator/IndexScanValidator.cpp @@ -98,10 +98,11 @@ Status IndexScanValidator::checkFilter(Expression* expr, const std::string& from switch (expr->kind()) { case Expression::Kind::kLogicalOr : case Expression::Kind::kLogicalAnd : { + // TODO(dutor) Deal with n-ary operands auto lExpr = static_cast<LogicalExpression*>(expr); - auto ret = checkFilter(lExpr->left(), from); + auto ret = checkFilter(lExpr->operand(0), from); NG_RETURN_IF_ERROR(ret); - ret = checkFilter(lExpr->right(), from); + ret = checkFilter(lExpr->operand(1), from); NG_RETURN_IF_ERROR(ret); break; } diff --git a/src/visitor/DeduceTypeVisitor.cpp b/src/visitor/DeduceTypeVisitor.cpp index e72e86d8..512c1a18 100644 --- a/src/visitor/DeduceTypeVisitor.cpp +++ b/src/visitor/DeduceTypeVisitor.cpp @@ -60,6 +60,29 @@ static const std::unordered_map<Value::Type, Value> kConstantValues = { } \ type_ = detectVal.type() +#define DETECT_NARYEXPR_TYPE(OP) \ + do { \ + auto &operands = expr->operands(); \ + operands[0]->accept(this); \ + if (!ok()) return; \ + auto prev = type_; \ + for (auto i = 1u; i < operands.size(); i++) { \ + operands[i]->accept(this); \ + if (!ok()) return; \ + auto current = type_; \ + auto detectValue = kConstantValues.at(prev) OP kConstantValues.at(current); \ + if (detectValue.isBadNull()) { \ + std::stringstream ss; \ + ss << "`" << expr->toString() << "' is not a valid expression, " \ + << "can not apply `" << #OP << "' to `" << prev << "' and `" << current << "'."; \ + status_ = Status::SemanticError(ss.str()); \ + return; \ + } \ + prev = detectValue.type(); \ + } \ + type_ = prev; \ + } while (false) + #define DETECT_UNARYEXPR_TYPE(OP) \ auto detectVal = OP kConstantValues.at(type_); \ if (detectVal.isBadNull()) { \ @@ -301,12 +324,12 @@ void DeduceTypeVisitor::visit(AttributeExpression *expr) { void DeduceTypeVisitor::visit(LogicalExpression *expr) { switch (expr->kind()) { case Expression::Kind::kLogicalAnd: { - DETECT_BIEXPR_TYPE(&&); + DETECT_NARYEXPR_TYPE(&&); break; } case Expression::Kind::kLogicalXor: case Expression::Kind::kLogicalOr: { - DETECT_BIEXPR_TYPE(||); + DETECT_NARYEXPR_TYPE(||); break; } default: { @@ -529,6 +552,7 @@ void DeduceTypeVisitor::visitVertexPropertyExpr(PropertyExpression *expr) { void DeduceTypeVisitor::visit(PathBuildExpression *) { type_ = Value::Type::PATH; } +#undef DETECT_NARYEXPR_TYPE #undef DETECT_UNARYEXPR_TYPE #undef DETECT_BIEXPR_TYPE diff --git a/src/visitor/ExprVisitorImpl.cpp b/src/visitor/ExprVisitorImpl.cpp index 118fe48c..ca6fea86 100644 --- a/src/visitor/ExprVisitorImpl.cpp +++ b/src/visitor/ExprVisitorImpl.cpp @@ -37,7 +37,12 @@ void ExprVisitorImpl::visit(AttributeExpression *expr) { } void ExprVisitorImpl::visit(LogicalExpression *expr) { - visitBinaryExpr(expr); + for (auto &operand : expr->operands()) { + operand->accept(this); + if (!ok()) { + break; + } + } } void ExprVisitorImpl::visit(LabelAttributeExpression *expr) { diff --git a/src/visitor/ExtractFilterExprVisitor.cpp b/src/visitor/ExtractFilterExprVisitor.cpp index b31a91a7..45250300 100644 --- a/src/visitor/ExtractFilterExprVisitor.cpp +++ b/src/visitor/ExtractFilterExprVisitor.cpp @@ -78,20 +78,26 @@ void ExtractFilterExprVisitor::visit(EdgeExpression *) { } void ExtractFilterExprVisitor::visit(LogicalExpression *expr) { + // TODO(dutor) It's buggy when there are multi-level embedded logical expressions if (expr->kind() == Expression::Kind::kLogicalAnd) { - expr->left()->accept(this); - auto canBePushedLeft = canBePushed_; - expr->right()->accept(this); - auto canBePushedRight = canBePushed_; - canBePushed_ = canBePushedLeft || canBePushedRight; - if (canBePushed_) { - if (!canBePushedLeft) { - remainedExpr_ = expr->left()->clone(); - expr->setLeft(new ConstantExpression(true)); - } else if (!canBePushedRight) { - remainedExpr_ = expr->right()->clone(); - expr->setRight(new ConstantExpression(true)); + auto &operands = expr->operands(); + std::vector<bool> flags(operands.size(), false); + auto canBePushed = false; + for (auto i = 0u; i < operands.size(); i++) { + operands[i]->accept(this); + flags[i] = canBePushed_; + canBePushed = canBePushed || canBePushed_; + } + if (canBePushed) { + auto remainedExpr = std::make_unique<LogicalExpression>(Expression::Kind::kLogicalAnd); + for (auto i = 0u; i < operands.size(); i++) { + if (flags[i]) { + continue; + } + remainedExpr->addOperand(operands[i]->clone().release()); + expr->setOperand(i, new ConstantExpression(true)); } + remainedExpr_ = std::move(remainedExpr); } } else { ExprVisitorImpl::visit(expr); diff --git a/src/visitor/FoldConstantExprVisitor.cpp b/src/visitor/FoldConstantExprVisitor.cpp index 75979921..a3ec2d5c 100644 --- a/src/visitor/FoldConstantExprVisitor.cpp +++ b/src/visitor/FoldConstantExprVisitor.cpp @@ -65,7 +65,31 @@ void FoldConstantExprVisitor::visit(AttributeExpression *expr) { } void FoldConstantExprVisitor::visit(LogicalExpression *expr) { - visitBinaryExpr(expr); + auto &operands = expr->operands(); + auto foldable = true; + // auto shortCircuit = false; + for (auto i = 0u; i < operands.size(); i++) { + auto *operand = operands[i].get(); + operand->accept(this); + if (canBeFolded_) { + auto *newExpr = fold(operand); + expr->setOperand(i, newExpr); + /* + if (newExpr->value().isBool()) { + auto value = newExpr->value().getBool(); + if ((value && expr->kind() == Expression::Kind::kLogicalOr) || + (!value && expr->kind() == Expression::Kind::kLogicalAnd)) { + shortCircuit = true; + break; + } + } + */ + } else { + foldable = false; + } + } + // canBeFolded_ = foldable || shortCircuit; + canBeFolded_ = foldable; } // function call diff --git a/src/visitor/RewriteInputPropVisitor.cpp b/src/visitor/RewriteInputPropVisitor.cpp index 78de007a..bbd2f7b8 100644 --- a/src/visitor/RewriteInputPropVisitor.cpp +++ b/src/visitor/RewriteInputPropVisitor.cpp @@ -62,7 +62,13 @@ void RewriteInputPropVisitor::visit(RelationalExpression* expr) { } void RewriteInputPropVisitor::visit(LogicalExpression* expr) { - visitBinaryExpr(expr); + auto &operands = expr->operands(); + for (auto i = 0u; i < operands.size(); i++) { + operands[i]->accept(this); + if (ok()) { + expr->setOperand(i, result_.release()); + } + } } void RewriteInputPropVisitor::visit(UnaryExpression* expr) { diff --git a/src/visitor/RewriteSymExprVisitor.cpp b/src/visitor/RewriteSymExprVisitor.cpp index e97f84de..2b4c0abd 100644 --- a/src/visitor/RewriteSymExprVisitor.cpp +++ b/src/visitor/RewriteSymExprVisitor.cpp @@ -71,7 +71,13 @@ void RewriteSymExprVisitor::visit(AttributeExpression *expr) { } void RewriteSymExprVisitor::visit(LogicalExpression *expr) { - visitBinaryExpr(expr); + auto &operands = expr->operands(); + for (auto i = 0u; i < operands.size(); i++) { + operands[i]->accept(this); + if (expr_) { + expr->setOperand(i, expr_.release()); + } + } } // function call -- GitLab