diff --git a/src/visitor/FoldConstantExprVisitor.cpp b/src/visitor/FoldConstantExprVisitor.cpp index 290df88d999a0da32df15941ec1a12a020dfb30f..6d3e036ccb673a67df17442b5c19e9bed033b234 100644 --- a/src/visitor/FoldConstantExprVisitor.cpp +++ b/src/visitor/FoldConstantExprVisitor.cpp @@ -135,6 +135,7 @@ void FoldConstantExprVisitor::visit(MapExpression *expr) { for (size_t i = 0; i < items.size(); ++i) { auto &pair = items[i]; auto item = const_cast<Expression *>(pair.second.get()); + item->accept(this); if (!canBeFolded_) { canBeFolded = false; continue; diff --git a/src/visitor/test/FoldConstantExprVisitorTest.cpp b/src/visitor/test/FoldConstantExprVisitorTest.cpp index 8cf7587ce6fb25b5f8732ca0b3b998eb088e6568..b636e44de0a9f75ac51c79251edf871d56ff32bb 100644 --- a/src/visitor/test/FoldConstantExprVisitorTest.cpp +++ b/src/visitor/test/FoldConstantExprVisitorTest.cpp @@ -22,43 +22,43 @@ public: pool.clear(); } - static ConstantExpression *constant(Value value) { + static ConstantExpression *constantExpr(Value value) { return new ConstantExpression(std::move(value)); } - static ArithmeticExpression *add(Expression *lhs, Expression *rhs) { + static ArithmeticExpression *addExpr(Expression *lhs, Expression *rhs) { return new ArithmeticExpression(Expression::Kind::kAdd, lhs, rhs); } - static ArithmeticExpression *minus(Expression *lhs, Expression *rhs) { + static ArithmeticExpression *minusExpr(Expression *lhs, Expression *rhs) { return new ArithmeticExpression(Expression::Kind::kMinus, lhs, rhs); } - static RelationalExpression *gt(Expression *lhs, Expression *rhs) { + static RelationalExpression *gtExpr(Expression *lhs, Expression *rhs) { return new RelationalExpression(Expression::Kind::kRelGT, lhs, rhs); } - static RelationalExpression *eq(Expression *lhs, Expression *rhs) { + static RelationalExpression *eqExpr(Expression *lhs, Expression *rhs) { return new RelationalExpression(Expression::Kind::kRelEQ, lhs, rhs); } - static TypeCastingExpression *cast(Type type, Expression *expr) { + static TypeCastingExpression *castExpr(Type type, Expression *expr) { return new TypeCastingExpression(type, expr); } - static UnaryExpression *not_(Expression *expr) { + static UnaryExpression *notExpr(Expression *expr) { return new UnaryExpression(Expression::Kind::kUnaryNot, expr); } - static LogicalExpression *and_(Expression *lhs, Expression *rhs) { + static LogicalExpression *andExpr(Expression *lhs, Expression *rhs) { return new LogicalExpression(Expression::Kind::kLogicalAnd, lhs, rhs); } - static LogicalExpression *or_(Expression *lhs, Expression *rhs) { + static LogicalExpression *orExpr(Expression *lhs, Expression *rhs) { return new LogicalExpression(Expression::Kind::kLogicalOr, lhs, rhs); } - static ListExpression *list_(std::initializer_list<Expression *> exprs) { + static ListExpression *listExpr(std::initializer_list<Expression *> exprs) { auto exprList = new ExpressionList; for (auto expr : exprs) { exprList->add(expr); @@ -66,11 +66,29 @@ public: return new ListExpression(exprList); } - static SubscriptExpression *sub(Expression *lhs, Expression *rhs) { + static SetExpression *setExpr(std::initializer_list<Expression *> exprs) { + auto exprList = new ExpressionList; + for (auto expr : exprs) { + exprList->add(expr); + } + return new SetExpression(exprList); + } + + static MapExpression *mapExpr( + std::initializer_list<std::pair<std::string, Expression *>> exprs) { + auto mapItemList = new MapItemList; + for (auto expr : exprs) { + mapItemList->add(new std::string(expr.first), expr.second); + } + return new MapExpression(mapItemList); + } + + static SubscriptExpression *subExpr(Expression *lhs, Expression *rhs) { return new SubscriptExpression(lhs, rhs); } - static FunctionCallExpression *fn(std::string fn, std::initializer_list<Expression *> args) { + static FunctionCallExpression *fnExpr(std::string fn, + std::initializer_list<Expression *> args) { auto argsList = new ArgumentList; for (auto arg : args) { argsList->addArgument(std::unique_ptr<Expression>(arg)); @@ -78,7 +96,7 @@ public: return new FunctionCallExpression(new std::string(std::move(fn)), argsList); } - static VariableExpression *var(const std::string &name) { + static VariableExpression *varExpr(const std::string &name) { return new VariableExpression(new std::string(name)); } @@ -88,23 +106,25 @@ protected: TEST_F(FoldConstantExprVisitorTest, TestArithmeticExpr) { // (5 - 1) + 2 => 4 + 2 - auto expr = pool.add(add(minus(constant(5), constant(1)), constant(2))); + auto expr = pool.add(addExpr(minusExpr(constantExpr(5), constantExpr(1)), constantExpr(2))); FoldConstantExprVisitor visitor; expr->accept(&visitor); - auto expected = pool.add(add(constant(4), constant(2))); + auto expected = pool.add(addExpr(constantExpr(4), constantExpr(2))); ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); ASSERT(visitor.canBeFolded()); // 4+2 => 6 auto root = pool.add(visitor.fold(expr)); - auto rootExpected = pool.add(constant(6)); + auto rootExpected = pool.add(constantExpr(6)); ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); } TEST_F(FoldConstantExprVisitorTest, TestRelationExpr) { // false == !(3 > (1+1)) => false == false - auto expr = pool.add(eq(constant(false), not_(gt(constant(3), add(constant(1), constant(1)))))); - auto expected = pool.add(eq(constant(false), constant(false))); + auto expr = pool.add( + eqExpr(constantExpr(false), + notExpr(gtExpr(constantExpr(3), addExpr(constantExpr(1), constantExpr(1)))))); + auto expected = pool.add(eqExpr(constantExpr(false), constantExpr(false))); FoldConstantExprVisitor visitor; expr->accept(&visitor); ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); @@ -112,15 +132,17 @@ TEST_F(FoldConstantExprVisitorTest, TestRelationExpr) { // false==false => true auto root = pool.add(visitor.fold(expr)); - auto rootExpected = pool.add(constant(true)); + auto rootExpected = pool.add(constantExpr(true)); ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); } TEST_F(FoldConstantExprVisitorTest, TestLogicalExpr) { // false && (false || (3 > (1 + 1))) => false && true - auto expr = pool.add(and_( - constant(false), or_(constant(false), gt(constant(3), add(constant(1), constant(1)))))); - auto expected = pool.add(and_(constant(false), constant(true))); + auto expr = pool.add( + andExpr(constantExpr(false), + orExpr(constantExpr(false), + gtExpr(constantExpr(3), addExpr(constantExpr(1), constantExpr(1)))))); + auto expected = pool.add(andExpr(constantExpr(false), constantExpr(true))); FoldConstantExprVisitor visitor; expr->accept(&visitor); ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); @@ -128,18 +150,20 @@ TEST_F(FoldConstantExprVisitorTest, TestLogicalExpr) { // false && true => false auto root = pool.add(visitor.fold(expr)); - auto rootExpected = pool.add(constant(false)); + auto rootExpected = pool.add(constantExpr(false)); ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); } TEST_F(FoldConstantExprVisitorTest, TestSubscriptExpr) { // 1 + [1, pow(2, 2+1), 2][2-1] => 1 + 8 - auto expr = pool.add(add(constant(1), - sub(list_({constant(1), - fn("pow", {constant(2), add(constant(2), constant(1))}), - constant(2)}), - minus(constant(2), constant(1))))); - auto expected = pool.add(add(constant(1), constant(8))); + auto expr = pool.add(addExpr( + constantExpr(1), + subExpr( + listExpr({constantExpr(1), + fnExpr("pow", {constantExpr(2), addExpr(constantExpr(2), constantExpr(1))}), + constantExpr(2)}), + minusExpr(constantExpr(2), constantExpr(1))))); + auto expected = pool.add(addExpr(constantExpr(1), constantExpr(8))); FoldConstantExprVisitor visitor; expr->accept(&visitor); ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); @@ -147,16 +171,57 @@ TEST_F(FoldConstantExprVisitorTest, TestSubscriptExpr) { // 1+8 => 9 auto root = pool.add(visitor.fold(expr)); - auto rootExpected = pool.add(constant(9)); + auto rootExpected = pool.add(constantExpr(9)); ASSERT_EQ(*root, *rootExpected) << root->toString() << " vs. " << rootExpected->toString(); } +TEST_F(FoldConstantExprVisitorTest, TestListExpr) { + // [3+4, pow(2, 2+1), 2] => [7, 8, 2] + auto expr = pool.add( + listExpr({addExpr(constantExpr(3), constantExpr(4)), + fnExpr("pow", {constantExpr(2), addExpr(constantExpr(2), constantExpr(1))}), + constantExpr(2)})); + auto expected = pool.add(listExpr({constantExpr(7), constantExpr(8), constantExpr(2)})); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); +} + +TEST_F(FoldConstantExprVisitorTest, TestSetExpr) { + // {sqrt(19-3), pow(2, 3+1), 2} => {4, 16, 2} + auto expr = pool.add( + setExpr({fnExpr("sqrt", {minusExpr(constantExpr(19), constantExpr(3))}), + fnExpr("pow", {constantExpr(2), addExpr(constantExpr(3), constantExpr(1))}), + constantExpr(2)})); + auto expected = pool.add(setExpr({constantExpr(4), constantExpr(16), constantExpr(2)})); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); +} + +TEST_F(FoldConstantExprVisitorTest, TestMapExpr) { + // {"jack":1, "tom":pow(2, 2+1), "jerry":5-1} => {"jack":1, "tom":8, "jerry":4} + auto expr = pool.add(mapExpr( + {{"jack", constantExpr(1)}, + {"tom", fnExpr("pow", {constantExpr(2), addExpr(constantExpr(2), constantExpr(1))})}, + {"jerry", minusExpr(constantExpr(5), constantExpr(1))}})); + auto expected = pool.add( + mapExpr({{"jack", constantExpr(1)}, {"tom", constantExpr(8)}, {"jerry", constantExpr(4)}})); + FoldConstantExprVisitor visitor; + expr->accept(&visitor); + ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); + ASSERT(visitor.canBeFolded()); +} + TEST_F(FoldConstantExprVisitorTest, TestFoldFailed) { // function call { // pow($v, (1+2)) => pow($v, 3) - auto expr = pool.add(fn("pow", {var("v"), add(constant(1), constant(2))})); - auto expected = pool.add(fn("pow", {var("v"), constant(3)})); + auto expr = + pool.add(fnExpr("pow", {varExpr("v"), addExpr(constantExpr(1), constantExpr(2))})); + auto expected = pool.add(fnExpr("pow", {varExpr("v"), constantExpr(3)})); FoldConstantExprVisitor visitor; expr->accept(&visitor); ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString(); @@ -165,10 +230,12 @@ TEST_F(FoldConstantExprVisitorTest, TestFoldFailed) { // list { // [$v, pow(1, 2), 1+2][2-1] => [$v, 1, 3][0] - auto expr = pool.add(sub( - list_({var("v"), fn("pow", {constant(1), constant(2)}), add(constant(1), constant(2))}), - minus(constant(1), constant(1)))); - auto expected = pool.add(sub(list_({var("v"), constant(1), constant(3)}), constant(0))); + auto expr = pool.add(subExpr(listExpr({varExpr("v"), + fnExpr("pow", {constantExpr(1), constantExpr(2)}), + addExpr(constantExpr(1), constantExpr(2))}), + minusExpr(constantExpr(1), constantExpr(1)))); + auto expected = pool.add( + subExpr(listExpr({varExpr("v"), constantExpr(1), constantExpr(3)}), constantExpr(0))); FoldConstantExprVisitor visitor; expr->accept(&visitor); ASSERT_EQ(*expr, *expected) << expr->toString() << " vs. " << expected->toString();