From 75cfdb7c7d1d53038202332a7ece82b282aaefaa Mon Sep 17 00:00:00 2001
From: cpw <13495049+CPWstatic@users.noreply.github.com>
Date: Thu, 15 Oct 2020 12:00:48 +0800
Subject: [PATCH] Fix go test. (#334)

* Fix go it.

* Fix go ut.

* Fix go m to n.

* Fix ut.

* Fix symbol test.
---
 src/validator/GoValidator.cpp             | 77 +++++++++++++++++-----
 src/validator/GoValidator.h               |  2 +
 src/validator/TraversalValidator.cpp      | 10 ---
 src/validator/test/QueryValidatorTest.cpp | 30 ++++++---
 src/validator/test/SymbolsTest.cpp        | 80 ++++++++++++++---------
 tests/query/v1/test_go.py                 |  3 +-
 6 files changed, 135 insertions(+), 67 deletions(-)

diff --git a/src/validator/GoValidator.cpp b/src/validator/GoValidator.cpp
index a9de983b..675b1dbd 100644
--- a/src/validator/GoValidator.cpp
+++ b/src/validator/GoValidator.cpp
@@ -240,19 +240,26 @@ Status GoValidator::buildNStepsPlan() {
     VLOG(1) << gn->outputVar();
 
     PlanNode* dedupDstVids = projectDstVidsFromGN(gn, startVidsVar);
+    PlanNode* loopBody = dedupDstVids;
+
+    PlanNode* dedupSrcDstVids = nullptr;
+    if (from_.fromType != FromType::kInstantExpr) {
+         dedupSrcDstVids = projectSrcDstVidsFromGN(dedupDstVids, gn);
+         loopBody = dedupSrcDstVids;
+    }
 
     // Trace to the start vid if starts from a runtime start vid.
     PlanNode* projectFromJoin = nullptr;
     if (from_.fromType != FromType::kInstantExpr  &&
-        projectLeftVarForJoin != nullptr && dedupDstVids != nullptr) {
-        projectFromJoin = traceToStartVid(projectLeftVarForJoin, dedupDstVids);
+        projectLeftVarForJoin != nullptr && dedupSrcDstVids != nullptr) {
+        projectFromJoin = traceToStartVid(projectLeftVarForJoin, dedupSrcDstVids);
+        loopBody = projectFromJoin;
     }
 
     auto* loop = Loop::make(
         qctx_,
-        projectLeftVarForJoin == nullptr ? dedupStartVid
-                                         : projectLeftVarForJoin,  // dep
-        projectFromJoin == nullptr ? dedupDstVids : projectFromJoin,  // body
+        projectLeftVarForJoin == nullptr ? dedupStartVid : projectLeftVarForJoin,   // dep
+        loopBody,                                                                   // body
         buildNStepLoopCondition(steps_.steps - 1));
 
     auto status = oneStep(loop, dedupDstVids->outputVar(), projectFromJoin);
@@ -299,13 +306,17 @@ Status GoValidator::buildMToNPlan() {
 
     PlanNode* dependencyForProjectResult = dedupDstVids;
 
-    // Trace to the start vid if $-.prop was declared.
+    PlanNode* dedupSrcDstVids = nullptr;
+    if (from_.fromType != FromType::kInstantExpr) {
+        dedupSrcDstVids = projectSrcDstVidsFromGN(dedupDstVids, gn);
+        dependencyForProjectResult = dedupSrcDstVids;
+    }
+
+    // Trace to the start vid if starts from a runtime start vid.
     PlanNode* projectFromJoin = nullptr;
-    if (!exprProps_.inputProps().empty() || !exprProps_.varProps().empty()) {
-        if ((!exprProps_.inputProps().empty() || !exprProps_.varProps().empty()) &&
-            projectLeftVarForJoin != nullptr && dedupDstVids != nullptr) {
-            projectFromJoin = traceToStartVid(projectLeftVarForJoin, dedupDstVids);
-        }
+    if (from_.fromType != FromType::kInstantExpr  &&
+        projectLeftVarForJoin != nullptr && dedupSrcDstVids != nullptr) {
+        projectFromJoin = traceToStartVid(projectLeftVarForJoin, dedupSrcDstVids);
     }
 
     // Get the src props and edge props if $-.prop, $var.prop, $$.tag.prop were declared.
@@ -313,6 +324,9 @@ Status GoValidator::buildMToNPlan() {
     if (!exprProps_.inputProps().empty() || !exprProps_.varProps().empty() ||
         !exprProps_.dstTagProps().empty()) {
         PlanNode* depForProject = dedupDstVids;
+        if (dedupSrcDstVids != nullptr) {
+            depForProject = projectFromJoin;
+        }
         if (projectFromJoin != nullptr) {
             depForProject = projectFromJoin;
         }
@@ -341,9 +355,12 @@ Status GoValidator::buildMToNPlan() {
     if (filter_ != nullptr) {
         auto* filterNode = Filter::make(qctx_, dependencyForProjectResult,
                     newFilter_ != nullptr ? newFilter_ : filter_);
-        filterNode->setInputVar(
-            dependencyForProjectResult == dedupDstVids ?
-                gn->outputVar() : dependencyForProjectResult->outputVar());
+        if (dependencyForProjectResult == dedupDstVids ||
+            dependencyForProjectResult == dedupSrcDstVids) {
+            filterNode->setInputVar(gn->outputVar());
+        } else {
+            filterNode->setInputVar(dependencyForProjectResult->outputVar());
+        }
         filterNode->setColNames(dependencyForProjectResult->colNames());
         dependencyForProjectResult = filterNode;
     }
@@ -351,9 +368,12 @@ Status GoValidator::buildMToNPlan() {
     SingleInputNode* projectResult =
         Project::make(qctx_, dependencyForProjectResult,
         newYieldCols_ != nullptr ? newYieldCols_ : yields_);
-    projectResult->setInputVar(
-            dependencyForProjectResult == dedupDstVids ?
-                gn->outputVar() : dependencyForProjectResult->outputVar());
+    if (dependencyForProjectResult == dedupDstVids ||
+        dependencyForProjectResult == dedupSrcDstVids) {
+        projectResult->setInputVar(gn->outputVar());
+    } else {
+        projectResult->setInputVar(dependencyForProjectResult->outputVar());
+    }
     projectResult->setColNames(std::vector<std::string>(colNames_));
 
     SingleInputNode* dedupNode = nullptr;
@@ -798,5 +818,28 @@ Status GoValidator::buildColumns() {
     return Status::OK();
 }
 
+PlanNode* GoValidator::projectSrcDstVidsFromGN(PlanNode* dep, PlanNode* gn) {
+    Project* project = nullptr;
+    auto* columns = qctx_->objPool()->add(new YieldColumns());
+    auto* column = new YieldColumn(
+        new EdgePropertyExpression(new std::string("*"), new std::string(kDst)),
+        new std::string(kVid));
+    columns->addColumn(column);
+
+    srcVidColName_ = vctx_->anonColGen()->getCol();
+    column = new YieldColumn(new InputPropertyExpression(new std::string(kVid)),
+                             new std::string(srcVidColName_));
+    columns->addColumn(column);
+
+    project = Project::make(qctx_, dep, columns);
+    project->setInputVar(gn->outputVar());
+    project->setColNames(deduceColNames(columns));
+    VLOG(1) << project->outputVar();
+
+    auto* dedupSrcDstVids = Dedup::make(qctx_, project);
+    dedupSrcDstVids->setInputVar(project->outputVar());
+    dedupSrcDstVids->setColNames(project->colNames());
+    return dedupSrcDstVids;
+}
 }  // namespace graph
 }  // namespace nebula
diff --git a/src/validator/GoValidator.h b/src/validator/GoValidator.h
index 5bc9eb57..c6e467af 100644
--- a/src/validator/GoValidator.h
+++ b/src/validator/GoValidator.h
@@ -65,6 +65,8 @@ private:
 
     PlanNode* buildJoinDstProps(PlanNode* projectSrcDstProps);
 
+    PlanNode* projectSrcDstVidsFromGN(PlanNode* dep, PlanNode* gn);
+
 private:
     Over                                                    over_;
     Expression*                                             filter_{nullptr};
diff --git a/src/validator/TraversalValidator.cpp b/src/validator/TraversalValidator.cpp
index 89a5c581..74a7ffee 100644
--- a/src/validator/TraversalValidator.cpp
+++ b/src/validator/TraversalValidator.cpp
@@ -137,8 +137,6 @@ Status TraversalValidator::validateStep(const StepClause* clause, Steps& step) {
     return Status::OK();
 }
 
-
-
 PlanNode* TraversalValidator::projectDstVidsFromGN(PlanNode* gn, const std::string& outputVar) {
     Project* project = nullptr;
     auto* columns = qctx_->objPool()->add(new YieldColumns());
@@ -147,14 +145,6 @@ PlanNode* TraversalValidator::projectDstVidsFromGN(PlanNode* gn, const std::stri
         new std::string(kVid));
     columns->addColumn(column);
 
-    srcVidColName_ = vctx_->anonColGen()->getCol();
-    if (from_.fromType != FromType::kInstantExpr) {
-        column =
-            new YieldColumn(new InputPropertyExpression(new std::string(kVid)),
-                            new std::string(srcVidColName_));
-        columns->addColumn(column);
-    }
-
     project = Project::make(qctx_, gn, columns);
     project->setInputVar(gn->outputVar());
     project->setColNames(deduceColNames(columns));
diff --git a/src/validator/test/QueryValidatorTest.cpp b/src/validator/test/QueryValidatorTest.cpp
index 3274ebeb..f36b85c8 100644
--- a/src/validator/test/QueryValidatorTest.cpp
+++ b/src/validator/test/QueryValidatorTest.cpp
@@ -171,8 +171,10 @@ TEST_F(QueryValidatorTest, GoWithPipe) {
             PK::kProject,
             PK::kProject,
             PK::kGetNeighbors,
-            PK::kGetNeighbors,
+            PK::kDedup,
             PK::kStart,
+            PK::kProject,
+            PK::kGetNeighbors,
             PK::kStart
         };
         EXPECT_TRUE(checkResult(query, expected));
@@ -287,8 +289,10 @@ TEST_F(QueryValidatorTest, GoWithPipe) {
             PK::kProject,
             PK::kProject,
             PK::kGetNeighbors,
-            PK::kGetNeighbors,
+            PK::kDedup,
             PK::kStart,
+            PK::kProject,
+            PK::kGetNeighbors,
             PK::kStart,
         };
         EXPECT_TRUE(checkResult(query, expected));
@@ -316,8 +320,10 @@ TEST_F(QueryValidatorTest, GoWithPipe) {
             PK::kProject,
             PK::kProject,
             PK::kGetNeighbors,
-            PK::kGetNeighbors,
+            PK::kDedup,
             PK::kStart,
+            PK::kProject,
+            PK::kGetNeighbors,
             PK::kStart,
         };
         EXPECT_TRUE(checkResult(query, expected));
@@ -347,8 +353,10 @@ TEST_F(QueryValidatorTest, GoWithPipe) {
             PK::kProject,
             PK::kProject,
             PK::kGetNeighbors,
-            PK::kGetNeighbors,
+            PK::kDedup,
             PK::kStart,
+            PK::kProject,
+            PK::kGetNeighbors,
             PK::kStart,
         };
         EXPECT_TRUE(checkResult(query, expected));
@@ -464,11 +472,13 @@ TEST_F(QueryValidatorTest, GoWithPipe) {
             PK::kProject,
             PK::kProject,
             PK::kDataJoin,
-            PK::kGetNeighbors,
+            PK::kDedup,
+            PK::kProject,
             PK::kProject,
-            PK::kStart,
             PK::kGetVertices,
+            PK::kGetNeighbors,
             PK::kDedup,
+            PK::kStart,
             PK::kProject,
             PK::kProject,
             PK::kGetNeighbors,
@@ -506,11 +516,13 @@ TEST_F(QueryValidatorTest, GoWithPipe) {
             PK::kProject,
             PK::kProject,
             PK::kDataJoin,
-            PK::kGetNeighbors,
+            PK::kDedup,
+            PK::kProject,
             PK::kProject,
-            PK::kStart,
             PK::kGetVertices,
+            PK::kGetNeighbors,
             PK::kDedup,
+            PK::kStart,
             PK::kProject,
             PK::kProject,
             PK::kGetNeighbors,
@@ -1026,6 +1038,8 @@ TEST_F(QueryValidatorTest, GoMToN) {
             PK::kDataJoin,
             PK::kDedup,
             PK::kProject,
+            PK::kDedup,
+            PK::kProject,
             PK::kGetNeighbors,
             PK::kStart,
         };
diff --git a/src/validator/test/SymbolsTest.cpp b/src/validator/test/SymbolsTest.cpp
index 35248a59..829e2c7d 100644
--- a/src/validator/test/SymbolsTest.cpp
+++ b/src/validator/test/SymbolsTest.cpp
@@ -52,14 +52,14 @@ TEST_F(SymbolsTest, Variables) {
         auto* symTable = qctx->symTable();
 
         {
-            auto varName = "__Start_19";
+            auto varName = "__Start_21";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_TRUE(variable->colNames.empty());
             EXPECT_TRUE(checkNodes(variable->readBy, {}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {19}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {21}));
         }
         {
             auto varName = "__GetNeighbors_0";
@@ -78,7 +78,7 @@ TEST_F(SymbolsTest, Variables) {
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_EQ(variable->colNames, std::vector<std::string>({"id"}));
-            EXPECT_TRUE(checkNodes(variable->readBy, {3, 5, 17}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {3, 5, 19}));
             EXPECT_TRUE(checkNodes(variable->writtenBy, {1}));
         }
         {
@@ -97,8 +97,8 @@ TEST_F(SymbolsTest, Variables) {
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
-            EXPECT_EQ(variable->colNames, std::vector<std::string>({"_vid", "__UNAMED_COL_2"}));
-            EXPECT_TRUE(checkNodes(variable->readBy, {7, 10, 14}));
+            EXPECT_EQ(variable->colNames, std::vector<std::string>({"_vid"}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {7, 16}));
             EXPECT_TRUE(checkNodes(variable->writtenBy, {4, 9}));
         }
         {
@@ -118,8 +118,8 @@ TEST_F(SymbolsTest, Variables) {
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_EQ(variable->colNames, std::vector<std::string>({"id", "__UNAMED_COL_1"}));
-            EXPECT_TRUE(checkNodes(variable->readBy, {10, 13, 16}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {6, 12}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {12, 15, 18}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {6, 14}));
         }
         {
             auto varName = "__Start_2";
@@ -138,7 +138,7 @@ TEST_F(SymbolsTest, Variables) {
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_TRUE(variable->colNames.empty());
-            EXPECT_TRUE(checkNodes(variable->readBy, {8}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {8, 10}));
             EXPECT_TRUE(checkNodes(variable->writtenBy, {7}));
         }
         {
@@ -147,74 +147,94 @@ TEST_F(SymbolsTest, Variables) {
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
-            EXPECT_EQ(variable->colNames, std::vector<std::string>({"_vid", "__UNAMED_COL_2"}));
+            EXPECT_EQ(variable->colNames, std::vector<std::string>({"_vid"}));
             EXPECT_TRUE(checkNodes(variable->readBy, {9}));
             EXPECT_TRUE(checkNodes(variable->writtenBy, {8}));
         }
         {
-            auto varName = "__DataJoin_10";
+            auto varName = "__Project_10";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
-            EXPECT_EQ(variable->colNames,
-                      std::vector<std::string>({"id", "__UNAMED_COL_1", "_vid", "__UNAMED_COL_2"}));
+            EXPECT_EQ(variable->colNames, std::vector<std::string>({"_vid", "__UNAMED_COL_2"}));
             EXPECT_TRUE(checkNodes(variable->readBy, {11}));
             EXPECT_TRUE(checkNodes(variable->writtenBy, {10}));
         }
         {
-            auto varName = "__Project_11";
+            auto varName = "__Dedup_11";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
-            EXPECT_EQ(variable->colNames, std::vector<std::string>({"id", "__UNAMED_COL_1"}));
+            EXPECT_EQ(variable->colNames, std::vector<std::string>({"_vid", "__UNAMED_COL_2"}));
             EXPECT_TRUE(checkNodes(variable->readBy, {12}));
             EXPECT_TRUE(checkNodes(variable->writtenBy, {11}));
         }
         {
-            auto varName = "__Loop_13";
+            auto varName = "__DataJoin_12";
+            auto* variable = symTable->getVar(varName);
+            EXPECT_NE(variable, nullptr);
+            EXPECT_EQ(variable->name, varName);
+            EXPECT_EQ(variable->type, Value::Type::DATASET);
+            EXPECT_EQ(variable->colNames,
+                      std::vector<std::string>({"id", "__UNAMED_COL_1", "_vid", "__UNAMED_COL_2"}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {13}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {12}));
+        }
+        {
+            auto varName = "__Project_13";
+            auto* variable = symTable->getVar(varName);
+            EXPECT_NE(variable, nullptr);
+            EXPECT_EQ(variable->name, varName);
+            EXPECT_EQ(variable->type, Value::Type::DATASET);
+            EXPECT_EQ(variable->colNames, std::vector<std::string>({"id", "__UNAMED_COL_1"}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {14}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {13}));
+        }
+        {
+            auto varName = "__Loop_15";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_EQ(variable->colNames, std::vector<std::string>({}));
             EXPECT_TRUE(checkNodes(variable->readBy, {}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {13}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {15}));
         }
         {
-            auto varName = "__GetNeighbors_14";
+            auto varName = "__GetNeighbors_16";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_TRUE(variable->colNames.empty());
-            EXPECT_TRUE(checkNodes(variable->readBy, {15}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {14}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {17}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {16}));
         }
         {
-            auto varName = "__Project_15";
+            auto varName = "__Project_17";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_EQ(variable->colNames, std::vector<std::string>({"__UNAMED_COL_0", "_vid"}));
-            EXPECT_TRUE(checkNodes(variable->readBy, {16}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {15}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {18}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {17}));
         }
         {
-            auto varName = "__DataJoin_16";
+            auto varName = "__DataJoin_18";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_EQ(variable->colNames,
                       std::vector<std::string>({"__UNAMED_COL_0", "_vid", "id", "__UNAMED_COL_1"}));
-            EXPECT_TRUE(checkNodes(variable->readBy, {17}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {16}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {19}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {18}));
         }
         {
-            auto varName = "__DataJoin_17";
+            auto varName = "__DataJoin_19";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
@@ -222,18 +242,18 @@ TEST_F(SymbolsTest, Variables) {
             EXPECT_EQ(variable->colNames,
                       std::vector<std::string>(
                           {"__UNAMED_COL_0", "_vid", "id", "__UNAMED_COL_1", "like._dst"}));
-            EXPECT_TRUE(checkNodes(variable->readBy, {18}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {17}));
+            EXPECT_TRUE(checkNodes(variable->readBy, {20}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {19}));
         }
         {
-            auto varName = "__Project_18";
+            auto varName = "__Project_20";
             auto* variable = symTable->getVar(varName);
             EXPECT_NE(variable, nullptr);
             EXPECT_EQ(variable->name, varName);
             EXPECT_EQ(variable->type, Value::Type::DATASET);
             EXPECT_EQ(variable->colNames, std::vector<std::string>({"like._dst"}));
             EXPECT_TRUE(checkNodes(variable->readBy, {}));
-            EXPECT_TRUE(checkNodes(variable->writtenBy, {18}));
+            EXPECT_TRUE(checkNodes(variable->writtenBy, {20}));
         }
     }
 }
diff --git a/tests/query/v1/test_go.py b/tests/query/v1/test_go.py
index fe841084..324c3abf 100644
--- a/tests/query/v1/test_go.py
+++ b/tests/query/v1/test_go.py
@@ -977,8 +977,7 @@ class TestGoQuery(NebulaTestSuite):
             ]
         }
         self.check_column_names(resp, expected_data["column_names"])
-        # TODO: datajoin error, need to fix it
-        # self.check_out_of_order_result(resp, expected_data["rows"])
+        self.check_out_of_order_result(resp, expected_data["rows"])
 
     def test_reversely_two_steps(self):
         stmt = "GO 2 STEPS FROM 'Kobe Bryant' OVER like REVERSELY YIELD $$.player.name"
-- 
GitLab