diff --git a/src/optimizer/rule/IndexScanRule.cpp b/src/optimizer/rule/IndexScanRule.cpp index 9a50f3d55186bde68e53e533a0aee9a49ea99a04..2c6e599a5896c2d8249e949da8cd5a7c979164f2 100644 --- a/src/optimizer/rule/IndexScanRule.cpp +++ b/src/optimizer/rule/IndexScanRule.cpp @@ -308,8 +308,9 @@ Status IndexScanRule::analyzeExpression(Expression* expr, Expression::encode(*expr).c_str()); return Status::NotSupported(errorMsg); } - NG_RETURN_IF_ERROR(analyzeExpression(lExpr->left(), items, kind, isEdge)); - NG_RETURN_IF_ERROR(analyzeExpression(lExpr->right(), items, kind, isEdge)); + // TODO(dutor) Deal with n-ary operands + NG_RETURN_IF_ERROR(analyzeExpression(lExpr->operand(0), items, kind, isEdge)); + NG_RETURN_IF_ERROR(analyzeExpression(lExpr->operand(1), items, kind, isEdge)); break; } case Expression::Kind::kRelLE: diff --git a/src/util/ExpressionUtils.cpp b/src/util/ExpressionUtils.cpp index 10a4cc31091fb65c9ce7ff00cc547ec48c0099d6..82db2976731f2a8342202305a9270f4de1e7c4ba 100644 --- a/src/util/ExpressionUtils.cpp +++ b/src/util/ExpressionUtils.cpp @@ -26,17 +26,18 @@ std::vector<const Expression*> ExpressionUtils::pullAnds(const Expression *expr) auto *root = static_cast<const LogicalExpression*>(expr); std::vector<const Expression*> operands; - if (root->left()->kind() != Expression::Kind::kLogicalAnd) { - operands.emplace_back(root->left()); + // TODO(dutor) Deal with n-ary operands + if (root->operand(0)->kind() != Expression::Kind::kLogicalAnd) { + operands.emplace_back(root->operand(0)); } else { - auto ands = pullAnds(root->left()); + auto ands = pullAnds(root->operand(0)); operands.insert(operands.end(), ands.begin(), ands.end()); } - if (root->right()->kind() != Expression::Kind::kLogicalAnd) { - operands.emplace_back(root->right()); + if (root->operand(1)->kind() != Expression::Kind::kLogicalAnd) { + operands.emplace_back(root->operand(1)); } else { - auto ands = pullAnds(root->right()); + auto ands = pullAnds(root->operand(1)); operands.insert(operands.end(), ands.begin(), ands.end()); } @@ -48,17 +49,18 @@ std::vector<const Expression*> ExpressionUtils::pullOrs(const Expression *expr) auto *root = static_cast<const LogicalExpression*>(expr); std::vector<const Expression*> operands; - if (root->left()->kind() != Expression::Kind::kLogicalOr) { - operands.emplace_back(root->left()); + // TODO(dutor) Deal with n-ary operands + if (root->operand(0)->kind() != Expression::Kind::kLogicalOr) { + operands.emplace_back(root->operand(0)); } else { - auto ands = pullOrs(root->left()); + auto ands = pullOrs(root->operand(0)); operands.insert(operands.end(), ands.begin(), ands.end()); } - if (root->right()->kind() != Expression::Kind::kLogicalOr) { - operands.emplace_back(root->right()); + if (root->operand(1)->kind() != Expression::Kind::kLogicalOr) { + operands.emplace_back(root->operand(1)); } else { - auto ands = pullOrs(root->right()); + auto ands = pullOrs(root->operand(1)); operands.insert(operands.end(), ands.begin(), ands.end()); } diff --git a/src/validator/IndexScanValidator.cpp b/src/validator/IndexScanValidator.cpp index def908ccc1a3a8f00c06353060ac56996f8a4b36..758093256a33878d2d0cb079d5f821f12ebe0ece 100644 --- a/src/validator/IndexScanValidator.cpp +++ b/src/validator/IndexScanValidator.cpp @@ -98,10 +98,11 @@ Status IndexScanValidator::checkFilter(Expression* expr, const std::string& from switch (expr->kind()) { case Expression::Kind::kLogicalOr : case Expression::Kind::kLogicalAnd : { + // TODO(dutor) Deal with n-ary operands auto lExpr = static_cast<LogicalExpression*>(expr); - auto ret = checkFilter(lExpr->left(), from); + auto ret = checkFilter(lExpr->operand(0), from); NG_RETURN_IF_ERROR(ret); - ret = checkFilter(lExpr->right(), from); + ret = checkFilter(lExpr->operand(1), from); NG_RETURN_IF_ERROR(ret); break; } diff --git a/src/visitor/DeduceTypeVisitor.cpp b/src/visitor/DeduceTypeVisitor.cpp index e72e86d81b235f9489cf7d26da5a6461be41bfd3..512c1a181abd37894d56d7722f5a5b10791d671b 100644 --- a/src/visitor/DeduceTypeVisitor.cpp +++ b/src/visitor/DeduceTypeVisitor.cpp @@ -60,6 +60,29 @@ static const std::unordered_map<Value::Type, Value> kConstantValues = { } \ type_ = detectVal.type() +#define DETECT_NARYEXPR_TYPE(OP) \ + do { \ + auto &operands = expr->operands(); \ + operands[0]->accept(this); \ + if (!ok()) return; \ + auto prev = type_; \ + for (auto i = 1u; i < operands.size(); i++) { \ + operands[i]->accept(this); \ + if (!ok()) return; \ + auto current = type_; \ + auto detectValue = kConstantValues.at(prev) OP kConstantValues.at(current); \ + if (detectValue.isBadNull()) { \ + std::stringstream ss; \ + ss << "`" << expr->toString() << "' is not a valid expression, " \ + << "can not apply `" << #OP << "' to `" << prev << "' and `" << current << "'."; \ + status_ = Status::SemanticError(ss.str()); \ + return; \ + } \ + prev = detectValue.type(); \ + } \ + type_ = prev; \ + } while (false) + #define DETECT_UNARYEXPR_TYPE(OP) \ auto detectVal = OP kConstantValues.at(type_); \ if (detectVal.isBadNull()) { \ @@ -301,12 +324,12 @@ void DeduceTypeVisitor::visit(AttributeExpression *expr) { void DeduceTypeVisitor::visit(LogicalExpression *expr) { switch (expr->kind()) { case Expression::Kind::kLogicalAnd: { - DETECT_BIEXPR_TYPE(&&); + DETECT_NARYEXPR_TYPE(&&); break; } case Expression::Kind::kLogicalXor: case Expression::Kind::kLogicalOr: { - DETECT_BIEXPR_TYPE(||); + DETECT_NARYEXPR_TYPE(||); break; } default: { @@ -529,6 +552,7 @@ void DeduceTypeVisitor::visitVertexPropertyExpr(PropertyExpression *expr) { void DeduceTypeVisitor::visit(PathBuildExpression *) { type_ = Value::Type::PATH; } +#undef DETECT_NARYEXPR_TYPE #undef DETECT_UNARYEXPR_TYPE #undef DETECT_BIEXPR_TYPE diff --git a/src/visitor/ExprVisitorImpl.cpp b/src/visitor/ExprVisitorImpl.cpp index 118fe48c986db0c7820d80b18de243c845b62694..ca6fea86b9bbd8273b415382dcea4c7e5728d9a4 100644 --- a/src/visitor/ExprVisitorImpl.cpp +++ b/src/visitor/ExprVisitorImpl.cpp @@ -37,7 +37,12 @@ void ExprVisitorImpl::visit(AttributeExpression *expr) { } void ExprVisitorImpl::visit(LogicalExpression *expr) { - visitBinaryExpr(expr); + for (auto &operand : expr->operands()) { + operand->accept(this); + if (!ok()) { + break; + } + } } void ExprVisitorImpl::visit(LabelAttributeExpression *expr) { diff --git a/src/visitor/ExtractFilterExprVisitor.cpp b/src/visitor/ExtractFilterExprVisitor.cpp index b31a91a76cf1cca7fcbf446bb05b324fabf60533..452503001ef799c6d17324931f51983ccf33464f 100644 --- a/src/visitor/ExtractFilterExprVisitor.cpp +++ b/src/visitor/ExtractFilterExprVisitor.cpp @@ -78,20 +78,26 @@ void ExtractFilterExprVisitor::visit(EdgeExpression *) { } void ExtractFilterExprVisitor::visit(LogicalExpression *expr) { + // TODO(dutor) It's buggy when there are multi-level embedded logical expressions if (expr->kind() == Expression::Kind::kLogicalAnd) { - expr->left()->accept(this); - auto canBePushedLeft = canBePushed_; - expr->right()->accept(this); - auto canBePushedRight = canBePushed_; - canBePushed_ = canBePushedLeft || canBePushedRight; - if (canBePushed_) { - if (!canBePushedLeft) { - remainedExpr_ = expr->left()->clone(); - expr->setLeft(new ConstantExpression(true)); - } else if (!canBePushedRight) { - remainedExpr_ = expr->right()->clone(); - expr->setRight(new ConstantExpression(true)); + auto &operands = expr->operands(); + std::vector<bool> flags(operands.size(), false); + auto canBePushed = false; + for (auto i = 0u; i < operands.size(); i++) { + operands[i]->accept(this); + flags[i] = canBePushed_; + canBePushed = canBePushed || canBePushed_; + } + if (canBePushed) { + auto remainedExpr = std::make_unique<LogicalExpression>(Expression::Kind::kLogicalAnd); + for (auto i = 0u; i < operands.size(); i++) { + if (flags[i]) { + continue; + } + remainedExpr->addOperand(operands[i]->clone().release()); + expr->setOperand(i, new ConstantExpression(true)); } + remainedExpr_ = std::move(remainedExpr); } } else { ExprVisitorImpl::visit(expr); diff --git a/src/visitor/FoldConstantExprVisitor.cpp b/src/visitor/FoldConstantExprVisitor.cpp index 75979921f93e2fa85274df9594cf34144680b504..a3ec2d5c81d9df7a81cbd759789a7cd00a28f64b 100644 --- a/src/visitor/FoldConstantExprVisitor.cpp +++ b/src/visitor/FoldConstantExprVisitor.cpp @@ -65,7 +65,31 @@ void FoldConstantExprVisitor::visit(AttributeExpression *expr) { } void FoldConstantExprVisitor::visit(LogicalExpression *expr) { - visitBinaryExpr(expr); + auto &operands = expr->operands(); + auto foldable = true; + // auto shortCircuit = false; + for (auto i = 0u; i < operands.size(); i++) { + auto *operand = operands[i].get(); + operand->accept(this); + if (canBeFolded_) { + auto *newExpr = fold(operand); + expr->setOperand(i, newExpr); + /* + if (newExpr->value().isBool()) { + auto value = newExpr->value().getBool(); + if ((value && expr->kind() == Expression::Kind::kLogicalOr) || + (!value && expr->kind() == Expression::Kind::kLogicalAnd)) { + shortCircuit = true; + break; + } + } + */ + } else { + foldable = false; + } + } + // canBeFolded_ = foldable || shortCircuit; + canBeFolded_ = foldable; } // function call diff --git a/src/visitor/RewriteInputPropVisitor.cpp b/src/visitor/RewriteInputPropVisitor.cpp index 78de007a090ee5922998301118f1dc6479df2615..bbd2f7b8718b3acf3ba13a035e01db8bd826d707 100644 --- a/src/visitor/RewriteInputPropVisitor.cpp +++ b/src/visitor/RewriteInputPropVisitor.cpp @@ -62,7 +62,13 @@ void RewriteInputPropVisitor::visit(RelationalExpression* expr) { } void RewriteInputPropVisitor::visit(LogicalExpression* expr) { - visitBinaryExpr(expr); + auto &operands = expr->operands(); + for (auto i = 0u; i < operands.size(); i++) { + operands[i]->accept(this); + if (ok()) { + expr->setOperand(i, result_.release()); + } + } } void RewriteInputPropVisitor::visit(UnaryExpression* expr) { diff --git a/src/visitor/RewriteSymExprVisitor.cpp b/src/visitor/RewriteSymExprVisitor.cpp index e97f84debb050b79fb59c9b0a6acf7783cbc2402..2b4c0abd140cfc6c5b8588d0a3daa9411c8b8fe4 100644 --- a/src/visitor/RewriteSymExprVisitor.cpp +++ b/src/visitor/RewriteSymExprVisitor.cpp @@ -71,7 +71,13 @@ void RewriteSymExprVisitor::visit(AttributeExpression *expr) { } void RewriteSymExprVisitor::visit(LogicalExpression *expr) { - visitBinaryExpr(expr); + auto &operands = expr->operands(); + for (auto i = 0u; i < operands.size(); i++) { + operands[i]->accept(this); + if (expr_) { + expr->setOperand(i, expr_.release()); + } + } } // function call