From 93f4c0a22172040b5870c5174d59325e6725a3ec Mon Sep 17 00:00:00 2001
From: dutor <440396+dutor@users.noreply.github.com>
Date: Mon, 9 Nov 2020 17:59:41 +0800
Subject: [PATCH] Flatten LogicalExpression (#400)

* Flatten LogicalExpression

* remove left/right interface

Co-authored-by: cpw <13495049+CPWstatic@users.noreply.github.com>
---
 src/optimizer/rule/IndexScanRule.cpp     |  5 ++--
 src/util/ExpressionUtils.cpp             | 26 ++++++++++----------
 src/validator/IndexScanValidator.cpp     |  5 ++--
 src/visitor/DeduceTypeVisitor.cpp        | 28 ++++++++++++++++++++--
 src/visitor/ExprVisitorImpl.cpp          |  7 +++++-
 src/visitor/ExtractFilterExprVisitor.cpp | 30 ++++++++++++++----------
 src/visitor/FoldConstantExprVisitor.cpp  | 26 +++++++++++++++++++-
 src/visitor/RewriteInputPropVisitor.cpp  |  8 ++++++-
 src/visitor/RewriteSymExprVisitor.cpp    |  8 ++++++-
 9 files changed, 109 insertions(+), 34 deletions(-)

diff --git a/src/optimizer/rule/IndexScanRule.cpp b/src/optimizer/rule/IndexScanRule.cpp
index 9a50f3d5..2c6e599a 100644
--- a/src/optimizer/rule/IndexScanRule.cpp
+++ b/src/optimizer/rule/IndexScanRule.cpp
@@ -308,8 +308,9 @@ Status IndexScanRule::analyzeExpression(Expression* expr,
                                                    Expression::encode(*expr).c_str());
                 return Status::NotSupported(errorMsg);
             }
-            NG_RETURN_IF_ERROR(analyzeExpression(lExpr->left(), items, kind, isEdge));
-            NG_RETURN_IF_ERROR(analyzeExpression(lExpr->right(), items, kind, isEdge));
+            // TODO(dutor) Deal with n-ary operands
+            NG_RETURN_IF_ERROR(analyzeExpression(lExpr->operand(0), items, kind, isEdge));
+            NG_RETURN_IF_ERROR(analyzeExpression(lExpr->operand(1), items, kind, isEdge));
             break;
         }
         case Expression::Kind::kRelLE:
diff --git a/src/util/ExpressionUtils.cpp b/src/util/ExpressionUtils.cpp
index 10a4cc31..82db2976 100644
--- a/src/util/ExpressionUtils.cpp
+++ b/src/util/ExpressionUtils.cpp
@@ -26,17 +26,18 @@ std::vector<const Expression*> ExpressionUtils::pullAnds(const Expression *expr)
     auto *root = static_cast<const LogicalExpression*>(expr);
     std::vector<const Expression*> operands;
 
-    if (root->left()->kind() != Expression::Kind::kLogicalAnd) {
-        operands.emplace_back(root->left());
+    // TODO(dutor) Deal with n-ary operands
+    if (root->operand(0)->kind() != Expression::Kind::kLogicalAnd) {
+        operands.emplace_back(root->operand(0));
     } else {
-        auto ands = pullAnds(root->left());
+        auto ands = pullAnds(root->operand(0));
         operands.insert(operands.end(), ands.begin(), ands.end());
     }
 
-    if (root->right()->kind() != Expression::Kind::kLogicalAnd) {
-        operands.emplace_back(root->right());
+    if (root->operand(1)->kind() != Expression::Kind::kLogicalAnd) {
+        operands.emplace_back(root->operand(1));
     } else {
-        auto ands = pullAnds(root->right());
+        auto ands = pullAnds(root->operand(1));
         operands.insert(operands.end(), ands.begin(), ands.end());
     }
 
@@ -48,17 +49,18 @@ std::vector<const Expression*> ExpressionUtils::pullOrs(const Expression *expr)
     auto *root = static_cast<const LogicalExpression*>(expr);
     std::vector<const Expression*> operands;
 
-    if (root->left()->kind() != Expression::Kind::kLogicalOr) {
-        operands.emplace_back(root->left());
+    // TODO(dutor) Deal with n-ary operands
+    if (root->operand(0)->kind() != Expression::Kind::kLogicalOr) {
+        operands.emplace_back(root->operand(0));
     } else {
-        auto ands = pullOrs(root->left());
+        auto ands = pullOrs(root->operand(0));
         operands.insert(operands.end(), ands.begin(), ands.end());
     }
 
-    if (root->right()->kind() != Expression::Kind::kLogicalOr) {
-        operands.emplace_back(root->right());
+    if (root->operand(1)->kind() != Expression::Kind::kLogicalOr) {
+        operands.emplace_back(root->operand(1));
     } else {
-        auto ands = pullOrs(root->right());
+        auto ands = pullOrs(root->operand(1));
         operands.insert(operands.end(), ands.begin(), ands.end());
     }
 
diff --git a/src/validator/IndexScanValidator.cpp b/src/validator/IndexScanValidator.cpp
index def908cc..75809325 100644
--- a/src/validator/IndexScanValidator.cpp
+++ b/src/validator/IndexScanValidator.cpp
@@ -98,10 +98,11 @@ Status IndexScanValidator::checkFilter(Expression* expr, const std::string& from
     switch (expr->kind()) {
         case Expression::Kind::kLogicalOr :
         case Expression::Kind::kLogicalAnd : {
+            // TODO(dutor) Deal with n-ary operands
             auto lExpr = static_cast<LogicalExpression*>(expr);
-            auto ret = checkFilter(lExpr->left(), from);
+            auto ret = checkFilter(lExpr->operand(0), from);
             NG_RETURN_IF_ERROR(ret);
-            ret = checkFilter(lExpr->right(), from);
+            ret = checkFilter(lExpr->operand(1), from);
             NG_RETURN_IF_ERROR(ret);
             break;
         }
diff --git a/src/visitor/DeduceTypeVisitor.cpp b/src/visitor/DeduceTypeVisitor.cpp
index e72e86d8..512c1a18 100644
--- a/src/visitor/DeduceTypeVisitor.cpp
+++ b/src/visitor/DeduceTypeVisitor.cpp
@@ -60,6 +60,29 @@ static const std::unordered_map<Value::Type, Value> kConstantValues = {
     }                                                                                              \
     type_ = detectVal.type()
 
+#define DETECT_NARYEXPR_TYPE(OP)                                                                   \
+    do {                                                                                           \
+        auto &operands = expr->operands();                                                         \
+        operands[0]->accept(this);                                                                 \
+        if (!ok()) return;                                                                         \
+        auto prev = type_;                                                                         \
+        for (auto i = 1u; i < operands.size(); i++) {                                              \
+            operands[i]->accept(this);                                                             \
+            if (!ok()) return;                                                                     \
+            auto current = type_;                                                                  \
+            auto detectValue = kConstantValues.at(prev) OP kConstantValues.at(current);            \
+            if (detectValue.isBadNull()) {                                                         \
+                std::stringstream ss;                                                              \
+                ss << "`" << expr->toString() << "' is not a valid expression, "                   \
+                << "can not apply `" << #OP << "' to `" << prev << "' and `" << current << "'.";   \
+                status_ = Status::SemanticError(ss.str());                                         \
+                return;                                                                            \
+            }                                                                                      \
+            prev = detectValue.type();                                                             \
+        }                                                                                          \
+        type_ = prev;                                                                              \
+    } while (false)
+
 #define DETECT_UNARYEXPR_TYPE(OP)                                                                  \
     auto detectVal = OP kConstantValues.at(type_);                                                 \
     if (detectVal.isBadNull()) {                                                                   \
@@ -301,12 +324,12 @@ void DeduceTypeVisitor::visit(AttributeExpression *expr) {
 void DeduceTypeVisitor::visit(LogicalExpression *expr) {
     switch (expr->kind()) {
         case Expression::Kind::kLogicalAnd: {
-            DETECT_BIEXPR_TYPE(&&);
+            DETECT_NARYEXPR_TYPE(&&);
             break;
         }
         case Expression::Kind::kLogicalXor:
         case Expression::Kind::kLogicalOr: {
-            DETECT_BIEXPR_TYPE(||);
+            DETECT_NARYEXPR_TYPE(||);
             break;
         }
         default: {
@@ -529,6 +552,7 @@ void DeduceTypeVisitor::visitVertexPropertyExpr(PropertyExpression *expr) {
 void DeduceTypeVisitor::visit(PathBuildExpression *) {
     type_ = Value::Type::PATH;
 }
+#undef DETECT_NARYEXPR_TYPE
 #undef DETECT_UNARYEXPR_TYPE
 #undef DETECT_BIEXPR_TYPE
 
diff --git a/src/visitor/ExprVisitorImpl.cpp b/src/visitor/ExprVisitorImpl.cpp
index 118fe48c..ca6fea86 100644
--- a/src/visitor/ExprVisitorImpl.cpp
+++ b/src/visitor/ExprVisitorImpl.cpp
@@ -37,7 +37,12 @@ void ExprVisitorImpl::visit(AttributeExpression *expr) {
 }
 
 void ExprVisitorImpl::visit(LogicalExpression *expr) {
-    visitBinaryExpr(expr);
+    for (auto &operand : expr->operands()) {
+        operand->accept(this);
+        if (!ok()) {
+            break;
+        }
+    }
 }
 
 void ExprVisitorImpl::visit(LabelAttributeExpression *expr) {
diff --git a/src/visitor/ExtractFilterExprVisitor.cpp b/src/visitor/ExtractFilterExprVisitor.cpp
index b31a91a7..45250300 100644
--- a/src/visitor/ExtractFilterExprVisitor.cpp
+++ b/src/visitor/ExtractFilterExprVisitor.cpp
@@ -78,20 +78,26 @@ void ExtractFilterExprVisitor::visit(EdgeExpression *) {
 }
 
 void ExtractFilterExprVisitor::visit(LogicalExpression *expr) {
+    // TODO(dutor) It's buggy when there are multi-level embedded logical expressions
     if (expr->kind() == Expression::Kind::kLogicalAnd) {
-        expr->left()->accept(this);
-        auto canBePushedLeft = canBePushed_;
-        expr->right()->accept(this);
-        auto canBePushedRight = canBePushed_;
-        canBePushed_ = canBePushedLeft || canBePushedRight;
-        if (canBePushed_) {
-            if (!canBePushedLeft) {
-                remainedExpr_ = expr->left()->clone();
-                expr->setLeft(new ConstantExpression(true));
-            } else if (!canBePushedRight) {
-                remainedExpr_ = expr->right()->clone();
-                expr->setRight(new ConstantExpression(true));
+        auto &operands = expr->operands();
+        std::vector<bool> flags(operands.size(), false);
+        auto canBePushed = false;
+        for (auto i = 0u; i < operands.size(); i++) {
+            operands[i]->accept(this);
+            flags[i] = canBePushed_;
+            canBePushed = canBePushed || canBePushed_;
+        }
+        if (canBePushed) {
+            auto remainedExpr = std::make_unique<LogicalExpression>(Expression::Kind::kLogicalAnd);
+            for (auto i = 0u; i < operands.size(); i++) {
+                if (flags[i]) {
+                    continue;
+                }
+                remainedExpr->addOperand(operands[i]->clone().release());
+                expr->setOperand(i, new ConstantExpression(true));
             }
+            remainedExpr_ = std::move(remainedExpr);
         }
     } else {
         ExprVisitorImpl::visit(expr);
diff --git a/src/visitor/FoldConstantExprVisitor.cpp b/src/visitor/FoldConstantExprVisitor.cpp
index 75979921..a3ec2d5c 100644
--- a/src/visitor/FoldConstantExprVisitor.cpp
+++ b/src/visitor/FoldConstantExprVisitor.cpp
@@ -65,7 +65,31 @@ void FoldConstantExprVisitor::visit(AttributeExpression *expr) {
 }
 
 void FoldConstantExprVisitor::visit(LogicalExpression *expr) {
-    visitBinaryExpr(expr);
+    auto &operands = expr->operands();
+    auto foldable = true;
+    // auto shortCircuit = false;
+    for (auto i = 0u; i < operands.size(); i++) {
+        auto *operand = operands[i].get();
+        operand->accept(this);
+        if (canBeFolded_) {
+            auto *newExpr = fold(operand);
+            expr->setOperand(i, newExpr);
+            /*
+            if (newExpr->value().isBool()) {
+                auto value = newExpr->value().getBool();
+                if ((value && expr->kind() == Expression::Kind::kLogicalOr) ||
+                        (!value && expr->kind() == Expression::Kind::kLogicalAnd)) {
+                    shortCircuit = true;
+                    break;
+                }
+            }
+            */
+        } else {
+            foldable = false;
+        }
+    }
+    // canBeFolded_ = foldable || shortCircuit;
+    canBeFolded_ = foldable;
 }
 
 // function call
diff --git a/src/visitor/RewriteInputPropVisitor.cpp b/src/visitor/RewriteInputPropVisitor.cpp
index 78de007a..bbd2f7b8 100644
--- a/src/visitor/RewriteInputPropVisitor.cpp
+++ b/src/visitor/RewriteInputPropVisitor.cpp
@@ -62,7 +62,13 @@ void RewriteInputPropVisitor::visit(RelationalExpression* expr) {
 }
 
 void RewriteInputPropVisitor::visit(LogicalExpression* expr) {
-    visitBinaryExpr(expr);
+    auto &operands = expr->operands();
+    for (auto i = 0u; i < operands.size(); i++) {
+        operands[i]->accept(this);
+        if (ok()) {
+            expr->setOperand(i, result_.release());
+        }
+    }
 }
 
 void RewriteInputPropVisitor::visit(UnaryExpression* expr) {
diff --git a/src/visitor/RewriteSymExprVisitor.cpp b/src/visitor/RewriteSymExprVisitor.cpp
index e97f84de..2b4c0abd 100644
--- a/src/visitor/RewriteSymExprVisitor.cpp
+++ b/src/visitor/RewriteSymExprVisitor.cpp
@@ -71,7 +71,13 @@ void RewriteSymExprVisitor::visit(AttributeExpression *expr) {
 }
 
 void RewriteSymExprVisitor::visit(LogicalExpression *expr) {
-    visitBinaryExpr(expr);
+    auto &operands = expr->operands();
+    for (auto i = 0u; i < operands.size(); i++) {
+        operands[i]->accept(this);
+        if (expr_) {
+            expr->setOperand(i, expr_.release());
+        }
+    }
 }
 
 // function call
-- 
GitLab