diff --git a/cmd/storage/benchmark.go b/cmd/storage/benchmark.go
new file mode 100644
index 0000000000000000000000000000000000000000..6a7c06f9a209b25df74545d1f23f58b2d1c80594
--- /dev/null
+++ b/cmd/storage/benchmark.go
@@ -0,0 +1,315 @@
+package main
+
+import (
+ "context"
+ "crypto/md5"
+ "flag"
+ "fmt"
+ "log"
+ "math/rand"
+ "os"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/pivotal-golang/bytefmt"
+ "github.com/zilliztech/milvus-distributed/internal/storage"
+ storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+// Global variables
+var durationSecs, threads, loops, numVersion, batchOpSize int
+var valueSize uint64
+var valueData []byte
+var batchValueData [][]byte
+var counter, totalKeyCount, keyNum int32
+var endTime, setFinish, getFinish, deleteFinish time.Time
+var totalKeys [][]byte
+
+var logFileName = "benchmark.log"
+var logFile *os.File
+
+var store storagetype.Store
+var wg sync.WaitGroup
+
+func runSet() {
+ for time.Now().Before(endTime) {
+ num := atomic.AddInt32(&keyNum, 1)
+ key := []byte(fmt.Sprint("key", num))
+ for ver := 1; ver <= numVersion; ver++ {
+ atomic.AddInt32(&counter, 1)
+ err := store.PutRow(context.Background(), key, valueData, "empty", uint64(ver))
+ if err != nil {
+ log.Fatalf("Error setting key %s, %s", key, err.Error())
+ //atomic.AddInt32(&setCount, -1)
+ }
+ }
+ }
+ // Remember last done time
+ setFinish = time.Now()
+ wg.Done()
+}
+
+func runBatchSet() {
+ for time.Now().Before(endTime) {
+ num := atomic.AddInt32(&keyNum, int32(batchOpSize))
+ keys := make([][]byte, batchOpSize)
+ versions := make([]uint64, batchOpSize)
+ batchSuffix := make([]string, batchOpSize)
+ for n := batchOpSize; n > 0; n-- {
+ keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
+ }
+ for ver := 1; ver <= numVersion; ver++ {
+ atomic.AddInt32(&counter, 1)
+ err := store.PutRows(context.Background(), keys, batchValueData, batchSuffix, versions)
+ if err != nil {
+ log.Fatalf("Error setting batch keys %s %s", keys, err.Error())
+ //atomic.AddInt32(&batchSetCount, -1)
+ }
+ }
+ }
+ setFinish = time.Now()
+ wg.Done()
+}
+
+func runGet() {
+ for time.Now().Before(endTime) {
+ num := atomic.AddInt32(&counter, 1)
+ //num := atomic.AddInt32(&keyNum, 1)
+ //key := []byte(fmt.Sprint("key", num))
+ num = num % totalKeyCount
+ key := totalKeys[num]
+ _, err := store.GetRow(context.Background(), key, uint64(numVersion))
+ if err != nil {
+ log.Fatalf("Error getting key %s, %s", key, err.Error())
+ //atomic.AddInt32(&getCount, -1)
+ }
+ }
+ // Remember last done time
+ getFinish = time.Now()
+ wg.Done()
+}
+
+func runBatchGet() {
+ for time.Now().Before(endTime) {
+ num := atomic.AddInt32(&keyNum, int32(batchOpSize))
+ //keys := make([][]byte, batchOpSize)
+ //for n := batchOpSize; n > 0; n-- {
+ // keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
+ //}
+ end := num % totalKeyCount
+ if end < int32(batchOpSize) {
+ end = int32(batchOpSize)
+ }
+ start := end - int32(batchOpSize)
+ keys := totalKeys[start:end]
+ versions := make([]uint64, batchOpSize)
+ for i := range versions {
+ versions[i] = uint64(numVersion)
+ }
+ atomic.AddInt32(&counter, 1)
+ _, err := store.GetRows(context.Background(), keys, versions)
+ if err != nil {
+ log.Fatalf("Error getting key %s, %s", keys, err.Error())
+ //atomic.AddInt32(&batchGetCount, -1)
+ }
+ }
+ // Remember last done time
+ getFinish = time.Now()
+ wg.Done()
+}
+
+func runDelete() {
+ for time.Now().Before(endTime) {
+ num := atomic.AddInt32(&counter, 1)
+ //num := atomic.AddInt32(&keyNum, 1)
+ //key := []byte(fmt.Sprint("key", num))
+ num = num % totalKeyCount
+ key := totalKeys[num]
+ err := store.DeleteRow(context.Background(), key, uint64(numVersion))
+ if err != nil {
+ log.Fatalf("Error getting key %s, %s", key, err.Error())
+ //atomic.AddInt32(&deleteCount, -1)
+ }
+ }
+ // Remember last done time
+ deleteFinish = time.Now()
+ wg.Done()
+}
+
+func runBatchDelete() {
+ for time.Now().Before(endTime) {
+ num := atomic.AddInt32(&keyNum, int32(batchOpSize))
+ //keys := make([][]byte, batchOpSize)
+ //for n := batchOpSize; n > 0; n-- {
+ // keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
+ //}
+ end := num % totalKeyCount
+ if end < int32(batchOpSize) {
+ end = int32(batchOpSize)
+ }
+ start := end - int32(batchOpSize)
+ keys := totalKeys[start:end]
+ atomic.AddInt32(&counter, 1)
+ versions := make([]uint64, batchOpSize)
+ for i := range versions {
+ versions[i] = uint64(numVersion)
+ }
+ err := store.DeleteRows(context.Background(), keys, versions)
+ if err != nil {
+ log.Fatalf("Error getting key %s, %s", keys, err.Error())
+ //atomic.AddInt32(&batchDeleteCount, -1)
+ }
+ }
+ // Remember last done time
+ getFinish = time.Now()
+ wg.Done()
+}
+
+func main() {
+ // Parse command line
+ myflag := flag.NewFlagSet("myflag", flag.ExitOnError)
+ myflag.IntVar(&durationSecs, "d", 5, "Duration of each test in seconds")
+ myflag.IntVar(&threads, "t", 1, "Number of threads to run")
+ myflag.IntVar(&loops, "l", 1, "Number of times to repeat test")
+ var sizeArg string
+ var storeType string
+ myflag.StringVar(&sizeArg, "z", "1k", "Size of objects in bytes with postfix K, M, and G")
+ myflag.StringVar(&storeType, "s", "s3", "Storage type, tikv or minio or s3")
+ myflag.IntVar(&numVersion, "v", 1, "Max versions for each key")
+ myflag.IntVar(&batchOpSize, "b", 100, "Batch operation kv pair number")
+
+ if err := myflag.Parse(os.Args[1:]); err != nil {
+ os.Exit(1)
+ }
+
+ // Check the arguments
+ var err error
+ if valueSize, err = bytefmt.ToBytes(sizeArg); err != nil {
+ log.Fatalf("Invalid -z argument for object size: %v", err)
+ }
+ var option = storagetype.Option{TikvAddress: "localhost:2379", Type: storeType, BucketName: "zilliz-hz"}
+
+ store, err = storage.NewStore(context.Background(), option)
+ if err != nil {
+ log.Fatalf("Error when creating storage " + err.Error())
+ }
+ logFile, err = os.OpenFile(logFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0777)
+ if err != nil {
+ log.Fatalf("Prepare log file error, " + err.Error())
+ }
+
+ // Echo the parameters
+ log.Printf("Benchmark log will write to file %s\n", logFile.Name())
+ fmt.Fprintf(logFile, "Parameters: duration=%d, threads=%d, loops=%d, valueSize=%s, batchSize=%d, versions=%d\n", durationSecs, threads, loops, sizeArg, batchOpSize, numVersion)
+ // Init test data
+ valueData = make([]byte, valueSize)
+ rand.Read(valueData)
+ hasher := md5.New()
+ hasher.Write(valueData)
+
+ batchValueData = make([][]byte, batchOpSize)
+ for i := range batchValueData {
+ batchValueData[i] = make([]byte, valueSize)
+ rand.Read(batchValueData[i])
+ hasher := md5.New()
+ hasher.Write(batchValueData[i])
+ }
+
+ // Loop running the tests
+ for loop := 1; loop <= loops; loop++ {
+
+ // reset counters
+ counter = 0
+ keyNum = 0
+ totalKeyCount = 0
+ totalKeys = nil
+
+ // Run the batchSet case
+ // key seq start from setCount
+ counter = 0
+ startTime := time.Now()
+ endTime = startTime.Add(time.Second * time.Duration(durationSecs))
+ for n := 1; n <= threads; n++ {
+ wg.Add(1)
+ go runBatchSet()
+ }
+ wg.Wait()
+
+ setTime := setFinish.Sub(startTime).Seconds()
+ bps := float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime
+ fmt.Fprintf(logFile, "Loop %d: BATCH PUT time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
+ loop, setTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/setTime, float64(counter*int32(batchOpSize))/setTime)
+ // Record all test keys
+ //totalKeyCount = keyNum
+ //totalKeys = make([][]byte, totalKeyCount)
+ //for i := int32(0); i < totalKeyCount; i++ {
+ // totalKeys[i] = []byte(fmt.Sprint("key", i))
+ //}
+ //
+ //// Run the get case
+ //counter = 0
+ //startTime = time.Now()
+ //endTime = startTime.Add(time.Second * time.Duration(durationSecs))
+ //for n := 1; n <= threads; n++ {
+ // wg.Add(1)
+ // go runGet()
+ //}
+ //wg.Wait()
+ //
+ //getTime := getFinish.Sub(startTime).Seconds()
+ //bps = float64(uint64(counter)*valueSize) / getTime
+ //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: GET time %.1f secs, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
+ // loop, getTime, counter, bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter)/getTime))
+
+ // Run the batchGet case
+ //counter = 0
+ //startTime = time.Now()
+ //endTime = startTime.Add(time.Second * time.Duration(durationSecs))
+ //for n := 1; n <= threads; n++ {
+ // wg.Add(1)
+ // go runBatchGet()
+ //}
+ //wg.Wait()
+ //
+ //getTime = getFinish.Sub(startTime).Seconds()
+ //bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / getTime
+ //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH GET time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
+ // loop, getTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter * int32(batchOpSize))/getTime))
+ //
+ //// Run the delete case
+ //counter = 0
+ //startTime = time.Now()
+ //endTime = startTime.Add(time.Second * time.Duration(durationSecs))
+ //for n := 1; n <= threads; n++ {
+ // wg.Add(1)
+ // go runDelete()
+ //}
+ //wg.Wait()
+ //
+ //deleteTime := deleteFinish.Sub(startTime).Seconds()
+ //bps = float64(uint64(counter)*valueSize) / deleteTime
+ //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: Delete time %.1f secs, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n",
+ // loop, deleteTime, counter, float64(counter)/deleteTime, float64(counter)/deleteTime))
+ //
+ //// Run the batchDelete case
+ //counter = 0
+ //startTime = time.Now()
+ //endTime = startTime.Add(time.Second * time.Duration(durationSecs))
+ //for n := 1; n <= threads; n++ {
+ // wg.Add(1)
+ // go runBatchDelete()
+ //}
+ //wg.Wait()
+ //
+ //deleteTime = setFinish.Sub(startTime).Seconds()
+ //bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime
+ //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH DELETE time %.1f secs, batchs = %d, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n",
+ // loop, setTime, counter, counter*int32(batchOpSize), float64(counter)/setTime, float64(counter * int32(batchOpSize))/setTime))
+
+ // Print line mark
+ lineMark := "\n"
+ fmt.Fprint(logFile, lineMark)
+ }
+ log.Print("Benchmark test done.")
+}
diff --git a/deployments/docker/docker-compose.yml b/deployments/docker/docker-compose.yml
index 60bf5d9fff1f2a0aa47274becc742f778aaeb77a..0ae708a19ecb9ababafe5fcdb6bd5f9d5eac529e 100644
--- a/deployments/docker/docker-compose.yml
+++ b/deployments/docker/docker-compose.yml
@@ -36,6 +36,14 @@ services:
networks:
- milvus
+ jaeger:
+ image: jaegertracing/all-in-one:latest
+ ports:
+ - "6831:6831/udp"
+ - "16686:16686"
+ networks:
+ - milvus
+
networks:
milvus:
diff --git a/docker-compose.yml b/docker-compose.yml
index cba23befabc4c39314b179b3227a1d3570ee3c31..9f3599abb9a5b139323717b4e98e4a9d7e91b8f4 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -83,5 +83,10 @@ services:
networks:
- milvus
+ jaeger:
+ image: jaegertracing/all-in-one:latest
+ networks:
+ - milvus
+
networks:
milvus:
diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt
index b272e388853fedb8b1178d41302826ff2b2fe8b7..a1de1d4ed502053407f16f1fc6e107c163cda653 100644
--- a/internal/core/src/query/CMakeLists.txt
+++ b/internal/core/src/query/CMakeLists.txt
@@ -4,15 +4,13 @@ set(MILVUS_QUERY_SRCS
generated/PlanNode.cpp
generated/Expr.cpp
visitors/ShowPlanNodeVisitor.cpp
- visitors/ShowExprVisitor.cpp
visitors/ExecPlanNodeVisitor.cpp
+ visitors/ShowExprVisitor.cpp
visitors/ExecExprVisitor.cpp
- visitors/VerifyPlanNodeVisitor.cpp
- visitors/VerifyExprVisitor.cpp
Plan.cpp
Search.cpp
SearchOnSealed.cpp
BruteForceSearch.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
-target_link_libraries(milvus_query milvus_proto milvus_utils knowhere)
+target_link_libraries(milvus_query milvus_proto milvus_utils)
diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp
index 78f1f14c73ebb94d897ff6629cdd62bf3c05e05f..96653593516c15f98797a9f2d4c18e316b0161e1 100644
--- a/internal/core/src/query/Plan.cpp
+++ b/internal/core/src/query/Plan.cpp
@@ -21,7 +21,6 @@
#include <boost/align/aligned_allocator.hpp>
#include <boost/algorithm/string.hpp>
#include <algorithm>
-#include "query/generated/VerifyPlanNodeVisitor.h"
namespace milvus::query {
@@ -139,8 +138,6 @@ Parser::CreatePlanImpl(const std::string& dsl_str) {
if (predicate != nullptr) {
vec_node->predicate_ = std::move(predicate);
}
- VerifyPlanNodeVisitor verifier;
- vec_node->accept(verifier);
auto plan = std::make_unique<Plan>(schema);
plan->tag2field_ = std::move(tag2field_);
diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h
index a9e0574a6e527a6f8fe5856ec640f15072824390..250d68a6e567a52f9cf9c6bc8c7c886abf12f3f3 100644
--- a/internal/core/src/query/generated/ExecExprVisitor.h
+++ b/internal/core/src/query/generated/ExecExprVisitor.h
@@ -21,7 +21,7 @@
#include "ExprVisitor.h"
namespace milvus::query {
-class ExecExprVisitor : public ExprVisitor {
+class ExecExprVisitor : ExprVisitor {
public:
void
visit(BoolUnaryExpr& expr) override;
diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h
index c026c689857958c27b44daf2c160b51592b26c44..0eb33384d71eec5e01a486014c27c1049e442c46 100644
--- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h
+++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h
@@ -19,7 +19,7 @@
#include "PlanNodeVisitor.h"
namespace milvus::query {
-class ExecPlanNodeVisitor : public PlanNodeVisitor {
+class ExecPlanNodeVisitor : PlanNodeVisitor {
public:
void
visit(FloatVectorANNS& node) override;
diff --git a/internal/core/src/query/generated/ShowExprVisitor.h b/internal/core/src/query/generated/ShowExprVisitor.h
index 6a1ed2646fc641b7670a0c7f100da9ed8408dc06..55659e24c04e4a419a97b50ec39cfaffb1bcb558 100644
--- a/internal/core/src/query/generated/ShowExprVisitor.h
+++ b/internal/core/src/query/generated/ShowExprVisitor.h
@@ -19,7 +19,7 @@
#include "ExprVisitor.h"
namespace milvus::query {
-class ShowExprVisitor : public ExprVisitor {
+class ShowExprVisitor : ExprVisitor {
public:
void
visit(BoolUnaryExpr& expr) override;
diff --git a/internal/core/src/query/generated/ShowPlanNodeVisitor.h b/internal/core/src/query/generated/ShowPlanNodeVisitor.h
index c518c3f7d0b23204f804c035db3471bcf08c4831..b921ec81fc5aa3eb29eb15c091a72738cfc57d4b 100644
--- a/internal/core/src/query/generated/ShowPlanNodeVisitor.h
+++ b/internal/core/src/query/generated/ShowPlanNodeVisitor.h
@@ -20,7 +20,7 @@
#include "PlanNodeVisitor.h"
namespace milvus::query {
-class ShowPlanNodeVisitor : public PlanNodeVisitor {
+class ShowPlanNodeVisitor : PlanNodeVisitor {
public:
void
visit(FloatVectorANNS& node) override;
diff --git a/internal/core/src/query/generated/VerifyExprVisitor.cpp b/internal/core/src/query/generated/VerifyExprVisitor.cpp
deleted file mode 100644
index 44af4dde81bdeea864b8bef34d064e4fbd4f2fee..0000000000000000000000000000000000000000
--- a/internal/core/src/query/generated/VerifyExprVisitor.cpp
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright (C) 2019-2020 Zilliz. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software distributed under the License
-// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-// or implied. See the License for the specific language governing permissions and limitations under the License
-
-#error TODO: copy this file out, and modify the content.
-#include "query/generated/VerifyExprVisitor.h"
-
-namespace milvus::query {
-void
-VerifyExprVisitor::visit(BoolUnaryExpr& expr) {
- // TODO
-}
-
-void
-VerifyExprVisitor::visit(BoolBinaryExpr& expr) {
- // TODO
-}
-
-void
-VerifyExprVisitor::visit(TermExpr& expr) {
- // TODO
-}
-
-void
-VerifyExprVisitor::visit(RangeExpr& expr) {
- // TODO
-}
-
-} // namespace milvus::query
diff --git a/internal/core/src/query/generated/VerifyExprVisitor.h b/internal/core/src/query/generated/VerifyExprVisitor.h
deleted file mode 100644
index 6b04a76978d7db2c8247ac45399b50b85f44309d..0000000000000000000000000000000000000000
--- a/internal/core/src/query/generated/VerifyExprVisitor.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright (C) 2019-2020 Zilliz. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software distributed under the License
-// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-// or implied. See the License for the specific language governing permissions and limitations under the License
-
-#pragma once
-// Generated File
-// DO NOT EDIT
-#include <optional>
-#include <boost/dynamic_bitset.hpp>
-#include <utility>
-#include <deque>
-#include "segcore/SegmentSmallIndex.h"
-#include "query/ExprImpl.h"
-#include "ExprVisitor.h"
-
-namespace milvus::query {
-class VerifyExprVisitor : public ExprVisitor {
- public:
- void
- visit(BoolUnaryExpr& expr) override;
-
- void
- visit(BoolBinaryExpr& expr) override;
-
- void
- visit(TermExpr& expr) override;
-
- void
- visit(RangeExpr& expr) override;
-
- public:
-};
-} // namespace milvus::query
diff --git a/internal/core/src/query/generated/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/generated/VerifyPlanNodeVisitor.cpp
deleted file mode 100644
index c7b0656f57041427e8159e899f891ca0f391c901..0000000000000000000000000000000000000000
--- a/internal/core/src/query/generated/VerifyPlanNodeVisitor.cpp
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright (C) 2019-2020 Zilliz. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software distributed under the License
-// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-// or implied. See the License for the specific language governing permissions and limitations under the License
-
-#error TODO: copy this file out, and modify the content.
-#include "query/generated/VerifyPlanNodeVisitor.h"
-
-namespace milvus::query {
-void
-VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) {
- // TODO
-}
-
-void
-VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) {
- // TODO
-}
-
-} // namespace milvus::query
diff --git a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h b/internal/core/src/query/generated/VerifyPlanNodeVisitor.h
deleted file mode 100644
index a964e6c08f920bcf3b2b1f4e7f70ebbddb0e264a..0000000000000000000000000000000000000000
--- a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright (C) 2019-2020 Zilliz. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software distributed under the License
-// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-// or implied. See the License for the specific language governing permissions and limitations under the License
-
-#pragma once
-// Generated File
-// DO NOT EDIT
-#include "utils/Json.h"
-#include "query/PlanImpl.h"
-#include "segcore/SegmentBase.h"
-#include <utility>
-#include "PlanNodeVisitor.h"
-
-namespace milvus::query {
-class VerifyPlanNodeVisitor : public PlanNodeVisitor {
- public:
- void
- visit(FloatVectorANNS& node) override;
-
- void
- visit(BinaryVectorANNS& node) override;
-
- public:
- using RetType = QueryResult;
- VerifyPlanNodeVisitor() = default;
-
- private:
- std::optional<RetType> ret_;
-};
-} // namespace milvus::query
diff --git a/internal/core/src/query/visitors/VerifyExprVisitor.cpp b/internal/core/src/query/visitors/VerifyExprVisitor.cpp
deleted file mode 100644
index 9b3326c74a60b9f604de92a24fb71c4156a60cad..0000000000000000000000000000000000000000
--- a/internal/core/src/query/visitors/VerifyExprVisitor.cpp
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright (C) 2019-2020 Zilliz. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software distributed under the License
-// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-// or implied. See the License for the specific language governing permissions and limitations under the License
-
-#include "query/generated/VerifyExprVisitor.h"
-
-namespace milvus::query {
-void
-VerifyExprVisitor::visit(BoolUnaryExpr& expr) {
- // TODO
-}
-
-void
-VerifyExprVisitor::visit(BoolBinaryExpr& expr) {
- // TODO
-}
-
-void
-VerifyExprVisitor::visit(TermExpr& expr) {
- // TODO
-}
-
-void
-VerifyExprVisitor::visit(RangeExpr& expr) {
- // TODO
-}
-
-} // namespace milvus::query
diff --git a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp
deleted file mode 100644
index 263390a39a52831c0eb7f1bd71f8dc0cc1cecdb5..0000000000000000000000000000000000000000
--- a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright (C) 2019-2020 Zilliz. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software distributed under the License
-// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-// or implied. See the License for the specific language governing permissions and limitations under the License
-
-#include "query/generated/VerifyPlanNodeVisitor.h"
-#include "knowhere/index/vector_index/ConfAdapterMgr.h"
-#include "segcore/SegmentSmallIndex.h"
-#include "knowhere/index/vector_index/ConfAdapter.h"
-#include "knowhere/index/vector_index/helpers/IndexParameter.h"
-
-namespace milvus::query {
-
-#if 1
-namespace impl {
-// THIS CONTAINS EXTRA BODY FOR VISITOR
-// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
-class VerifyPlanNodeVisitor : PlanNodeVisitor {
- public:
- using RetType = QueryResult;
- VerifyPlanNodeVisitor() = default;
-
- private:
- std::optional<RetType> ret_;
-};
-} // namespace impl
-#endif
-
-static knowhere::IndexType
-InferIndexType(const Json& search_params) {
- // ivf -> nprobe
- // nsg -> search_length
- // hnsw/rhnsw/*pq/*sq -> ef
- // annoy -> search_k
- // ngtpanng / ngtonng -> max_search_edges / epsilon
- static const std::map<std::string, knowhere::IndexType> key_list = [] {
- std::map<std::string, knowhere::IndexType> list;
- namespace ip = knowhere::IndexParams;
- namespace ie = knowhere::IndexEnum;
- list.emplace(ip::nprobe, ie::INDEX_FAISS_IVFFLAT);
- list.emplace(ip::search_length, ie::INDEX_NSG);
- list.emplace(ip::ef, ie::INDEX_HNSW);
- list.emplace(ip::search_k, ie::INDEX_ANNOY);
- list.emplace(ip::max_search_edges, ie::INDEX_NGTONNG);
- list.emplace(ip::epsilon, ie::INDEX_NGTONNG);
- return list;
- }();
- auto dbg_str = search_params.dump();
- for (auto& kv : search_params.items()) {
- std::string key = kv.key();
- if (key_list.count(key)) {
- return key_list.at(key);
- }
- }
- PanicInfo("failed to infer index type");
-}
-
-void
-VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) {
- auto& search_params = node.query_info_.search_params_;
- auto inferred_type = InferIndexType(search_params);
- auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(inferred_type);
- auto index_mode = knowhere::IndexMode::MODE_CPU;
-
- // mock the api, topk will be passed from placeholder
- auto params_copy = search_params;
- params_copy[knowhere::meta::TOPK] = 10;
-
- // NOTE: the second parameter is not checked in knowhere, may be redundant
- auto passed = adapter->CheckSearch(params_copy, inferred_type, index_mode);
- AssertInfo(passed, "invalid search params");
-}
-
-void
-VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) {
- // TODO
-}
-
-} // namespace milvus::query
diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt
index 1a011c984b690d92e62dc6bb172aa245cbfd140f..709a983c977aceacc0c049b070887044701840f6 100644
--- a/internal/core/src/segcore/CMakeLists.txt
+++ b/internal/core/src/segcore/CMakeLists.txt
@@ -24,6 +24,5 @@ target_link_libraries(milvus_segcore
dl backtrace
milvus_common
milvus_query
- milvus_utils
)
diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt
index 29fe7b32cfe6068b0b06662d2757de00c3e9a86d..a69690728065e26d373ed961c0fd4ccd4817e10b 100644
--- a/internal/core/unittest/CMakeLists.txt
+++ b/internal/core/unittest/CMakeLists.txt
@@ -24,8 +24,10 @@ target_link_libraries(all_tests
gtest_main
milvus_segcore
milvus_indexbuilder
+ knowhere
log
pthread
+ milvus_utils
)
install (TARGETS all_tests DESTINATION unittest)
diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp
index 65866a60b7f39538dcb5ef2cba8a872f23da8689..0c5dbe83a8ca657cfef829b5767577312124f416 100644
--- a/internal/core/unittest/test_c_api.cpp
+++ b/internal/core/unittest/test_c_api.cpp
@@ -137,7 +137,7 @@ TEST(CApiTest, SearchTest) {
auto offset = PreInsert(segment, N);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
- ASSERT_EQ(ins_res.error_code, Success);
+ assert(ins_res.error_code == Success);
const char* dsl_string = R"(
{
@@ -176,11 +176,11 @@ TEST(CApiTest, SearchTest) {
void* plan = nullptr;
auto status = CreatePlan(collection, dsl_string, &plan);
- ASSERT_EQ(status.error_code, Success);
+ assert(status.error_code == Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
- ASSERT_EQ(status.error_code, Success);
+ assert(status.error_code == Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
@@ -189,7 +189,7 @@ TEST(CApiTest, SearchTest) {
CQueryResult search_result;
auto res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, &search_result);
- ASSERT_EQ(res.error_code, Success);
+ assert(res.error_code == Success);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
diff --git a/internal/indexbuilder/indexbuilder.go b/internal/indexbuilder/indexbuilder.go
index 4acfffc3d284665158d9f3c237682a5792257d50..5b21e68dd4eebd11079784d799aaf63679a28359 100644
--- a/internal/indexbuilder/indexbuilder.go
+++ b/internal/indexbuilder/indexbuilder.go
@@ -11,6 +11,9 @@ import (
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
+
"go.etcd.io/etcd/clientv3"
"github.com/zilliztech/milvus-distributed/internal/allocator"
@@ -68,16 +71,19 @@ func CreateBuilder(ctx context.Context) (*Builder, error) {
idAllocator, err := allocator.NewIDAllocator(b.loopCtx, Params.MasterAddress)
- option := &miniokv.Option{
- Address: Params.MinIOAddress,
- AccessKeyID: Params.MinIOAccessKeyID,
- SecretAccessKeyID: Params.MinIOSecretAccessKey,
- UseSSL: Params.MinIOUseSSL,
- BucketName: Params.MinioBucketName,
- CreateBucket: true,
+ minIOEndPoint := Params.MinIOAddress
+ minIOAccessKeyID := Params.MinIOAccessKeyID
+ minIOSecretAccessKey := Params.MinIOSecretAccessKey
+ minIOUseSSL := Params.MinIOUseSSL
+ minIOClient, err := minio.New(minIOEndPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(minIOAccessKeyID, minIOSecretAccessKey, ""),
+ Secure: minIOUseSSL,
+ })
+ if err != nil {
+ return nil, err
}
- b.kv, err = miniokv.NewMinIOKV(b.loopCtx, option)
+ b.kv, err = miniokv.NewMinIOKV(b.loopCtx, minIOClient, Params.MinioBucketName)
if err != nil {
return nil, err
}
diff --git a/internal/kv/minio/minio_kv.go b/internal/kv/minio/minio_kv.go
index 6b3522fb454273f24a08e11d83f4f64e43a40381..68bb3a3438bbd771a2d5afc98c23988703db72ea 100644
--- a/internal/kv/minio/minio_kv.go
+++ b/internal/kv/minio/minio_kv.go
@@ -2,15 +2,11 @@ package miniokv
import (
"context"
- "fmt"
-
"io"
"log"
"strings"
"github.com/minio/minio-go/v7"
- "github.com/minio/minio-go/v7/pkg/credentials"
- "github.com/zilliztech/milvus-distributed/internal/errors"
)
type MinIOKV struct {
@@ -19,46 +15,24 @@ type MinIOKV struct {
bucketName string
}
-type Option struct {
- Address string
- AccessKeyID string
- BucketName string
- SecretAccessKeyID string
- UseSSL bool
- CreateBucket bool // when bucket not existed, create it
-}
-
-func NewMinIOKV(ctx context.Context, option *Option) (*MinIOKV, error) {
- minIOClient, err := minio.New(option.Address, &minio.Options{
- Creds: credentials.NewStaticV4(option.AccessKeyID, option.SecretAccessKeyID, ""),
- Secure: option.UseSSL,
- })
- if err != nil {
- return nil, err
- }
+// NewMinIOKV creates a new MinIO kv.
+func NewMinIOKV(ctx context.Context, client *minio.Client, bucketName string) (*MinIOKV, error) {
- bucketExists, err := minIOClient.BucketExists(ctx, option.BucketName)
+ bucketExists, err := client.BucketExists(ctx, bucketName)
if err != nil {
return nil, err
}
- if option.CreateBucket {
- if !bucketExists {
- err = minIOClient.MakeBucket(ctx, option.BucketName, minio.MakeBucketOptions{})
- if err != nil {
- return nil, err
- }
- }
- } else {
- if !bucketExists {
- return nil, errors.New(fmt.Sprintf("Bucket %s not Existed.", option.BucketName))
+ if !bucketExists {
+ err = client.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{})
+ if err != nil {
+ return nil, err
}
}
-
return &MinIOKV{
ctx: ctx,
- minioClient: minIOClient,
- bucketName: option.BucketName,
+ minioClient: client,
+ bucketName: bucketName,
}, nil
}
diff --git a/internal/kv/minio/minio_kv_test.go b/internal/kv/minio/minio_kv_test.go
index 2e50545b40a83516aa16cb0991bbc64cf34360cc..ac2a3180b2966bbdc09158ff70ec8baaef23dedb 100644
--- a/internal/kv/minio/minio_kv_test.go
+++ b/internal/kv/minio/minio_kv_test.go
@@ -5,6 +5,8 @@ import (
"strconv"
"testing"
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
@@ -13,31 +15,24 @@ import (
var Params paramtable.BaseTable
-func newMinIOKVClient(ctx context.Context, bucketName string) (*miniokv.MinIOKV, error) {
+func TestMinIOKV_Load(t *testing.T) {
+ Params.Init()
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
- useSSL, _ := strconv.ParseBool(useSSLStr)
- option := &miniokv.Option{
- Address: endPoint,
- AccessKeyID: accessKeyID,
- SecretAccessKeyID: secretAccessKey,
- UseSSL: useSSL,
- BucketName: bucketName,
- CreateBucket: true,
- }
- client, err := miniokv.NewMinIOKV(ctx, option)
- return client, err
-}
-
-func TestMinIOKV_Load(t *testing.T) {
- Params.Init()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
+ useSSL, _ := strconv.ParseBool(useSSLStr)
+
+ minioClient, err := minio.New(endPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
+ Secure: useSSL,
+ })
+ assert.Nil(t, err)
bucketName := "fantastic-tech-test"
- MinIOKV, err := newMinIOKVClient(ctx, bucketName)
+ MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
@@ -84,14 +79,25 @@ func TestMinIOKV_Load(t *testing.T) {
}
func TestMinIOKV_MultiSave(t *testing.T) {
- Params.Init()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- bucketName := "fantastic-tech-test"
- MinIOKV, err := newMinIOKVClient(ctx, bucketName)
+ Params.Init()
+ endPoint, _ := Params.Load("_MinioAddress")
+ accessKeyID, _ := Params.Load("minio.accessKeyID")
+ secretAccessKey, _ := Params.Load("minio.secretAccessKey")
+ useSSLStr, _ := Params.Load("minio.useSSL")
+ useSSL, _ := strconv.ParseBool(useSSLStr)
+
+ minioClient, err := minio.New(endPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
+ Secure: useSSL,
+ })
assert.Nil(t, err)
+ bucketName := "fantastic-tech-test"
+ MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
+ assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
err = MinIOKV.Save("key_1", "111")
@@ -111,13 +117,25 @@ func TestMinIOKV_MultiSave(t *testing.T) {
}
func TestMinIOKV_Remove(t *testing.T) {
- Params.Init()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
+ Params.Init()
+ endPoint, _ := Params.Load("_MinioAddress")
+ accessKeyID, _ := Params.Load("minio.accessKeyID")
+ secretAccessKey, _ := Params.Load("minio.secretAccessKey")
+ useSSLStr, _ := Params.Load("minio.useSSL")
+ useSSL, _ := strconv.ParseBool(useSSLStr)
+
+ minioClient, err := minio.New(endPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
+ Secure: useSSL,
+ })
+ assert.Nil(t, err)
+
bucketName := "fantastic-tech-test"
- MinIOKV, err := newMinIOKVClient(ctx, bucketName)
+ MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
diff --git a/internal/master/master.go b/internal/master/master.go
index 5476a7c8dc177e249c6cca4103a264f8c83d1e17..e8fb4134613a9c0cf26977eb21065f32acb8c798 100644
--- a/internal/master/master.go
+++ b/internal/master/master.go
@@ -218,6 +218,7 @@ func CreateServer(ctx context.Context) (*Master, error) {
m.grpcServer = grpc.NewServer()
masterpb.RegisterMasterServer(m.grpcServer, m)
+
return m, nil
}
diff --git a/internal/master/master_test.go b/internal/master/master_test.go
index 0a44ed90e886b55d6b7bd5bec9e2d1842041fd2c..a605e73aa76127c24918cdd3826fae0d0d186ad8 100644
--- a/internal/master/master_test.go
+++ b/internal/master/master_test.go
@@ -110,6 +110,7 @@ func TestMaster(t *testing.T) {
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
require.Nil(t, err)
+
cli := masterpb.NewMasterClient(conn)
t.Run("TestConfigTask", func(t *testing.T) {
@@ -886,12 +887,6 @@ func TestMaster(t *testing.T) {
var k2sMsgstream ms.MsgStream = k2sMs
assert.True(t, receiveTimeTickMsg(&k2sMsgstream))
- conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
- assert.Nil(t, err)
- defer conn.Close()
-
- cli := masterpb.NewMasterClient(conn)
-
sch := schemapb.CollectionSchema{
Name: "name" + strconv.FormatUint(rand.Uint64(), 10),
Description: "test collection",
diff --git a/internal/msgstream/msg.go b/internal/msgstream/msg.go
index 518bcfa7afe56a34cf86ff1464eb309a42560a9c..a71d1cabfe89745cfdfaec2d3411c4e3521c8b19 100644
--- a/internal/msgstream/msg.go
+++ b/internal/msgstream/msg.go
@@ -1,6 +1,8 @@
package msgstream
import (
+ "context"
+
"github.com/golang/protobuf/proto"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
@@ -8,6 +10,8 @@ import (
type MsgType = internalPb.MsgType
type TsMsg interface {
+ GetContext() context.Context
+ SetContext(context.Context)
BeginTs() Timestamp
EndTs() Timestamp
Type() MsgType
@@ -17,6 +21,7 @@ type TsMsg interface {
}
type BaseMsg struct {
+ ctx context.Context
BeginTimestamp Timestamp
EndTimestamp Timestamp
HashValues []uint32
@@ -44,6 +49,14 @@ func (it *InsertMsg) Type() MsgType {
return it.MsgType
}
+func (it *InsertMsg) GetContext() context.Context {
+ return it.ctx
+}
+
+func (it *InsertMsg) SetContext(ctx context.Context) {
+ it.ctx = ctx
+}
+
func (it *InsertMsg) Marshal(input TsMsg) ([]byte, error) {
insertMsg := input.(*InsertMsg)
insertRequest := &insertMsg.InsertRequest
@@ -88,6 +101,13 @@ func (fl *FlushMsg) Type() MsgType {
return fl.GetMsgType()
}
+func (fl *FlushMsg) GetContext() context.Context {
+ return fl.ctx
+}
+func (fl *FlushMsg) SetContext(ctx context.Context) {
+ fl.ctx = ctx
+}
+
func (fl *FlushMsg) Marshal(input TsMsg) ([]byte, error) {
flushMsgTask := input.(*FlushMsg)
flushMsg := &flushMsgTask.FlushMsg
@@ -121,6 +141,14 @@ func (dt *DeleteMsg) Type() MsgType {
return dt.MsgType
}
+func (dt *DeleteMsg) GetContext() context.Context {
+ return dt.ctx
+}
+
+func (dt *DeleteMsg) SetContext(ctx context.Context) {
+ dt.ctx = ctx
+}
+
func (dt *DeleteMsg) Marshal(input TsMsg) ([]byte, error) {
deleteTask := input.(*DeleteMsg)
deleteRequest := &deleteTask.DeleteRequest
@@ -165,6 +193,14 @@ func (st *SearchMsg) Type() MsgType {
return st.MsgType
}
+func (st *SearchMsg) GetContext() context.Context {
+ return st.ctx
+}
+
+func (st *SearchMsg) SetContext(ctx context.Context) {
+ st.ctx = ctx
+}
+
func (st *SearchMsg) Marshal(input TsMsg) ([]byte, error) {
searchTask := input.(*SearchMsg)
searchRequest := &searchTask.SearchRequest
@@ -198,6 +234,14 @@ func (srt *SearchResultMsg) Type() MsgType {
return srt.MsgType
}
+func (srt *SearchResultMsg) GetContext() context.Context {
+ return srt.ctx
+}
+
+func (srt *SearchResultMsg) SetContext(ctx context.Context) {
+ srt.ctx = ctx
+}
+
func (srt *SearchResultMsg) Marshal(input TsMsg) ([]byte, error) {
searchResultTask := input.(*SearchResultMsg)
searchResultRequest := &searchResultTask.SearchResult
@@ -231,6 +275,14 @@ func (tst *TimeTickMsg) Type() MsgType {
return tst.MsgType
}
+func (tst *TimeTickMsg) GetContext() context.Context {
+ return tst.ctx
+}
+
+func (tst *TimeTickMsg) SetContext(ctx context.Context) {
+ tst.ctx = ctx
+}
+
func (tst *TimeTickMsg) Marshal(input TsMsg) ([]byte, error) {
timeTickTask := input.(*TimeTickMsg)
timeTick := &timeTickTask.TimeTickMsg
@@ -264,6 +316,14 @@ func (qs *QueryNodeStatsMsg) Type() MsgType {
return qs.MsgType
}
+func (qs *QueryNodeStatsMsg) GetContext() context.Context {
+ return qs.ctx
+}
+
+func (qs *QueryNodeStatsMsg) SetContext(ctx context.Context) {
+ qs.ctx = ctx
+}
+
func (qs *QueryNodeStatsMsg) Marshal(input TsMsg) ([]byte, error) {
queryNodeSegStatsTask := input.(*QueryNodeStatsMsg)
queryNodeSegStats := &queryNodeSegStatsTask.QueryNodeStats
@@ -305,6 +365,14 @@ func (cc *CreateCollectionMsg) Type() MsgType {
return cc.MsgType
}
+func (cc *CreateCollectionMsg) GetContext() context.Context {
+ return cc.ctx
+}
+
+func (cc *CreateCollectionMsg) SetContext(ctx context.Context) {
+ cc.ctx = ctx
+}
+
func (cc *CreateCollectionMsg) Marshal(input TsMsg) ([]byte, error) {
createCollectionMsg := input.(*CreateCollectionMsg)
createCollectionRequest := &createCollectionMsg.CreateCollectionRequest
@@ -337,6 +405,13 @@ type DropCollectionMsg struct {
func (dc *DropCollectionMsg) Type() MsgType {
return dc.MsgType
}
+func (dc *DropCollectionMsg) GetContext() context.Context {
+ return dc.ctx
+}
+
+func (dc *DropCollectionMsg) SetContext(ctx context.Context) {
+ dc.ctx = ctx
+}
func (dc *DropCollectionMsg) Marshal(input TsMsg) ([]byte, error) {
dropCollectionMsg := input.(*DropCollectionMsg)
@@ -361,109 +436,18 @@ func (dc *DropCollectionMsg) Unmarshal(input []byte) (TsMsg, error) {
return dropCollectionMsg, nil
}
-/////////////////////////////////////////HasCollection//////////////////////////////////////////
-type HasCollectionMsg struct {
- BaseMsg
- internalPb.HasCollectionRequest
-}
-
-func (hc *HasCollectionMsg) Type() MsgType {
- return hc.MsgType
-}
-
-func (hc *HasCollectionMsg) Marshal(input TsMsg) ([]byte, error) {
- hasCollectionMsg := input.(*HasCollectionMsg)
- hasCollectionRequest := &hasCollectionMsg.HasCollectionRequest
- mb, err := proto.Marshal(hasCollectionRequest)
- if err != nil {
- return nil, err
- }
- return mb, nil
-}
-
-func (hc *HasCollectionMsg) Unmarshal(input []byte) (TsMsg, error) {
- hasCollectionRequest := internalPb.HasCollectionRequest{}
- err := proto.Unmarshal(input, &hasCollectionRequest)
- if err != nil {
- return nil, err
- }
- hasCollectionMsg := &HasCollectionMsg{HasCollectionRequest: hasCollectionRequest}
- hasCollectionMsg.BeginTimestamp = hasCollectionMsg.Timestamp
- hasCollectionMsg.EndTimestamp = hasCollectionMsg.Timestamp
-
- return hasCollectionMsg, nil
-}
-
-/////////////////////////////////////////DescribeCollection//////////////////////////////////////////
-type DescribeCollectionMsg struct {
- BaseMsg
- internalPb.DescribeCollectionRequest
-}
-
-func (dc *DescribeCollectionMsg) Type() MsgType {
- return dc.MsgType
-}
-
-func (dc *DescribeCollectionMsg) Marshal(input TsMsg) ([]byte, error) {
- describeCollectionMsg := input.(*DescribeCollectionMsg)
- describeCollectionRequest := &describeCollectionMsg.DescribeCollectionRequest
- mb, err := proto.Marshal(describeCollectionRequest)
- if err != nil {
- return nil, err
- }
- return mb, nil
-}
-
-func (dc *DescribeCollectionMsg) Unmarshal(input []byte) (TsMsg, error) {
- describeCollectionRequest := internalPb.DescribeCollectionRequest{}
- err := proto.Unmarshal(input, &describeCollectionRequest)
- if err != nil {
- return nil, err
- }
- describeCollectionMsg := &DescribeCollectionMsg{DescribeCollectionRequest: describeCollectionRequest}
- describeCollectionMsg.BeginTimestamp = describeCollectionMsg.Timestamp
- describeCollectionMsg.EndTimestamp = describeCollectionMsg.Timestamp
-
- return describeCollectionMsg, nil
-}
-
-/////////////////////////////////////////ShowCollection//////////////////////////////////////////
-type ShowCollectionMsg struct {
+/////////////////////////////////////////CreatePartition//////////////////////////////////////////
+type CreatePartitionMsg struct {
BaseMsg
- internalPb.ShowCollectionRequest
-}
-
-func (sc *ShowCollectionMsg) Type() MsgType {
- return sc.MsgType
-}
-
-func (sc *ShowCollectionMsg) Marshal(input TsMsg) ([]byte, error) {
- showCollectionMsg := input.(*ShowCollectionMsg)
- showCollectionRequest := &showCollectionMsg.ShowCollectionRequest
- mb, err := proto.Marshal(showCollectionRequest)
- if err != nil {
- return nil, err
- }
- return mb, nil
+ internalPb.CreatePartitionRequest
}
-func (sc *ShowCollectionMsg) Unmarshal(input []byte) (TsMsg, error) {
- showCollectionRequest := internalPb.ShowCollectionRequest{}
- err := proto.Unmarshal(input, &showCollectionRequest)
- if err != nil {
- return nil, err
- }
- showCollectionMsg := &ShowCollectionMsg{ShowCollectionRequest: showCollectionRequest}
- showCollectionMsg.BeginTimestamp = showCollectionMsg.Timestamp
- showCollectionMsg.EndTimestamp = showCollectionMsg.Timestamp
-
- return showCollectionMsg, nil
+func (cc *CreatePartitionMsg) GetContext() context.Context {
+ return cc.ctx
}
-/////////////////////////////////////////CreatePartition//////////////////////////////////////////
-type CreatePartitionMsg struct {
- BaseMsg
- internalPb.CreatePartitionRequest
+func (cc *CreatePartitionMsg) SetContext(ctx context.Context) {
+ cc.ctx = ctx
}
func (cc *CreatePartitionMsg) Type() MsgType {
@@ -499,6 +483,14 @@ type DropPartitionMsg struct {
internalPb.DropPartitionRequest
}
+func (dc *DropPartitionMsg) GetContext() context.Context {
+ return dc.ctx
+}
+
+func (dc *DropPartitionMsg) SetContext(ctx context.Context) {
+ dc.ctx = ctx
+}
+
func (dc *DropPartitionMsg) Type() MsgType {
return dc.MsgType
}
@@ -526,105 +518,6 @@ func (dc *DropPartitionMsg) Unmarshal(input []byte) (TsMsg, error) {
return dropPartitionMsg, nil
}
-/////////////////////////////////////////HasPartition//////////////////////////////////////////
-type HasPartitionMsg struct {
- BaseMsg
- internalPb.HasPartitionRequest
-}
-
-func (hc *HasPartitionMsg) Type() MsgType {
- return hc.MsgType
-}
-
-func (hc *HasPartitionMsg) Marshal(input TsMsg) ([]byte, error) {
- hasPartitionMsg := input.(*HasPartitionMsg)
- hasPartitionRequest := &hasPartitionMsg.HasPartitionRequest
- mb, err := proto.Marshal(hasPartitionRequest)
- if err != nil {
- return nil, err
- }
- return mb, nil
-}
-
-func (hc *HasPartitionMsg) Unmarshal(input []byte) (TsMsg, error) {
- hasPartitionRequest := internalPb.HasPartitionRequest{}
- err := proto.Unmarshal(input, &hasPartitionRequest)
- if err != nil {
- return nil, err
- }
- hasPartitionMsg := &HasPartitionMsg{HasPartitionRequest: hasPartitionRequest}
- hasPartitionMsg.BeginTimestamp = hasPartitionMsg.Timestamp
- hasPartitionMsg.EndTimestamp = hasPartitionMsg.Timestamp
-
- return hasPartitionMsg, nil
-}
-
-/////////////////////////////////////////DescribePartition//////////////////////////////////////////
-type DescribePartitionMsg struct {
- BaseMsg
- internalPb.DescribePartitionRequest
-}
-
-func (dc *DescribePartitionMsg) Type() MsgType {
- return dc.MsgType
-}
-
-func (dc *DescribePartitionMsg) Marshal(input TsMsg) ([]byte, error) {
- describePartitionMsg := input.(*DescribePartitionMsg)
- describePartitionRequest := &describePartitionMsg.DescribePartitionRequest
- mb, err := proto.Marshal(describePartitionRequest)
- if err != nil {
- return nil, err
- }
- return mb, nil
-}
-
-func (dc *DescribePartitionMsg) Unmarshal(input []byte) (TsMsg, error) {
- describePartitionRequest := internalPb.DescribePartitionRequest{}
- err := proto.Unmarshal(input, &describePartitionRequest)
- if err != nil {
- return nil, err
- }
- describePartitionMsg := &DescribePartitionMsg{DescribePartitionRequest: describePartitionRequest}
- describePartitionMsg.BeginTimestamp = describePartitionMsg.Timestamp
- describePartitionMsg.EndTimestamp = describePartitionMsg.Timestamp
-
- return describePartitionMsg, nil
-}
-
-/////////////////////////////////////////ShowPartition//////////////////////////////////////////
-type ShowPartitionMsg struct {
- BaseMsg
- internalPb.ShowPartitionRequest
-}
-
-func (sc *ShowPartitionMsg) Type() MsgType {
- return sc.MsgType
-}
-
-func (sc *ShowPartitionMsg) Marshal(input TsMsg) ([]byte, error) {
- showPartitionMsg := input.(*ShowPartitionMsg)
- showPartitionRequest := &showPartitionMsg.ShowPartitionRequest
- mb, err := proto.Marshal(showPartitionRequest)
- if err != nil {
- return nil, err
- }
- return mb, nil
-}
-
-func (sc *ShowPartitionMsg) Unmarshal(input []byte) (TsMsg, error) {
- showPartitionRequest := internalPb.ShowPartitionRequest{}
- err := proto.Unmarshal(input, &showPartitionRequest)
- if err != nil {
- return nil, err
- }
- showPartitionMsg := &ShowPartitionMsg{ShowPartitionRequest: showPartitionRequest}
- showPartitionMsg.BeginTimestamp = showPartitionMsg.Timestamp
- showPartitionMsg.EndTimestamp = showPartitionMsg.Timestamp
-
- return showPartitionMsg, nil
-}
-
/////////////////////////////////////////LoadIndex//////////////////////////////////////////
type LoadIndexMsg struct {
BaseMsg
@@ -635,6 +528,14 @@ func (lim *LoadIndexMsg) Type() MsgType {
return lim.MsgType
}
+func (lim *LoadIndexMsg) GetContext() context.Context {
+ return lim.ctx
+}
+
+func (lim *LoadIndexMsg) SetContext(ctx context.Context) {
+ lim.ctx = ctx
+}
+
func (lim *LoadIndexMsg) Marshal(input TsMsg) ([]byte, error) {
loadIndexMsg := input.(*LoadIndexMsg)
loadIndexRequest := &loadIndexMsg.LoadIndex
diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go
index 37dd71c053441673334cbda1c60adac9fdc5fc5f..969755feb322a3ac33619500a061dd9a0f984c00 100644
--- a/internal/msgstream/msgstream.go
+++ b/internal/msgstream/msgstream.go
@@ -4,9 +4,13 @@ import (
"context"
"log"
"reflect"
+ "strings"
"sync"
"time"
+ "github.com/opentracing/opentracing-go"
+ "github.com/opentracing/opentracing-go/ext"
+
"github.com/apache/pulsar-client-go/pulsar"
"github.com/golang/protobuf/proto"
@@ -151,6 +155,29 @@ func (ms *PulsarMsgStream) Close() {
}
}
+type propertiesReaderWriter struct {
+ ppMap map[string]string
+}
+
+func (ppRW *propertiesReaderWriter) Set(key, val string) {
+ // The GRPC HPACK implementation rejects any uppercase keys here.
+ //
+ // As such, since the HTTP_HEADERS format is case-insensitive anyway, we
+ // blindly lowercase the key (which is guaranteed to work in the
+ // Inject/Extract sense per the OpenTracing spec).
+ key = strings.ToLower(key)
+ ppRW.ppMap[key] = val
+}
+
+func (ppRW *propertiesReaderWriter) ForeachKey(handler func(key, val string) error) error {
+ for k, val := range ppRW.ppMap {
+ if err := handler(k, val); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error {
tsMsgs := msgPack.Msgs
if len(tsMsgs) <= 0 {
@@ -200,12 +227,41 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error {
if err != nil {
return err
}
+
+ msg := &pulsar.ProducerMessage{Payload: mb}
+ var child opentracing.Span
+ if v.Msgs[i].Type() == internalPb.MsgType_kInsert || v.Msgs[i].Type() == internalPb.MsgType_kSearch {
+ tracer := opentracing.GlobalTracer()
+ ctx := v.Msgs[i].GetContext()
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if parent := opentracing.SpanFromContext(ctx); parent != nil {
+ child = tracer.StartSpan("start send pulsar msg",
+ opentracing.FollowsFrom(parent.Context()))
+ } else {
+ child = tracer.StartSpan("start send pulsar msg")
+ }
+ child.SetTag("hash keys", v.Msgs[i].HashKeys())
+ child.SetTag("start time", v.Msgs[i].BeginTs())
+ child.SetTag("end time", v.Msgs[i].EndTs())
+ msg.Properties = make(map[string]string)
+ err = tracer.Inject(child.Context(), opentracing.TextMap, &propertiesReaderWriter{msg.Properties})
+ if err != nil {
+ return err
+ }
+ }
+
if _, err := (*ms.producers[k]).Send(
context.Background(),
- &pulsar.ProducerMessage{Payload: mb},
+ msg,
); err != nil {
return err
}
+ if child != nil {
+ child.Finish()
+ }
}
}
return nil
@@ -218,10 +274,34 @@ func (ms *PulsarMsgStream) Broadcast(msgPack *MsgPack) error {
if err != nil {
return err
}
+ msg := &pulsar.ProducerMessage{Payload: mb}
+ if v.Type() == internalPb.MsgType_kInsert || v.Type() == internalPb.MsgType_kSearch {
+ tracer := opentracing.GlobalTracer()
+ ctx := v.GetContext()
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ var child opentracing.Span
+ if parent := opentracing.SpanFromContext(ctx); parent != nil {
+ child = tracer.StartSpan("start send pulsar msg",
+ opentracing.FollowsFrom(parent.Context()))
+ } else {
+ child = tracer.StartSpan("start send pulsar msg, start time: %d")
+ }
+ child.SetTag("hash keys", v.HashKeys())
+ child.SetTag("start time", v.BeginTs())
+ child.SetTag("end time", v.EndTs())
+ msg.Properties = make(map[string]string)
+ err = tracer.Inject(child.Context(), opentracing.TextMap, &propertiesReaderWriter{msg.Properties})
+ if err != nil {
+ return err
+ }
+ child.Finish()
+ }
for i := 0; i < producerLen; i++ {
if _, err := (*ms.producers[i]).Send(
context.Background(),
- &pulsar.ProducerMessage{Payload: mb},
+ msg,
); err != nil {
return err
}
@@ -258,6 +338,7 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() {
for {
select {
case <-ms.ctx.Done():
+ log.Println("done")
return
default:
tsMsgList := make([]TsMsg, 0)
@@ -270,6 +351,7 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() {
}
pulsarMsg, ok := value.Interface().(pulsar.ConsumerMessage)
+
if !ok {
log.Printf("type assertion failed, not consumer message type")
continue
@@ -283,6 +365,21 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() {
continue
}
tsMsg, err := ms.unmarshal.Unmarshal(pulsarMsg.Payload(), headerMsg.MsgType)
+ if tsMsg.Type() == internalPb.MsgType_kInsert || tsMsg.Type() == internalPb.MsgType_kSearch {
+ tracer := opentracing.GlobalTracer()
+ spanContext, err := tracer.Extract(opentracing.HTTPHeaders, &propertiesReaderWriter{pulsarMsg.Properties()})
+ if err != nil {
+ log.Println("extract message err")
+ log.Println(err.Error())
+ }
+ span := opentracing.StartSpan("pulsar msg received",
+ ext.RPCServerOption(spanContext))
+ span.SetTag("hash keys", tsMsg.HashKeys())
+ span.SetTag("start time", tsMsg.BeginTs())
+ span.SetTag("end time", tsMsg.EndTs())
+ tsMsg.SetContext(opentracing.ContextWithSpan(context.Background(), span))
+ span.Finish()
+ }
if err != nil {
log.Printf("Failed to unmarshal tsMsg, error = %v", err)
continue
@@ -420,6 +517,23 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int,
if err != nil {
log.Printf("Failed to unmarshal, error = %v", err)
}
+
+ if tsMsg.Type() == internalPb.MsgType_kInsert || tsMsg.Type() == internalPb.MsgType_kSearch {
+ tracer := opentracing.GlobalTracer()
+ spanContext, err := tracer.Extract(opentracing.HTTPHeaders, &propertiesReaderWriter{pulsarMsg.Properties()})
+ if err != nil {
+ log.Println("extract message err")
+ log.Println(err.Error())
+ }
+ span := opentracing.StartSpan("pulsar msg received",
+ ext.RPCServerOption(spanContext))
+ span.SetTag("hash keys", tsMsg.HashKeys())
+ span.SetTag("start time", tsMsg.BeginTs())
+ span.SetTag("end time", tsMsg.EndTs())
+ tsMsg.SetContext(opentracing.ContextWithSpan(context.Background(), span))
+ span.Finish()
+ }
+
if headerMsg.MsgType == internalPb.MsgType_kTimeTick {
eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Timestamp
return
@@ -500,7 +614,7 @@ func insertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
result := make(map[int32]*MsgPack)
for i, request := range tsMsgs {
if request.Type() != internalPb.MsgType_kInsert {
- return nil, errors.New(string("msg's must be Insert"))
+ return nil, errors.New("msg's must be Insert")
}
insertRequest := request.(*InsertMsg)
keys := hashKeys[i]
@@ -511,7 +625,7 @@ func insertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
keysLen := len(keys)
if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen {
- return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal"))
+ return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal")
}
for index, key := range keys {
_, ok := result[key]
@@ -534,6 +648,9 @@ func insertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
}
insertMsg := &InsertMsg{
+ BaseMsg: BaseMsg{
+ ctx: request.GetContext(),
+ },
InsertRequest: sliceRequest,
}
result[key].Msgs = append(result[key].Msgs, insertMsg)
@@ -546,7 +663,7 @@ func deleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
result := make(map[int32]*MsgPack)
for i, request := range tsMsgs {
if request.Type() != internalPb.MsgType_kDelete {
- return nil, errors.New(string("msg's must be Delete"))
+ return nil, errors.New("msg's must be Delete")
}
deleteRequest := request.(*DeleteMsg)
keys := hashKeys[i]
@@ -556,7 +673,7 @@ func deleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e
keysLen := len(keys)
if keysLen != timestampLen || keysLen != primaryKeysLen {
- return nil, errors.New(string("the length of hashValue, timestamps, primaryKeys are not equal"))
+ return nil, errors.New("the length of hashValue, timestamps, primaryKeys are not equal")
}
for index, key := range keys {
@@ -590,7 +707,7 @@ func defaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack,
for i, request := range tsMsgs {
keys := hashKeys[i]
if len(keys) != 1 {
- return nil, errors.New(string("len(msg.hashValue) must equal 1"))
+ return nil, errors.New("len(msg.hashValue) must equal 1")
}
key := keys[0]
_, ok := result[key]
diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go
index f4232bc82234f07d54d713772e89d70f77b5586f..6a312c2bc04169e53915820fd400c177c7f48811 100644
--- a/internal/proxy/proxy.go
+++ b/internal/proxy/proxy.go
@@ -2,6 +2,8 @@ package proxy
import (
"context"
+ "fmt"
+ "io"
"log"
"math/rand"
"net"
@@ -9,6 +11,10 @@ import (
"sync"
"time"
+ "github.com/opentracing/opentracing-go"
+ "github.com/uber/jaeger-client-go"
+ "github.com/uber/jaeger-client-go/config"
+
"google.golang.org/grpc"
"github.com/zilliztech/milvus-distributed/internal/allocator"
@@ -39,6 +45,9 @@ type Proxy struct {
manipulationMsgStream *msgstream.PulsarMsgStream
queryMsgStream *msgstream.PulsarMsgStream
+ tracer opentracing.Tracer
+ closer io.Closer
+
// Add callback functions at different stages
startCallbacks []func()
closeCallbacks []func()
@@ -51,11 +60,28 @@ func Init() {
func CreateProxy(ctx context.Context) (*Proxy, error) {
rand.Seed(time.Now().UnixNano())
ctx1, cancel := context.WithCancel(ctx)
+ var err error
p := &Proxy{
proxyLoopCtx: ctx1,
proxyLoopCancel: cancel,
}
+ cfg := &config.Configuration{
+ ServiceName: "tracing",
+ Sampler: &config.SamplerConfig{
+ Type: "const",
+ Param: 1,
+ },
+ Reporter: &config.ReporterConfig{
+ LogSpans: true,
+ },
+ }
+ p.tracer, p.closer, err = cfg.NewTracer(config.Logger(jaeger.StdLogger))
+ if err != nil {
+ panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err))
+ }
+ opentracing.SetGlobalTracer(p.tracer)
+
pulsarAddress := Params.PulsarAddress()
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
@@ -198,6 +224,8 @@ func (p *Proxy) stopProxyLoop() {
p.tick.Close()
p.proxyLoopWg.Wait()
+
+ p.closer.Close()
}
// Close closes the server.
diff --git a/internal/proxy/repack_func.go b/internal/proxy/repack_func.go
index 44139999e0403719ca9eaf141f110980b808b6e1..f8873fe12f27bcae72473bd9b5eb252a9eac1aee 100644
--- a/internal/proxy/repack_func.go
+++ b/internal/proxy/repack_func.go
@@ -182,6 +182,7 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
insertMsg := &msgstream.InsertMsg{
InsertRequest: sliceRequest,
}
+ insertMsg.SetContext(request.GetContext())
if together { // all rows with same hash value are accumulated to only one message
if len(result[key].Msgs) <= 0 {
result[key].Msgs = append(result[key].Msgs, insertMsg)
diff --git a/internal/proxy/task.go b/internal/proxy/task.go
index 425cae75cfb3e4de24b55cf4a24cf3cc5aa55dbe..d01c45f0632545feca63a477eec4df35de3cf85c 100644
--- a/internal/proxy/task.go
+++ b/internal/proxy/task.go
@@ -7,6 +7,9 @@ import (
"math"
"strconv"
+ "github.com/opentracing/opentracing-go"
+ oplog "github.com/opentracing/opentracing-go/log"
+
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
@@ -74,12 +77,21 @@ func (it *InsertTask) Type() internalpb.MsgType {
}
func (it *InsertTask) PreExecute() error {
+ span := opentracing.StartSpan("InsertTask preExecute")
+ defer span.Finish()
+ it.ctx = opentracing.ContextWithSpan(it.ctx, span)
+ span.SetTag("hash keys", it.ReqID)
+ span.SetTag("start time", it.BeginTs())
collectionName := it.BaseInsertTask.CollectionName
if err := ValidateCollectionName(collectionName); err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
partitionTag := it.BaseInsertTask.PartitionTag
if err := ValidatePartitionTag(partitionTag, true); err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
@@ -87,22 +99,36 @@ func (it *InsertTask) PreExecute() error {
}
func (it *InsertTask) Execute() error {
+ span, ctx := opentracing.StartSpanFromContext(it.ctx, "InsertTask Execute")
+ defer span.Finish()
+ it.ctx = ctx
+ span.SetTag("hash keys", it.ReqID)
+ span.SetTag("start time", it.BeginTs())
collectionName := it.BaseInsertTask.CollectionName
+ span.LogFields(oplog.String("collection_name", collectionName))
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Sync(collectionName)
if err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
}
description, err := globalMetaCache.Get(collectionName)
if err != nil || description == nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
autoID := description.Schema.AutoID
+ span.LogFields(oplog.Bool("auto_id", autoID))
var rowIDBegin UniqueID
var rowIDEnd UniqueID
rowNums := len(it.BaseInsertTask.RowData)
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
+ span.LogFields(oplog.Int("rowNums", rowNums),
+ oplog.Int("rowIDBegin", int(rowIDBegin)),
+ oplog.Int("rowIDEnd", int(rowIDEnd)))
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
for i := rowIDBegin; i < rowIDEnd; i++ {
offset := i - rowIDBegin
@@ -125,6 +151,8 @@ func (it *InsertTask) Execute() error {
EndTs: it.EndTs(),
Msgs: make([]msgstream.TsMsg, 1),
}
+ tsMsg.SetContext(it.ctx)
+ span.LogFields(oplog.String("send msg", "send msg"))
msgPack.Msgs[0] = tsMsg
err = it.manipulationMsgStream.Produce(msgPack)
@@ -138,11 +166,14 @@ func (it *InsertTask) Execute() error {
if err != nil {
it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR
it.result.Status.Reason = err.Error()
+ span.LogFields(oplog.Error(err))
}
return nil
}
func (it *InsertTask) PostExecute() error {
+ span, _ := opentracing.StartSpanFromContext(it.ctx, "InsertTask postExecute")
+ defer span.Finish()
return nil
}
@@ -352,24 +383,38 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
}
func (qt *QueryTask) PreExecute() error {
+ span := opentracing.StartSpan("InsertTask preExecute")
+ defer span.Finish()
+ qt.ctx = opentracing.ContextWithSpan(qt.ctx, span)
+ span.SetTag("hash keys", qt.ReqID)
+ span.SetTag("start time", qt.BeginTs())
+
collectionName := qt.query.CollectionName
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Sync(collectionName)
if err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
}
_, err := globalMetaCache.Get(collectionName)
if err != nil { // err is not nil if collection not exists
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
for _, tag := range qt.query.PartitionTags {
if err := ValidatePartitionTag(tag, false); err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
}
@@ -379,6 +424,8 @@ func (qt *QueryTask) PreExecute() error {
}
queryBytes, err := proto.Marshal(qt.query)
if err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
qt.Query = &commonpb.Blob{
@@ -388,6 +435,10 @@ func (qt *QueryTask) PreExecute() error {
}
func (qt *QueryTask) Execute() error {
+ span, ctx := opentracing.StartSpanFromContext(qt.ctx, "InsertTask Execute")
+ defer span.Finish()
+ span.SetTag("hash keys", qt.ReqID)
+ span.SetTag("start time", qt.BeginTs())
var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
SearchRequest: qt.SearchRequest,
BaseMsg: msgstream.BaseMsg{
@@ -401,20 +452,28 @@ func (qt *QueryTask) Execute() error {
EndTs: qt.Timestamp,
Msgs: make([]msgstream.TsMsg, 1),
}
+ tsMsg.SetContext(ctx)
msgPack.Msgs[0] = tsMsg
err := qt.queryMsgStream.Produce(msgPack)
log.Printf("[Proxy] length of searchMsg: %v", len(msgPack.Msgs))
if err != nil {
+ span.LogFields(oplog.Error(err))
+ span.Finish()
log.Printf("[Proxy] send search request failed: %v", err)
}
return err
}
func (qt *QueryTask) PostExecute() error {
+ span, _ := opentracing.StartSpanFromContext(qt.ctx, "InsertTask postExecute")
+ span.SetTag("hash keys", qt.ReqID)
+ span.SetTag("start time", qt.BeginTs())
for {
select {
case <-qt.ctx.Done():
log.Print("wait to finish failed, timeout!")
+ span.LogFields(oplog.String("wait to finish failed, timeout", "wait to finish failed, timeout"))
+ span.Finish()
return errors.New("wait to finish failed, timeout")
case searchResults := <-qt.resultBuf:
filterSearchResult := make([]*internalpb.SearchResult, 0)
@@ -435,6 +494,8 @@ func (qt *QueryTask) PostExecute() error {
Reason: filterReason,
},
}
+ span.LogFields(oplog.Error(errors.New(filterReason)))
+ span.Finish()
return errors.New(filterReason)
}
@@ -465,6 +526,7 @@ func (qt *QueryTask) PostExecute() error {
Reason: filterReason,
},
}
+ span.Finish()
return nil
}
@@ -476,6 +538,7 @@ func (qt *QueryTask) PostExecute() error {
Reason: filterReason,
},
}
+ span.Finish()
return nil
}
@@ -526,10 +589,13 @@ func (qt *QueryTask) PostExecute() error {
reducedHitsBs, err := proto.Marshal(reducedHits)
if err != nil {
log.Println("marshal error")
+ span.LogFields(oplog.Error(err))
+ span.Finish()
return err
}
qt.result.Hits = append(qt.result.Hits, reducedHitsBs)
}
+ span.Finish()
return nil
}
}
@@ -637,7 +703,10 @@ func (dct *DescribeCollectionTask) PreExecute() error {
func (dct *DescribeCollectionTask) Execute() error {
var err error
dct.result, err = dct.masterClient.DescribeCollection(dct.ctx, &dct.DescribeCollectionRequest)
- globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result)
+ if err != nil {
+ return err
+ }
+ err = globalMetaCache.Update(dct.CollectionName.CollectionName, dct.result)
return err
}
diff --git a/internal/querynode/load_index_service.go b/internal/querynode/load_index_service.go
index d8ae759b6731512e1e07df7bf03e63c0b421a0e2..cedbd50bb0447e0c535e926eacf208ca1d1e1b29 100644
--- a/internal/querynode/load_index_service.go
+++ b/internal/querynode/load_index_service.go
@@ -11,6 +11,9 @@ import (
"strings"
"time"
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
+
minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
@@ -35,17 +38,16 @@ type loadIndexService struct {
func newLoadIndexService(ctx context.Context, replica collectionReplica) *loadIndexService {
ctx1, cancel := context.WithCancel(ctx)
- option := &minioKV.Option{
- Address: Params.MinioEndPoint,
- AccessKeyID: Params.MinioAccessKeyID,
- SecretAccessKeyID: Params.MinioSecretAccessKey,
- UseSSL: Params.MinioUseSSLStr,
- CreateBucket: true,
- BucketName: Params.MinioBucketName,
+ // init minio
+ minioClient, err := minio.New(Params.MinioEndPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(Params.MinioAccessKeyID, Params.MinioSecretAccessKey, ""),
+ Secure: Params.MinioUseSSLStr,
+ })
+ if err != nil {
+ panic(err)
}
- // TODO: load bucketName from config
- MinioKV, err := minioKV.NewMinIOKV(ctx1, option)
+ MinioKV, err := minioKV.NewMinIOKV(ctx1, minioClient, Params.MinioBucketName)
if err != nil {
panic(err)
}
diff --git a/internal/querynode/load_index_service_test.go b/internal/querynode/load_index_service_test.go
index 000edb49df2bf53fddef89b3261c9ceb15f4c5de..cb2fb8504e190345567ee170ecce71846fe2ce45 100644
--- a/internal/querynode/load_index_service_test.go
+++ b/internal/querynode/load_index_service_test.go
@@ -5,6 +5,8 @@ import (
"sort"
"testing"
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/indexbuilder"
@@ -66,16 +68,13 @@ func TestLoadIndexService(t *testing.T) {
binarySet, err := index.Serialize()
assert.Equal(t, err, nil)
- option := &minioKV.Option{
- Address: Params.MinioEndPoint,
- AccessKeyID: Params.MinioAccessKeyID,
- SecretAccessKeyID: Params.MinioSecretAccessKey,
- UseSSL: Params.MinioUseSSLStr,
- BucketName: Params.MinioBucketName,
- CreateBucket: true,
- }
-
- minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option)
+ //save index to minio
+ minioClient, err := minio.New(Params.MinioEndPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(Params.MinioAccessKeyID, Params.MinioSecretAccessKey, ""),
+ Secure: Params.MinioUseSSLStr,
+ })
+ assert.Equal(t, err, nil)
+ minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, minioClient, Params.MinioBucketName)
assert.Equal(t, err, nil)
indexPaths := make([]string, 0)
for _, index := range binarySet {
diff --git a/internal/storage/internal/S3/S3_test.go b/internal/storage/internal/S3/S3_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..c565f4a5a15cf25d109f5dcff642a0a21f2f94d1
--- /dev/null
+++ b/internal/storage/internal/S3/S3_test.go
@@ -0,0 +1,134 @@
+package s3driver
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+var option = storagetype.Option{BucketName: "zilliz-hz"}
+var ctx = context.Background()
+var client, err = NewS3Driver(ctx, option)
+
+func TestS3Driver_PutRowAndGetRow(t *testing.T) {
+ err = client.PutRow(ctx, []byte("bar"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("bar"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("bar"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("bar1"), []byte("testkeybarorbar_1"), "SegmentC", 3)
+ assert.Nil(t, err)
+ object, _ := client.GetRow(ctx, []byte("bar"), 1)
+ assert.Equal(t, "abcdefghijklmnoopqrstuvwxyz", string(object))
+ object, _ = client.GetRow(ctx, []byte("bar"), 2)
+ assert.Equal(t, "djhfkjsbdfbsdughorsgsdjhgoisdgh", string(object))
+ object, _ = client.GetRow(ctx, []byte("bar"), 5)
+ assert.Equal(t, "123854676ershdgfsgdfk,sdhfg;sdi8", string(object))
+ object, _ = client.GetRow(ctx, []byte("bar1"), 5)
+ assert.Equal(t, "testkeybarorbar_1", string(object))
+}
+
+func TestS3Driver_DeleteRow(t *testing.T) {
+ err = client.DeleteRow(ctx, []byte("bar"), 5)
+ assert.Nil(t, err)
+ object, _ := client.GetRow(ctx, []byte("bar"), 6)
+ assert.Nil(t, object)
+ err = client.DeleteRow(ctx, []byte("bar1"), 5)
+ assert.Nil(t, err)
+ object2, _ := client.GetRow(ctx, []byte("bar1"), 6)
+ assert.Nil(t, object2)
+}
+
+func TestS3Driver_GetSegments(t *testing.T) {
+ err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1)
+ assert.Nil(t, err)
+
+ segements, err := client.GetSegments(ctx, []byte("seg"), 4)
+ assert.Nil(t, err)
+ assert.Equal(t, 2, len(segements))
+ if segements[0] == "SegmentA" {
+ assert.Equal(t, "SegmentA", segements[0])
+ assert.Equal(t, "SegmentB", segements[1])
+ } else {
+ assert.Equal(t, "SegmentB", segements[0])
+ assert.Equal(t, "SegmentA", segements[1])
+ }
+}
+
+func TestS3Driver_PutRowsAndGetRows(t *testing.T) {
+ keys := [][]byte{[]byte("foo"), []byte("bar")}
+ values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")}
+ segments := []string{"segmentA", "segmentB"}
+ timestamps := []uint64{1, 2}
+ err = client.PutRows(ctx, keys, values, segments, timestamps)
+ assert.Nil(t, err)
+
+ objects, err := client.GetRows(ctx, keys, timestamps)
+ assert.Nil(t, err)
+ assert.Equal(t, "The key is foo!", string(objects[0]))
+ assert.Equal(t, "The key is bar!", string(objects[1]))
+}
+
+func TestS3Driver_DeleteRows(t *testing.T) {
+ keys := [][]byte{[]byte("foo"), []byte("bar")}
+ timestamps := []uint64{3, 3}
+ err := client.DeleteRows(ctx, keys, timestamps)
+ assert.Nil(t, err)
+
+ objects, err := client.GetRows(ctx, keys, timestamps)
+ assert.Nil(t, err)
+ assert.Nil(t, objects[0])
+ assert.Nil(t, objects[1])
+}
+
+func TestS3Driver_PutLogAndGetLog(t *testing.T) {
+ err = client.PutLog(ctx, []byte("insert"), []byte("This is insert log!"), 1, 11)
+ assert.Nil(t, err)
+ err = client.PutLog(ctx, []byte("delete"), []byte("This is delete log!"), 2, 10)
+ assert.Nil(t, err)
+ err = client.PutLog(ctx, []byte("update"), []byte("This is update log!"), 3, 9)
+ assert.Nil(t, err)
+ err = client.PutLog(ctx, []byte("select"), []byte("This is select log!"), 4, 8)
+ assert.Nil(t, err)
+
+ channels := []int{5, 8, 9, 10, 11, 12, 13}
+ logValues, err := client.GetLog(ctx, 0, 5, channels)
+ assert.Nil(t, err)
+ assert.Equal(t, "This is select log!", string(logValues[0]))
+ assert.Equal(t, "This is update log!", string(logValues[1]))
+ assert.Equal(t, "This is delete log!", string(logValues[2]))
+ assert.Equal(t, "This is insert log!", string(logValues[3]))
+}
+
+func TestS3Driver_Segment(t *testing.T) {
+ err := client.PutSegmentIndex(ctx, "segmentA", []byte("This is segmentA's index!"))
+ assert.Nil(t, err)
+
+ segmentIndex, err := client.GetSegmentIndex(ctx, "segmentA")
+ assert.Equal(t, "This is segmentA's index!", string(segmentIndex))
+ assert.Nil(t, err)
+
+ err = client.DeleteSegmentIndex(ctx, "segmentA")
+ assert.Nil(t, err)
+}
+
+func TestS3Driver_SegmentDL(t *testing.T) {
+ err := client.PutSegmentDL(ctx, "segmentB", []byte("This is segmentB's delete log!"))
+ assert.Nil(t, err)
+
+ segmentDL, err := client.GetSegmentDL(ctx, "segmentB")
+ assert.Nil(t, err)
+ assert.Equal(t, "This is segmentB's delete log!", string(segmentDL))
+
+ err = client.DeleteSegmentDL(ctx, "segmentB")
+ assert.Nil(t, err)
+}
diff --git a/internal/storage/internal/S3/s3_engine.go b/internal/storage/internal/S3/s3_engine.go
new file mode 100644
index 0000000000000000000000000000000000000000..8034d679e7e01fb45867b7424b6e9e31bc35c294
--- /dev/null
+++ b/internal/storage/internal/S3/s3_engine.go
@@ -0,0 +1,173 @@
+package s3driver
+
+import (
+ "bytes"
+ "context"
+ "io"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/session"
+ "github.com/aws/aws-sdk-go/service/s3"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+type S3Store struct {
+ client *s3.S3
+}
+
+func NewS3Store(config aws.Config) (*S3Store, error) {
+ sess := session.Must(session.NewSession(&config))
+ service := s3.New(sess)
+
+ return &S3Store{
+ client: service,
+ }, nil
+}
+
+func (s *S3Store) Put(ctx context.Context, key Key, value Value) error {
+ _, err := s.client.PutObjectWithContext(ctx, &s3.PutObjectInput{
+ Bucket: aws.String(bucketName),
+ Key: aws.String(string(key)),
+ Body: bytes.NewReader(value),
+ })
+
+ //sess := session.Must(session.NewSessionWithOptions(session.Options{
+ // SharedConfigState: session.SharedConfigEnable,
+ //}))
+ //uploader := s3manager.NewUploader(sess)
+ //
+ //_, err := uploader.Upload(&s3manager.UploadInput{
+ // Bucket: aws.String(bucketName),
+ // Key: aws.String(string(key)),
+ // Body: bytes.NewReader(value),
+ //})
+
+ return err
+}
+
+func (s *S3Store) Get(ctx context.Context, key Key) (Value, error) {
+ object, err := s.client.GetObjectWithContext(ctx, &s3.GetObjectInput{
+ Bucket: aws.String(bucketName),
+ Key: aws.String(string(key)),
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ //TODO: get size
+ size := 256 * 1024
+ buf := make([]byte, size)
+ n, err := object.Body.Read(buf)
+ if err != nil && err != io.EOF {
+ return nil, err
+ }
+ return buf[:n], nil
+}
+
+func (s *S3Store) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) {
+ objectsOutput, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
+ Bucket: aws.String(bucketName),
+ Prefix: aws.String(string(prefix)),
+ })
+
+ var objectsKeys []Key
+ var objectsValues []Value
+
+ if objectsOutput != nil && err == nil {
+ for _, object := range objectsOutput.Contents {
+ objectsKeys = append(objectsKeys, []byte(*object.Key))
+ if !keyOnly {
+ value, err := s.Get(ctx, []byte(*object.Key))
+ if err != nil {
+ return nil, nil, err
+ }
+ objectsValues = append(objectsValues, value)
+ }
+ }
+ } else {
+ return nil, nil, err
+ }
+
+ return objectsKeys, objectsValues, nil
+
+}
+
+func (s *S3Store) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit int, keyOnly bool) ([]Key, []Value, error) {
+ var keys []Key
+ var values []Value
+ limitCount := uint(limit)
+ objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
+ Bucket: aws.String(bucketName),
+ Prefix: aws.String(string(keyStart)),
+ })
+ if err == nil && objects != nil {
+ for _, object := range objects.Contents {
+ if *object.Key >= string(keyEnd) {
+ keys = append(keys, []byte(*object.Key))
+ if !keyOnly {
+ value, err := s.Get(ctx, []byte(*object.Key))
+ if err != nil {
+ return nil, nil, err
+ }
+ values = append(values, value)
+ }
+ limitCount--
+ if limitCount <= 0 {
+ break
+ }
+ }
+ }
+ }
+
+ return keys, values, err
+}
+
+func (s *S3Store) Delete(ctx context.Context, key Key) error {
+ _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
+ Bucket: aws.String(bucketName),
+ Key: aws.String(string(key)),
+ })
+ return err
+}
+
+func (s *S3Store) DeleteByPrefix(ctx context.Context, prefix Key) error {
+
+ objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
+ Bucket: aws.String(bucketName),
+ Prefix: aws.String(string(prefix)),
+ })
+
+ if objects != nil && err == nil {
+ for _, object := range objects.Contents {
+ _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
+ Bucket: aws.String(bucketName),
+ Key: object.Key,
+ })
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (s *S3Store) DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error {
+
+ objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
+ Bucket: aws.String(bucketName),
+ Prefix: aws.String(string(keyStart)),
+ })
+
+ if objects != nil && err == nil {
+ for _, object := range objects.Contents {
+ if *object.Key > string(keyEnd) {
+ _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
+ Bucket: aws.String(bucketName),
+ Key: object.Key,
+ })
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/internal/storage/internal/S3/s3_store.go b/internal/storage/internal/S3/s3_store.go
new file mode 100644
index 0000000000000000000000000000000000000000..19199baa54dcdc50d6ae947f899be1068c383642
--- /dev/null
+++ b/internal/storage/internal/S3/s3_store.go
@@ -0,0 +1,339 @@
+package s3driver
+
+import (
+ "context"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/endpoints"
+ "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio/codec"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+type S3Driver struct {
+ driver *S3Store
+}
+
+var bucketName string
+
+func NewS3Driver(ctx context.Context, option Option) (*S3Driver, error) {
+ // to-do read conf
+
+ bucketName = option.BucketName
+
+ S3Client, err := NewS3Store(aws.Config{
+ Region: aws.String(endpoints.CnNorthwest1RegionID)})
+
+ if err != nil {
+ return nil, err
+ }
+
+ return &S3Driver{
+ S3Client,
+ }, nil
+}
+
+func (s *S3Driver) put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error {
+ minioKey, err := codec.MvccEncode(key, timestamp, suffix)
+ if err != nil {
+ return err
+ }
+
+ err = s.driver.Put(ctx, minioKey, value)
+ return err
+}
+
+func (s *S3Driver) scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
+ keyEnd, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ keys, values, err := s.driver.Scan(ctx, key, []byte(keyEnd), -1, keyOnly)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ var timestamps []Timestamp
+ for _, key := range keys {
+ _, timestamp, _, _ := codec.MvccDecode(key)
+ timestamps = append(timestamps, timestamp)
+ }
+
+ return timestamps, keys, values, nil
+}
+
+func (s *S3Driver) scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
+ keyStart, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ keys, values, err := s.driver.Scan(ctx, key, keyStart, -1, keyOnly)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ var timestamps []Timestamp
+ for _, key := range keys {
+ _, timestamp, _, _ := codec.MvccDecode(key)
+ timestamps = append(timestamps, timestamp)
+ }
+
+ return timestamps, keys, values, nil
+}
+
+//scan(ctx context.Context, key Key, start Timestamp, end Timestamp, withValue bool) ([]Timestamp, []Key, []Value, error)
+func (s *S3Driver) deleteLE(ctx context.Context, key Key, timestamp Timestamp) error {
+ keyEnd, err := codec.MvccEncode(key, timestamp, "delete")
+ if err != nil {
+ return err
+ }
+ err = s.driver.DeleteRange(ctx, key, keyEnd)
+ return err
+}
+func (s *S3Driver) deleteGE(ctx context.Context, key Key, timestamp Timestamp) error {
+ keys, _, err := s.driver.GetByPrefix(ctx, key, true)
+ if err != nil {
+ return err
+ }
+ keyStart, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ panic(err)
+ }
+ err = s.driver.DeleteRange(ctx, []byte(keyStart), keys[len(keys)-1])
+ return err
+}
+func (s *S3Driver) deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error {
+ keyStart, err := codec.MvccEncode(key, start, "")
+ if err != nil {
+ return err
+ }
+ keyEnd, err := codec.MvccEncode(key, end, "")
+ if err != nil {
+ return err
+ }
+ err = s.driver.DeleteRange(ctx, keyStart, keyEnd)
+ return err
+}
+
+func (s *S3Driver) GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) {
+ minioKey, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, err
+ }
+
+ keys, values, err := s.driver.Scan(ctx, append(key, byte('_')), minioKey, 1, false)
+ if values == nil || keys == nil {
+ return nil, err
+ }
+
+ _, _, suffix, err := codec.MvccDecode(keys[0])
+ if err != nil {
+ return nil, err
+ }
+ if suffix == "delete" {
+ return nil, nil
+ }
+
+ return values[0], err
+}
+func (s *S3Driver) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) {
+ var values []Value
+ for i, key := range keys {
+ value, err := s.GetRow(ctx, key, timestamps[i])
+ if err != nil {
+ return nil, err
+ }
+ values = append(values, value)
+ }
+ return values, nil
+}
+
+func (s *S3Driver) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error {
+ minioKey, err := codec.MvccEncode(key, timestamp, segment)
+ if err != nil {
+ return err
+ }
+ err = s.driver.Put(ctx, minioKey, value)
+ return err
+}
+func (s *S3Driver) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error {
+ maxThread := 100
+ batchSize := 1
+ keysLength := len(keys)
+
+ if keysLength/batchSize > maxThread {
+ batchSize = keysLength / maxThread
+ }
+
+ batchNums := keysLength / batchSize
+
+ if keysLength%batchSize != 0 {
+ batchNums = keysLength/batchSize + 1
+ }
+
+ errCh := make(chan error)
+ f := func(ctx2 context.Context, keys2 []Key, values2 []Value, segments2 []string, timestamps2 []Timestamp) {
+ for i := 0; i < len(keys2); i++ {
+ err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i])
+ errCh <- err
+ }
+ }
+ for i := 0; i < batchNums; i++ {
+ j := i
+ go func() {
+ start, end := j*batchSize, (j+1)*batchSize
+ if len(keys) < end {
+ end = len(keys)
+ }
+ f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end])
+ }()
+ }
+
+ for i := 0; i < len(keys); i++ {
+ if err := <-errCh; err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *S3Driver) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) {
+ keyEnd, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, err
+ }
+ keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1, true)
+ if err != nil {
+ return nil, err
+ }
+ segmentsSet := map[string]bool{}
+ for _, key := range keys {
+ _, _, segment, err := codec.MvccDecode(key)
+ if err != nil {
+ panic("must no error")
+ }
+ if segment != "delete" {
+ segmentsSet[segment] = true
+ }
+ }
+
+ var segments []string
+ for k, v := range segmentsSet {
+ if v {
+ segments = append(segments, k)
+ }
+ }
+ return segments, err
+}
+
+func (s *S3Driver) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error {
+ minioKey, err := codec.MvccEncode(key, timestamp, "delete")
+ if err != nil {
+ return err
+ }
+ value := []byte("0")
+ err = s.driver.Put(ctx, minioKey, value)
+ return err
+}
+
+func (s *S3Driver) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error {
+ maxThread := 100
+ batchSize := 1
+ keysLength := len(keys)
+
+ if keysLength/batchSize > maxThread {
+ batchSize = keysLength / maxThread
+ }
+
+ batchNums := keysLength / batchSize
+
+ if keysLength%batchSize != 0 {
+ batchNums = keysLength/batchSize + 1
+ }
+
+ errCh := make(chan error)
+ f := func(ctx2 context.Context, keys2 []Key, timestamps2 []Timestamp) {
+ for i := 0; i < len(keys2); i++ {
+ err := s.DeleteRow(ctx2, keys2[i], timestamps2[i])
+ errCh <- err
+ }
+ }
+ for i := 0; i < batchNums; i++ {
+ j := i
+ go func() {
+ start, end := j*batchSize, (j+1)*batchSize
+ if len(keys) < end {
+ end = len(keys)
+ }
+ f(ctx, keys[start:end], timestamps[start:end])
+ }()
+ }
+
+ for i := 0; i < len(keys); i++ {
+ if err := <-errCh; err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *S3Driver) PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error {
+ logKey := codec.LogEncode(key, timestamp, channel)
+ err := s.driver.Put(ctx, logKey, value)
+ return err
+}
+
+func (s *S3Driver) GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error) {
+ keys, values, err := s.driver.GetByPrefix(ctx, []byte("log_"), false)
+ if err != nil {
+ return nil, err
+ }
+
+ var resultValues []Value
+ for i, key := range keys {
+ _, ts, channel, err := codec.LogDecode(string(key))
+ if err != nil {
+ return nil, err
+ }
+ if ts >= start && ts <= end {
+ for j := 0; j < len(channels); j++ {
+ if channel == channels[j] {
+ resultValues = append(resultValues, values[i])
+ }
+ }
+ }
+ }
+
+ return resultValues, nil
+}
+
+func (s *S3Driver) GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) {
+
+ return s.driver.Get(ctx, codec.SegmentEncode(segment, "index"))
+}
+
+func (s *S3Driver) PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error {
+
+ return s.driver.Put(ctx, codec.SegmentEncode(segment, "index"), index)
+}
+
+func (s *S3Driver) DeleteSegmentIndex(ctx context.Context, segment string) error {
+
+ return s.driver.Delete(ctx, codec.SegmentEncode(segment, "index"))
+}
+
+func (s *S3Driver) GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) {
+
+ return s.driver.Get(ctx, codec.SegmentEncode(segment, "DL"))
+}
+
+func (s *S3Driver) PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error {
+
+ return s.driver.Put(ctx, codec.SegmentEncode(segment, "DL"), log)
+}
+
+func (s *S3Driver) DeleteSegmentDL(ctx context.Context, segment string) error {
+
+ return s.driver.Delete(ctx, codec.SegmentEncode(segment, "DL"))
+}
diff --git a/internal/storage/internal/minio/codec/codec.go b/internal/storage/internal/minio/codec/codec.go
new file mode 100644
index 0000000000000000000000000000000000000000..4d2b76ee2f939b9f221b47d2ef3f38ce8fe0c71e
--- /dev/null
+++ b/internal/storage/internal/minio/codec/codec.go
@@ -0,0 +1,101 @@
+package codec
+
+import (
+ "errors"
+ "fmt"
+)
+
+func MvccEncode(key []byte, ts uint64, suffix string) ([]byte, error) {
+ return []byte(string(key) + "_" + fmt.Sprintf("%016x", ^ts) + "_" + suffix), nil
+}
+
+func MvccDecode(key []byte) (string, uint64, string, error) {
+ if len(key) < 16 {
+ return "", 0, "", errors.New("insufficient bytes to decode value")
+ }
+
+ suffixIndex := 0
+ TSIndex := 0
+ undersCount := 0
+ for i := len(key) - 1; i > 0; i-- {
+ if key[i] == byte('_') {
+ undersCount++
+ if undersCount == 1 {
+ suffixIndex = i + 1
+ }
+ if undersCount == 2 {
+ TSIndex = i + 1
+ break
+ }
+ }
+ }
+ if suffixIndex == 0 || TSIndex == 0 {
+ return "", 0, "", errors.New("key is wrong formatted")
+ }
+
+ var TS uint64
+ _, err := fmt.Sscanf(string(key[TSIndex:suffixIndex-1]), "%x", &TS)
+ TS = ^TS
+ if err != nil {
+ return "", 0, "", err
+ }
+
+ return string(key[0 : TSIndex-1]), TS, string(key[suffixIndex:]), nil
+}
+
+func LogEncode(key []byte, ts uint64, channel int) []byte {
+ suffix := string(key) + "_" + fmt.Sprintf("%d", channel)
+ logKey, err := MvccEncode([]byte("log"), ts, suffix)
+ if err != nil {
+ return nil
+ }
+ return logKey
+}
+
+func LogDecode(logKey string) (string, uint64, int, error) {
+ if len(logKey) < 16 {
+ return "", 0, 0, errors.New("insufficient bytes to decode value")
+ }
+
+ channelIndex := 0
+ keyIndex := 0
+ TSIndex := 0
+ undersCount := 0
+
+ for i := len(logKey) - 1; i > 0; i-- {
+ if logKey[i] == '_' {
+ undersCount++
+ if undersCount == 1 {
+ channelIndex = i + 1
+ }
+ if undersCount == 2 {
+ keyIndex = i + 1
+ }
+ if undersCount == 3 {
+ TSIndex = i + 1
+ break
+ }
+ }
+ }
+ if channelIndex == 0 || TSIndex == 0 || keyIndex == 0 || logKey[:TSIndex-1] != "log" {
+ return "", 0, 0, errors.New("key is wrong formatted")
+ }
+
+ var TS uint64
+ var channel int
+ _, err := fmt.Sscanf(logKey[TSIndex:keyIndex-1], "%x", &TS)
+ if err != nil {
+ return "", 0, 0, err
+ }
+ TS = ^TS
+
+ _, err = fmt.Sscanf(logKey[channelIndex:], "%d", &channel)
+ if err != nil {
+ return "", 0, 0, err
+ }
+ return logKey[keyIndex : channelIndex-1], TS, channel, nil
+}
+
+func SegmentEncode(segment string, suffix string) []byte {
+ return []byte(segment + "_" + suffix)
+}
diff --git a/internal/storage/internal/minio/minio_store.go b/internal/storage/internal/minio/minio_store.go
new file mode 100644
index 0000000000000000000000000000000000000000..18e2512401ac94093cd55eb89b016a9be4d92dcf
--- /dev/null
+++ b/internal/storage/internal/minio/minio_store.go
@@ -0,0 +1,361 @@
+package miniodriver
+
+import (
+ "context"
+
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
+ "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio/codec"
+ storageType "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+type MinioDriver struct {
+ driver *minioStore
+}
+
+var bucketName string
+
+func NewMinioDriver(ctx context.Context, option storageType.Option) (*MinioDriver, error) {
+ // to-do read conf
+ var endPoint = "localhost:9000"
+ var accessKeyID = "testminio"
+ var secretAccessKey = "testminio"
+ var useSSL = false
+
+ bucketName := option.BucketName
+
+ minioClient, err := minio.New(endPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
+ Secure: useSSL,
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ bucketExists, err := minioClient.BucketExists(ctx, bucketName)
+ if err != nil {
+ return nil, err
+ }
+
+ if !bucketExists {
+ err = minioClient.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{})
+ if err != nil {
+ return nil, err
+ }
+ }
+ return &MinioDriver{
+ &minioStore{
+ client: minioClient,
+ },
+ }, nil
+}
+
+func (s *MinioDriver) put(ctx context.Context, key storageType.Key, value storageType.Value, timestamp storageType.Timestamp, suffix string) error {
+ minioKey, err := codec.MvccEncode(key, timestamp, suffix)
+ if err != nil {
+ return err
+ }
+
+ err = s.driver.Put(ctx, minioKey, value)
+ return err
+}
+
+func (s *MinioDriver) scanLE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp, keyOnly bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) {
+ keyEnd, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ keys, values, err := s.driver.Scan(ctx, key, []byte(keyEnd), -1, keyOnly)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ var timestamps []storageType.Timestamp
+ for _, key := range keys {
+ _, timestamp, _, _ := codec.MvccDecode(key)
+ timestamps = append(timestamps, timestamp)
+ }
+
+ return timestamps, keys, values, nil
+}
+
+func (s *MinioDriver) scanGE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp, keyOnly bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) {
+ keyStart, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ keys, values, err := s.driver.Scan(ctx, key, keyStart, -1, keyOnly)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ var timestamps []storageType.Timestamp
+ for _, key := range keys {
+ _, timestamp, _, _ := codec.MvccDecode(key)
+ timestamps = append(timestamps, timestamp)
+ }
+
+ return timestamps, keys, values, nil
+}
+
+//scan(ctx context.Context, key storageType.Key, start storageType.Timestamp, end storageType.Timestamp, withValue bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error)
+func (s *MinioDriver) deleteLE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error {
+ keyEnd, err := codec.MvccEncode(key, timestamp, "delete")
+ if err != nil {
+ return err
+ }
+ err = s.driver.DeleteRange(ctx, key, keyEnd)
+ return err
+}
+func (s *MinioDriver) deleteGE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error {
+ keys, _, err := s.driver.GetByPrefix(ctx, key, true)
+ if err != nil {
+ return err
+ }
+ keyStart, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ panic(err)
+ }
+ err = s.driver.DeleteRange(ctx, keyStart, keys[len(keys)-1])
+ if err != nil {
+ panic(err)
+ }
+ return nil
+}
+func (s *MinioDriver) deleteRange(ctx context.Context, key storageType.Key, start storageType.Timestamp, end storageType.Timestamp) error {
+ keyStart, err := codec.MvccEncode(key, start, "")
+ if err != nil {
+ return err
+ }
+ keyEnd, err := codec.MvccEncode(key, end, "")
+ if err != nil {
+ return err
+ }
+ err = s.driver.DeleteRange(ctx, keyStart, keyEnd)
+ return err
+}
+
+func (s *MinioDriver) GetRow(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) (storageType.Value, error) {
+ minioKey, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, err
+ }
+
+ keys, values, err := s.driver.Scan(ctx, append(key, byte('_')), minioKey, 1, false)
+ if values == nil || keys == nil {
+ return nil, err
+ }
+
+ _, _, suffix, err := codec.MvccDecode(keys[0])
+ if err != nil {
+ return nil, err
+ }
+ if suffix == "delete" {
+ return nil, nil
+ }
+
+ return values[0], err
+}
+func (s *MinioDriver) GetRows(ctx context.Context, keys []storageType.Key, timestamps []storageType.Timestamp) ([]storageType.Value, error) {
+ var values []storageType.Value
+ for i, key := range keys {
+ value, err := s.GetRow(ctx, key, timestamps[i])
+ if err != nil {
+ return nil, err
+ }
+ values = append(values, value)
+ }
+ return values, nil
+}
+
+func (s *MinioDriver) PutRow(ctx context.Context, key storageType.Key, value storageType.Value, segment string, timestamp storageType.Timestamp) error {
+ minioKey, err := codec.MvccEncode(key, timestamp, segment)
+ if err != nil {
+ return err
+ }
+ err = s.driver.Put(ctx, minioKey, value)
+ return err
+}
+func (s *MinioDriver) PutRows(ctx context.Context, keys []storageType.Key, values []storageType.Value, segments []string, timestamps []storageType.Timestamp) error {
+ maxThread := 100
+ batchSize := 1
+ keysLength := len(keys)
+
+ if keysLength/batchSize > maxThread {
+ batchSize = keysLength / maxThread
+ }
+
+ batchNums := keysLength / batchSize
+
+ if keysLength%batchSize != 0 {
+ batchNums = keysLength/batchSize + 1
+ }
+
+ errCh := make(chan error)
+ f := func(ctx2 context.Context, keys2 []storageType.Key, values2 []storageType.Value, segments2 []string, timestamps2 []storageType.Timestamp) {
+ for i := 0; i < len(keys2); i++ {
+ err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i])
+ errCh <- err
+ }
+ }
+ for i := 0; i < batchNums; i++ {
+ j := i
+ go func() {
+ start, end := j*batchSize, (j+1)*batchSize
+ if len(keys) < end {
+ end = len(keys)
+ }
+ f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end])
+ }()
+ }
+
+ for i := 0; i < len(keys); i++ {
+ if err := <-errCh; err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *MinioDriver) GetSegments(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) ([]string, error) {
+ keyEnd, err := codec.MvccEncode(key, timestamp, "")
+ if err != nil {
+ return nil, err
+ }
+ keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1, true)
+ if err != nil {
+ return nil, err
+ }
+ segmentsSet := map[string]bool{}
+ for _, key := range keys {
+ _, _, segment, err := codec.MvccDecode(key)
+ if err != nil {
+ panic("must no error")
+ }
+ if segment != "delete" {
+ segmentsSet[segment] = true
+ }
+ }
+
+ var segments []string
+ for k, v := range segmentsSet {
+ if v {
+ segments = append(segments, k)
+ }
+ }
+ return segments, err
+}
+
+func (s *MinioDriver) DeleteRow(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error {
+ minioKey, err := codec.MvccEncode(key, timestamp, "delete")
+ if err != nil {
+ return err
+ }
+ value := []byte("0")
+ err = s.driver.Put(ctx, minioKey, value)
+ return err
+}
+
+func (s *MinioDriver) DeleteRows(ctx context.Context, keys []storageType.Key, timestamps []storageType.Timestamp) error {
+ maxThread := 100
+ batchSize := 1
+ keysLength := len(keys)
+
+ if keysLength/batchSize > maxThread {
+ batchSize = keysLength / maxThread
+ }
+
+ batchNums := keysLength / batchSize
+
+ if keysLength%batchSize != 0 {
+ batchNums = keysLength/batchSize + 1
+ }
+
+ errCh := make(chan error)
+ f := func(ctx2 context.Context, keys2 []storageType.Key, timestamps2 []storageType.Timestamp) {
+ for i := 0; i < len(keys2); i++ {
+ err := s.DeleteRow(ctx2, keys2[i], timestamps2[i])
+ errCh <- err
+ }
+ }
+ for i := 0; i < batchNums; i++ {
+ j := i
+ go func() {
+ start, end := j*batchSize, (j+1)*batchSize
+ if len(keys) < end {
+ end = len(keys)
+ }
+ f(ctx, keys[start:end], timestamps[start:end])
+ }()
+ }
+
+ for i := 0; i < len(keys); i++ {
+ if err := <-errCh; err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *MinioDriver) PutLog(ctx context.Context, key storageType.Key, value storageType.Value, timestamp storageType.Timestamp, channel int) error {
+ logKey := codec.LogEncode(key, timestamp, channel)
+ err := s.driver.Put(ctx, logKey, value)
+ return err
+}
+
+func (s *MinioDriver) GetLog(ctx context.Context, start storageType.Timestamp, end storageType.Timestamp, channels []int) ([]storageType.Value, error) {
+ keys, values, err := s.driver.GetByPrefix(ctx, []byte("log_"), false)
+ if err != nil {
+ return nil, err
+ }
+
+ var resultValues []storageType.Value
+ for i, key := range keys {
+ _, ts, channel, err := codec.LogDecode(string(key))
+ if err != nil {
+ return nil, err
+ }
+ if ts >= start && ts <= end {
+ for j := 0; j < len(channels); j++ {
+ if channel == channels[j] {
+ resultValues = append(resultValues, values[i])
+ }
+ }
+ }
+ }
+
+ return resultValues, nil
+}
+
+func (s *MinioDriver) GetSegmentIndex(ctx context.Context, segment string) (storageType.SegmentIndex, error) {
+
+ return s.driver.Get(ctx, codec.SegmentEncode(segment, "index"))
+}
+
+func (s *MinioDriver) PutSegmentIndex(ctx context.Context, segment string, index storageType.SegmentIndex) error {
+
+ return s.driver.Put(ctx, codec.SegmentEncode(segment, "index"), index)
+}
+
+func (s *MinioDriver) DeleteSegmentIndex(ctx context.Context, segment string) error {
+
+ return s.driver.Delete(ctx, codec.SegmentEncode(segment, "index"))
+}
+
+func (s *MinioDriver) GetSegmentDL(ctx context.Context, segment string) (storageType.SegmentDL, error) {
+
+ return s.driver.Get(ctx, codec.SegmentEncode(segment, "DL"))
+}
+
+func (s *MinioDriver) PutSegmentDL(ctx context.Context, segment string, log storageType.SegmentDL) error {
+
+ return s.driver.Put(ctx, codec.SegmentEncode(segment, "DL"), log)
+}
+
+func (s *MinioDriver) DeleteSegmentDL(ctx context.Context, segment string) error {
+
+ return s.driver.Delete(ctx, codec.SegmentEncode(segment, "DL"))
+}
diff --git a/internal/storage/internal/minio/minio_storeEngine.go b/internal/storage/internal/minio/minio_storeEngine.go
new file mode 100644
index 0000000000000000000000000000000000000000..64d74e859032b0306c9e9cce655d48d3df52a8a0
--- /dev/null
+++ b/internal/storage/internal/minio/minio_storeEngine.go
@@ -0,0 +1,130 @@
+package miniodriver
+
+import (
+ "bytes"
+ "context"
+ "io"
+
+ "github.com/minio/minio-go/v7"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+type minioStore struct {
+ client *minio.Client
+}
+
+func (s *minioStore) Put(ctx context.Context, key Key, value Value) error {
+ reader := bytes.NewReader(value)
+ _, err := s.client.PutObject(ctx, bucketName, string(key), reader, int64(len(value)), minio.PutObjectOptions{})
+
+ if err != nil {
+ return err
+ }
+
+ return err
+}
+
+func (s *minioStore) Get(ctx context.Context, key Key) (Value, error) {
+ object, err := s.client.GetObject(ctx, bucketName, string(key), minio.GetObjectOptions{})
+ if err != nil {
+ return nil, err
+ }
+
+ size := 256 * 1024
+ buf := make([]byte, size)
+ n, err := object.Read(buf)
+ if err != nil && err != io.EOF {
+ return nil, err
+ }
+ return buf[:n], nil
+}
+
+func (s *minioStore) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) {
+ objects := s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(prefix)})
+
+ var objectsKeys []Key
+ var objectsValues []Value
+
+ for object := range objects {
+ objectsKeys = append(objectsKeys, []byte(object.Key))
+ if !keyOnly {
+ value, err := s.Get(ctx, []byte(object.Key))
+ if err != nil {
+ return nil, nil, err
+ }
+ objectsValues = append(objectsValues, value)
+ }
+ }
+
+ return objectsKeys, objectsValues, nil
+
+}
+
+func (s *minioStore) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit int, keyOnly bool) ([]Key, []Value, error) {
+ var keys []Key
+ var values []Value
+ limitCount := uint(limit)
+ for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(keyStart)}) {
+ if object.Key >= string(keyEnd) {
+ keys = append(keys, []byte(object.Key))
+ if !keyOnly {
+ value, err := s.Get(ctx, []byte(object.Key))
+ if err != nil {
+ return nil, nil, err
+ }
+ values = append(values, value)
+ }
+ limitCount--
+ if limitCount <= 0 {
+ break
+ }
+ }
+ }
+
+ return keys, values, nil
+}
+
+func (s *minioStore) Delete(ctx context.Context, key Key) error {
+ err := s.client.RemoveObject(ctx, bucketName, string(key), minio.RemoveObjectOptions{})
+ return err
+}
+
+func (s *minioStore) DeleteByPrefix(ctx context.Context, prefix Key) error {
+ objectsCh := make(chan minio.ObjectInfo)
+
+ go func() {
+ defer close(objectsCh)
+
+ for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(prefix)}) {
+ objectsCh <- object
+ }
+ }()
+
+ for rErr := range s.client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
+ if rErr.Err != nil {
+ return rErr.Err
+ }
+ }
+ return nil
+}
+
+func (s *minioStore) DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error {
+ objectsCh := make(chan minio.ObjectInfo)
+
+ go func() {
+ defer close(objectsCh)
+
+ for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(keyStart)}) {
+ if object.Key <= string(keyEnd) {
+ objectsCh <- object
+ }
+ }
+ }()
+
+ for rErr := range s.client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
+ if rErr.Err != nil {
+ return rErr.Err
+ }
+ }
+ return nil
+}
diff --git a/internal/storage/internal/minio/minio_test.go b/internal/storage/internal/minio/minio_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..d98a98cfeaac2a4d77b8cf999940a04c1a61c71d
--- /dev/null
+++ b/internal/storage/internal/minio/minio_test.go
@@ -0,0 +1,134 @@
+package miniodriver
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+var option = storagetype.Option{BucketName: "zilliz-hz"}
+var ctx = context.Background()
+var client, err = NewMinioDriver(ctx, option)
+
+func TestMinioDriver_PutRowAndGetRow(t *testing.T) {
+ err = client.PutRow(ctx, []byte("bar"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("bar"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("bar"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("bar1"), []byte("testkeybarorbar_1"), "SegmentC", 3)
+ assert.Nil(t, err)
+ object, _ := client.GetRow(ctx, []byte("bar"), 5)
+ assert.Equal(t, "abcdefghijklmnoopqrstuvwxyz", string(object))
+ object, _ = client.GetRow(ctx, []byte("bar"), 2)
+ assert.Equal(t, "djhfkjsbdfbsdughorsgsdjhgoisdgh", string(object))
+ object, _ = client.GetRow(ctx, []byte("bar"), 5)
+ assert.Equal(t, "123854676ershdgfsgdfk,sdhfg;sdi8", string(object))
+ object, _ = client.GetRow(ctx, []byte("bar1"), 5)
+ assert.Equal(t, "testkeybarorbar_1", string(object))
+}
+
+func TestMinioDriver_DeleteRow(t *testing.T) {
+ err = client.DeleteRow(ctx, []byte("bar"), 5)
+ assert.Nil(t, err)
+ object, _ := client.GetRow(ctx, []byte("bar"), 6)
+ assert.Nil(t, object)
+ err = client.DeleteRow(ctx, []byte("bar1"), 5)
+ assert.Nil(t, err)
+ object2, _ := client.GetRow(ctx, []byte("bar1"), 6)
+ assert.Nil(t, object2)
+}
+
+func TestMinioDriver_GetSegments(t *testing.T) {
+ err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
+ assert.Nil(t, err)
+ err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1)
+ assert.Nil(t, err)
+
+ segements, err := client.GetSegments(ctx, []byte("seg"), 4)
+ assert.Nil(t, err)
+ assert.Equal(t, 2, len(segements))
+ if segements[0] == "SegmentA" {
+ assert.Equal(t, "SegmentA", segements[0])
+ assert.Equal(t, "SegmentB", segements[1])
+ } else {
+ assert.Equal(t, "SegmentB", segements[0])
+ assert.Equal(t, "SegmentA", segements[1])
+ }
+}
+
+func TestMinioDriver_PutRowsAndGetRows(t *testing.T) {
+ keys := [][]byte{[]byte("foo"), []byte("bar")}
+ values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")}
+ segments := []string{"segmentA", "segmentB"}
+ timestamps := []uint64{1, 2}
+ err = client.PutRows(ctx, keys, values, segments, timestamps)
+ assert.Nil(t, err)
+
+ objects, err := client.GetRows(ctx, keys, timestamps)
+ assert.Nil(t, err)
+ assert.Equal(t, "The key is foo!", string(objects[0]))
+ assert.Equal(t, "The key is bar!", string(objects[1]))
+}
+
+func TestMinioDriver_DeleteRows(t *testing.T) {
+ keys := [][]byte{[]byte("foo"), []byte("bar")}
+ timestamps := []uint64{3, 3}
+ err := client.DeleteRows(ctx, keys, timestamps)
+ assert.Nil(t, err)
+
+ objects, err := client.GetRows(ctx, keys, timestamps)
+ assert.Nil(t, err)
+ assert.Nil(t, objects[0])
+ assert.Nil(t, objects[1])
+}
+
+func TestMinioDriver_PutLogAndGetLog(t *testing.T) {
+ err = client.PutLog(ctx, []byte("insert"), []byte("This is insert log!"), 1, 11)
+ assert.Nil(t, err)
+ err = client.PutLog(ctx, []byte("delete"), []byte("This is delete log!"), 2, 10)
+ assert.Nil(t, err)
+ err = client.PutLog(ctx, []byte("update"), []byte("This is update log!"), 3, 9)
+ assert.Nil(t, err)
+ err = client.PutLog(ctx, []byte("select"), []byte("This is select log!"), 4, 8)
+ assert.Nil(t, err)
+
+ channels := []int{5, 8, 9, 10, 11, 12, 13}
+ logValues, err := client.GetLog(ctx, 0, 5, channels)
+ assert.Nil(t, err)
+ assert.Equal(t, "This is select log!", string(logValues[0]))
+ assert.Equal(t, "This is update log!", string(logValues[1]))
+ assert.Equal(t, "This is delete log!", string(logValues[2]))
+ assert.Equal(t, "This is insert log!", string(logValues[3]))
+}
+
+func TestMinioDriver_Segment(t *testing.T) {
+ err := client.PutSegmentIndex(ctx, "segmentA", []byte("This is segmentA's index!"))
+ assert.Nil(t, err)
+
+ segmentIndex, err := client.GetSegmentIndex(ctx, "segmentA")
+ assert.Equal(t, "This is segmentA's index!", string(segmentIndex))
+ assert.Nil(t, err)
+
+ err = client.DeleteSegmentIndex(ctx, "segmentA")
+ assert.Nil(t, err)
+}
+
+func TestMinioDriver_SegmentDL(t *testing.T) {
+ err := client.PutSegmentDL(ctx, "segmentB", []byte("This is segmentB's delete log!"))
+ assert.Nil(t, err)
+
+ segmentDL, err := client.GetSegmentDL(ctx, "segmentB")
+ assert.Nil(t, err)
+ assert.Equal(t, "This is segmentB's delete log!", string(segmentDL))
+
+ err = client.DeleteSegmentDL(ctx, "segmentB")
+ assert.Nil(t, err)
+}
diff --git a/internal/storage/internal/tikv/codec/codec.go b/internal/storage/internal/tikv/codec/codec.go
new file mode 100644
index 0000000000000000000000000000000000000000..ca09296097e8778ca1023c3395801e7cb0e99669
--- /dev/null
+++ b/internal/storage/internal/tikv/codec/codec.go
@@ -0,0 +1,62 @@
+package codec
+
+import (
+ "encoding/binary"
+ "errors"
+
+ "github.com/tikv/client-go/codec"
+)
+
+var (
+ Delimiter = byte('_')
+ DelimiterPlusOne = Delimiter + 0x01
+ DeleteMark = byte('d')
+ SegmentIndexMark = byte('i')
+ SegmentDLMark = byte('d')
+)
+
+// EncodeKey append timestamp, delimiter, and suffix string
+// to one slice key.
+// Note: suffix string should not contains Delimiter
+func EncodeKey(key []byte, timestamp uint64, suffix string) []byte {
+ //TODO: should we encode key to memory comparable
+ ret := EncodeDelimiter(key, Delimiter)
+ ret = codec.EncodeUintDesc(ret, timestamp)
+ return append(ret, suffix...)
+}
+
+func DecodeKey(key []byte) ([]byte, uint64, string, error) {
+ if len(key) < 8 {
+ return nil, 0, "", errors.New("insufficient bytes to decode value")
+ }
+
+ lenDeKey := 0
+ for i := len(key) - 1; i > 0; i-- {
+ if key[i] == Delimiter {
+ lenDeKey = i
+ break
+ }
+ }
+
+ if lenDeKey == 0 || lenDeKey+8 > len(key) {
+ return nil, 0, "", errors.New("insufficient bytes to decode value")
+ }
+
+ tsBytes := key[lenDeKey+1 : lenDeKey+9]
+ ts := binary.BigEndian.Uint64(tsBytes)
+ suffix := string(key[lenDeKey+9:])
+ key = key[:lenDeKey-1]
+ return key, ^ts, suffix, nil
+}
+
+// EncodeDelimiter append a delimiter byte to slice b, and return the appended slice.
+func EncodeDelimiter(b []byte, delimiter byte) []byte {
+ return append(b, delimiter)
+}
+
+func EncodeSegment(segName []byte, segType byte) []byte {
+ segmentKey := []byte("segment")
+ segmentKey = append(segmentKey, Delimiter)
+ segmentKey = append(segmentKey, segName...)
+ return append(segmentKey, Delimiter, segType)
+}
diff --git a/internal/storage/internal/tikv/tikv_store.go b/internal/storage/internal/tikv/tikv_store.go
new file mode 100644
index 0000000000000000000000000000000000000000..5ecf8936f675f574f262219013001db4519f77cd
--- /dev/null
+++ b/internal/storage/internal/tikv/tikv_store.go
@@ -0,0 +1,389 @@
+package tikvdriver
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "strings"
+
+ "github.com/tikv/client-go/config"
+ "github.com/tikv/client-go/rawkv"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv/codec"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/type"
+ storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+func keyAddOne(key Key) Key {
+ if key == nil {
+ return nil
+ }
+ lenKey := len(key)
+ ret := make(Key, lenKey)
+ copy(ret, key)
+ ret[lenKey-1] += 0x01
+ return ret
+}
+
+type tikvEngine struct {
+ client *rawkv.Client
+ conf config.Config
+}
+
+func (e tikvEngine) Put(ctx context.Context, key Key, value Value) error {
+ return e.client.Put(ctx, key, value)
+}
+
+func (e tikvEngine) BatchPut(ctx context.Context, keys []Key, values []Value) error {
+ return e.client.BatchPut(ctx, keys, values)
+}
+
+func (e tikvEngine) Get(ctx context.Context, key Key) (Value, error) {
+ return e.client.Get(ctx, key)
+}
+
+func (e tikvEngine) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) (keys []Key, values []Value, err error) {
+ startKey := prefix
+ endKey := keyAddOne(prefix)
+ limit := e.conf.Raw.MaxScanLimit
+ for {
+ ks, vs, err := e.Scan(ctx, startKey, endKey, limit, keyOnly)
+ if err != nil {
+ return keys, values, err
+ }
+ keys = append(keys, ks...)
+ values = append(values, vs...)
+ if len(ks) < limit {
+ break
+ }
+ // update the start key, and exclude the start key
+ startKey = append(ks[len(ks)-1], '\000')
+ }
+ return
+}
+
+func (e tikvEngine) Scan(ctx context.Context, startKey Key, endKey Key, limit int, keyOnly bool) ([]Key, []Value, error) {
+ return e.client.Scan(ctx, startKey, endKey, limit, rawkv.ScanOption{KeyOnly: keyOnly})
+}
+
+func (e tikvEngine) Delete(ctx context.Context, key Key) error {
+ return e.client.Delete(ctx, key)
+}
+
+func (e tikvEngine) DeleteByPrefix(ctx context.Context, prefix Key) error {
+ startKey := prefix
+ endKey := keyAddOne(prefix)
+ return e.client.DeleteRange(ctx, startKey, endKey)
+}
+
+func (e tikvEngine) DeleteRange(ctx context.Context, startKey Key, endKey Key) error {
+ return e.client.DeleteRange(ctx, startKey, endKey)
+}
+
+func (e tikvEngine) Close() error {
+ return e.client.Close()
+}
+
+type TikvStore struct {
+ engine *tikvEngine
+}
+
+func NewTikvStore(ctx context.Context, option storagetype.Option) (*TikvStore, error) {
+
+ conf := config.Default()
+ client, err := rawkv.NewClient(ctx, []string{option.TikvAddress}, conf)
+ if err != nil {
+ return nil, err
+ }
+ return &TikvStore{
+ &tikvEngine{
+ client: client,
+ conf: conf,
+ },
+ }, nil
+}
+
+func (s *TikvStore) Name() string {
+ return "TiKV storage"
+}
+
+func (s *TikvStore) put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error {
+ return s.engine.Put(ctx, EncodeKey(key, timestamp, suffix), value)
+}
+
+func (s *TikvStore) scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
+ panic("implement me")
+}
+
+func (s *TikvStore) scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
+ panic("implement me")
+}
+
+func (s *TikvStore) scan(ctx context.Context, key Key, start Timestamp, end Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
+ //startKey := EncodeKey(key, start, "")
+ //endKey := EncodeKey(EncodeDelimiter(key, DelimiterPlusOne), end, "")
+ //return s.engine.Scan(ctx, startKey, endKey, -1, keyOnly)
+ panic("implement me")
+}
+
+func (s *TikvStore) deleteLE(ctx context.Context, key Key, timestamp Timestamp) error {
+ panic("implement me")
+}
+
+func (s *TikvStore) deleteGE(ctx context.Context, key Key, timestamp Timestamp) error {
+ panic("implement me")
+}
+
+func (s *TikvStore) deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error {
+ panic("implement me")
+}
+
+func (s *TikvStore) GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) {
+ startKey := EncodeKey(key, timestamp, "")
+ endKey := EncodeDelimiter(key, DelimiterPlusOne)
+ keys, values, err := s.engine.Scan(ctx, startKey, endKey, 1, false)
+ if err != nil || keys == nil {
+ return nil, err
+ }
+ _, _, suffix, err := DecodeKey(keys[0])
+ if err != nil {
+ return nil, err
+ }
+ // key is marked deleted
+ if suffix == string(DeleteMark) {
+ return nil, nil
+ }
+ return values[0], nil
+}
+
+// TODO: how to spilt keys to some batches
+var batchSize = 100
+
+type kvPair struct {
+ key Key
+ value Value
+ err error
+}
+
+func batchKeys(keys []Key) [][]Key {
+ keysLen := len(keys)
+ numBatch := (keysLen-1)/batchSize + 1
+ batches := make([][]Key, numBatch)
+
+ for i := 0; i < numBatch; i++ {
+ batchStart := i * batchSize
+ batchEnd := batchStart + batchSize
+ // the last batch
+ if i == numBatch-1 {
+ batchEnd = keysLen
+ }
+ batches[i] = keys[batchStart:batchEnd]
+ }
+ return batches
+}
+
+func (s *TikvStore) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) {
+ if len(keys) != len(timestamps) {
+ return nil, errors.New("the len of keys is not equal to the len of timestamps")
+ }
+
+ batches := batchKeys(keys)
+ ch := make(chan kvPair, len(keys))
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ for n, b := range batches {
+ batch := b
+ numBatch := n
+ go func() {
+ for i, key := range batch {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ v, err := s.GetRow(ctx, key, timestamps[numBatch*batchSize+i])
+ ch <- kvPair{
+ key: key,
+ value: v,
+ err: err,
+ }
+ }
+ }
+ }()
+ }
+
+ var err error
+ var values []Value
+ kvMap := make(map[string]Value)
+ for i := 0; i < len(keys); i++ {
+ kv := <-ch
+ if kv.err != nil {
+ cancel()
+ if err == nil {
+ err = kv.err
+ }
+ }
+ kvMap[string(kv.key)] = kv.value
+ }
+ for _, key := range keys {
+ values = append(values, kvMap[string(key)])
+ }
+ return values, err
+}
+
+func (s *TikvStore) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error {
+ return s.put(ctx, key, value, timestamp, segment)
+}
+
+func (s *TikvStore) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error {
+ if len(keys) != len(values) {
+ return errors.New("the len of keys is not equal to the len of values")
+ }
+ if len(keys) != len(timestamps) {
+ return errors.New("the len of keys is not equal to the len of timestamps")
+ }
+
+ encodedKeys := make([]Key, len(keys))
+ for i, key := range keys {
+ encodedKeys[i] = EncodeKey(key, timestamps[i], segments[i])
+ }
+ return s.engine.BatchPut(ctx, encodedKeys, values)
+}
+
+func (s *TikvStore) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error {
+ return s.put(ctx, key, Value{0x00}, timestamp, string(DeleteMark))
+}
+
+func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error {
+ encodeKeys := make([]Key, len(keys))
+ values := make([]Value, len(keys))
+ for i, key := range keys {
+ encodeKeys[i] = EncodeKey(key, timestamps[i], string(DeleteMark))
+ values[i] = Value{0x00}
+ }
+ return s.engine.BatchPut(ctx, encodeKeys, values)
+}
+
+//func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamp Timestamp) error {
+// batches := batchKeys(keys)
+// ch := make(chan error, len(batches))
+// ctx, cancel := context.WithCancel(ctx)
+//
+// for _, b := range batches {
+// batch := b
+// go func() {
+// for _, key := range batch {
+// select {
+// case <-ctx.Done():
+// return
+// default:
+// ch <- s.DeleteRow(ctx, key, timestamp)
+// }
+// }
+// }()
+// }
+//
+// var err error
+// for i := 0; i < len(keys); i++ {
+// if e := <-ch; e != nil {
+// cancel()
+// if err == nil {
+// err = e
+// }
+// }
+// }
+// return err
+//}
+
+func (s *TikvStore) PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error {
+ suffix := string(EncodeDelimiter(key, DelimiterPlusOne)) + strconv.Itoa(channel)
+ return s.put(ctx, Key("log"), value, timestamp, suffix)
+}
+
+func (s *TikvStore) GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) (logs []Value, err error) {
+ key := Key("log")
+ startKey := EncodeKey(key, end, "")
+ endKey := EncodeKey(key, start, "")
+ // TODO: use for loop to ensure get all keys
+ keys, values, err := s.engine.Scan(ctx, startKey, endKey, s.engine.conf.Raw.MaxScanLimit, false)
+ if err != nil || keys == nil {
+ return nil, err
+ }
+
+ for i, key := range keys {
+ _, _, suffix, err := DecodeKey(key)
+ log := values[i]
+ if err != nil {
+ return logs, err
+ }
+
+ // no channels filter
+ if len(channels) == 0 {
+ logs = append(logs, log)
+ }
+ slice := strings.Split(suffix, string(DelimiterPlusOne))
+ channel, err := strconv.Atoi(slice[len(slice)-1])
+ if err != nil {
+ panic(err)
+ }
+ for _, item := range channels {
+ if item == channel {
+ logs = append(logs, log)
+ break
+ }
+ }
+ }
+ return
+}
+
+func (s *TikvStore) GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) {
+ return s.engine.Get(ctx, EncodeSegment([]byte(segment), SegmentIndexMark))
+}
+
+func (s *TikvStore) PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error {
+ return s.engine.Put(ctx, EncodeSegment([]byte(segment), SegmentIndexMark), index)
+}
+
+func (s *TikvStore) DeleteSegmentIndex(ctx context.Context, segment string) error {
+ return s.engine.Delete(ctx, EncodeSegment([]byte(segment), SegmentIndexMark))
+}
+
+func (s *TikvStore) GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) {
+ return s.engine.Get(ctx, EncodeSegment([]byte(segment), SegmentDLMark))
+}
+
+func (s *TikvStore) PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error {
+ return s.engine.Put(ctx, EncodeSegment([]byte(segment), SegmentDLMark), log)
+}
+
+func (s *TikvStore) DeleteSegmentDL(ctx context.Context, segment string) error {
+ return s.engine.Delete(ctx, EncodeSegment([]byte(segment), SegmentDLMark))
+}
+
+func (s *TikvStore) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) {
+ keys, _, err := s.engine.GetByPrefix(ctx, EncodeDelimiter(key, Delimiter), true)
+ if err != nil {
+ return nil, err
+ }
+ segmentsSet := map[string]bool{}
+ for _, key := range keys {
+ _, ts, segment, err := DecodeKey(key)
+ if err != nil {
+ panic("must no error")
+ }
+ if ts <= timestamp && segment != string(DeleteMark) {
+ segmentsSet[segment] = true
+ }
+ }
+
+ var segments []string
+ for k, v := range segmentsSet {
+ if v {
+ segments = append(segments, k)
+ }
+ }
+ return segments, err
+}
+
+func (s *TikvStore) Close() error {
+ return s.engine.Close()
+}
diff --git a/internal/storage/internal/tikv/tikv_test.go b/internal/storage/internal/tikv/tikv_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..4e69d14d2c3079f7251123666881206a9af03d03
--- /dev/null
+++ b/internal/storage/internal/tikv/tikv_test.go
@@ -0,0 +1,293 @@
+package tikvdriver
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "os"
+ "sort"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv/codec"
+ . "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+//var store TikvStore
+var store *TikvStore
+var option = Option{TikvAddress: "localhost:2379"}
+
+func TestMain(m *testing.M) {
+ store, _ = NewTikvStore(context.Background(), option)
+ exitCode := m.Run()
+ _ = store.Close()
+ os.Exit(exitCode)
+}
+
+func TestTikvEngine_Prefix(t *testing.T) {
+ ctx := context.Background()
+ prefix := Key("key")
+ engine := store.engine
+ value := Value("value")
+
+ // Put some key with same prefix
+ key := prefix
+ err := engine.Put(ctx, key, value)
+ require.Nil(t, err)
+ key = EncodeKey(prefix, 0, "")
+ err = engine.Put(ctx, key, value)
+ assert.Nil(t, err)
+
+ // Get by prefix
+ ks, _, err := engine.GetByPrefix(ctx, prefix, true)
+ assert.Equal(t, 2, len(ks))
+ assert.Nil(t, err)
+
+ // Delete by prefix
+ err = engine.DeleteByPrefix(ctx, prefix)
+ assert.Nil(t, err)
+ ks, _, err = engine.GetByPrefix(ctx, prefix, true)
+ assert.Equal(t, 0, len(ks))
+ assert.Nil(t, err)
+
+ //Test large amount keys
+ num := engine.conf.Raw.MaxScanLimit + 1
+ keys := make([]Key, num)
+ values := make([]Value, num)
+ for i := 0; i < num; i++ {
+ key = EncodeKey(prefix, uint64(i), "")
+ keys[i] = key
+ values[i] = value
+ }
+ err = engine.BatchPut(ctx, keys, values)
+ assert.Nil(t, err)
+
+ ks, _, err = engine.GetByPrefix(ctx, prefix, true)
+ assert.Nil(t, err)
+ assert.Equal(t, num, len(ks))
+ err = engine.DeleteByPrefix(ctx, prefix)
+ assert.Nil(t, err)
+}
+
+func TestTikvStore_Row(t *testing.T) {
+ ctx := context.Background()
+ key := Key("key")
+
+ // Add same row with different timestamp
+ err := store.PutRow(ctx, key, Value("value0"), "segment0", 0)
+ assert.Nil(t, err)
+ err = store.PutRow(ctx, key, Value("value1"), "segment0", 2)
+ assert.Nil(t, err)
+
+ // Get most recent row using key and timestamp
+ v, err := store.GetRow(ctx, key, 3)
+ assert.Nil(t, err)
+ assert.Equal(t, Value("value1"), v)
+ v, err = store.GetRow(ctx, key, 2)
+ assert.Nil(t, err)
+ assert.Equal(t, Value("value1"), v)
+ v, err = store.GetRow(ctx, key, 1)
+ assert.Nil(t, err)
+ assert.Equal(t, Value("value0"), v)
+
+ // Add a different row, but with same prefix
+ key1 := Key("key_y")
+ err = store.PutRow(ctx, key1, Value("valuey"), "segment0", 2)
+ assert.Nil(t, err)
+
+ // Get most recent row using key and timestamp
+ v, err = store.GetRow(ctx, key, 3)
+ assert.Nil(t, err)
+ assert.Equal(t, Value("value1"), v)
+ v, err = store.GetRow(ctx, key1, 3)
+ assert.Nil(t, err)
+ assert.Equal(t, Value("valuey"), v)
+
+ // Delete a row
+ err = store.DeleteRow(ctx, key, 4)
+ assert.Nil(t, err)
+ v, err = store.GetRow(ctx, key, 5)
+ assert.Nil(t, err)
+ assert.Nil(t, v)
+
+ // Clear test data
+ err = store.engine.DeleteByPrefix(ctx, key)
+ assert.Nil(t, err)
+ k, va, err := store.engine.GetByPrefix(ctx, key, false)
+ assert.Nil(t, err)
+ assert.Nil(t, k)
+ assert.Nil(t, va)
+}
+
+func TestTikvStore_BatchRow(t *testing.T) {
+ ctx := context.Background()
+
+ // Prepare test data
+ size := 0
+ var testKeys []Key
+ var testValues []Value
+ var segments []string
+ var timestamps []Timestamp
+ for i := 0; size/store.engine.conf.Raw.MaxBatchPutSize < 1; i++ {
+ key := fmt.Sprint("key", i)
+ size += len(key)
+ testKeys = append(testKeys, []byte(key))
+ value := fmt.Sprint("value", i)
+ size += len(value)
+ testValues = append(testValues, []byte(value))
+ segments = append(segments, "test")
+ v, err := store.GetRow(ctx, Key(key), math.MaxUint64)
+ assert.Nil(t, v)
+ assert.Nil(t, err)
+ }
+
+ // Batch put rows
+ for range testKeys {
+ timestamps = append(timestamps, 1)
+ }
+ err := store.PutRows(ctx, testKeys, testValues, segments, timestamps)
+ assert.Nil(t, err)
+
+ // Batch get rows
+ for i := range timestamps {
+ timestamps[i] = 2
+ }
+ checkValues, err := store.GetRows(ctx, testKeys, timestamps)
+ assert.NotNil(t, checkValues)
+ assert.Nil(t, err)
+ assert.Equal(t, len(checkValues), len(testValues))
+ for i := range testKeys {
+ assert.Equal(t, testValues[i], checkValues[i])
+ }
+
+ // Delete all test rows
+ for i := range timestamps {
+ timestamps[i] = math.MaxUint64
+ }
+ err = store.DeleteRows(ctx, testKeys, timestamps)
+ assert.Nil(t, err)
+ // Ensure all test row is deleted
+ for i := range timestamps {
+ timestamps[i] = math.MaxUint64
+ }
+ checkValues, err = store.GetRows(ctx, testKeys, timestamps)
+ assert.Nil(t, err)
+ for _, value := range checkValues {
+ assert.Nil(t, value)
+ }
+
+ // Clean test data
+ err = store.engine.DeleteByPrefix(ctx, Key("key"))
+ assert.Nil(t, err)
+}
+
+func TestTikvStore_GetSegments(t *testing.T) {
+ ctx := context.Background()
+ key := Key("key")
+
+ // Put rows
+ err := store.PutRow(ctx, key, Value{0}, "a", 1)
+ assert.Nil(t, err)
+ err = store.PutRow(ctx, key, Value{0}, "a", 2)
+ assert.Nil(t, err)
+ err = store.PutRow(ctx, key, Value{0}, "c", 3)
+ assert.Nil(t, err)
+
+ // Get segments
+ segs, err := store.GetSegments(ctx, key, 2)
+ assert.Nil(t, err)
+ assert.Equal(t, 1, len(segs))
+ assert.Equal(t, "a", segs[0])
+
+ segs, err = store.GetSegments(ctx, key, 3)
+ assert.Nil(t, err)
+ assert.Equal(t, 2, len(segs))
+
+ // Clean test data
+ err = store.engine.DeleteByPrefix(ctx, key)
+ assert.Nil(t, err)
+}
+
+func TestTikvStore_Log(t *testing.T) {
+ ctx := context.Background()
+
+ // Put some log
+ err := store.PutLog(ctx, Key("key1"), Value("value1"), 1, 1)
+ assert.Nil(t, err)
+ err = store.PutLog(ctx, Key("key1"), Value("value1_1"), 1, 2)
+ assert.Nil(t, err)
+ err = store.PutLog(ctx, Key("key2"), Value("value2"), 2, 1)
+ assert.Nil(t, err)
+
+ // Check log
+ log, err := store.GetLog(ctx, 0, 2, []int{1, 2})
+ if err != nil {
+ panic(err)
+ }
+ sort.Slice(log, func(i, j int) bool {
+ return bytes.Compare(log[i], log[j]) == -1
+ })
+ assert.Equal(t, log[0], Value("value1"))
+ assert.Equal(t, log[1], Value("value1_1"))
+ assert.Equal(t, log[2], Value("value2"))
+
+ // Delete test data
+ err = store.engine.DeleteByPrefix(ctx, Key("log"))
+ assert.Nil(t, err)
+}
+
+func TestTikvStore_SegmentIndex(t *testing.T) {
+ ctx := context.Background()
+
+ // Put segment index
+ err := store.PutSegmentIndex(ctx, "segment0", []byte("index0"))
+ assert.Nil(t, err)
+ err = store.PutSegmentIndex(ctx, "segment1", []byte("index1"))
+ assert.Nil(t, err)
+
+ // Get segment index
+ index, err := store.GetSegmentIndex(ctx, "segment0")
+ assert.Nil(t, err)
+ assert.Equal(t, []byte("index0"), index)
+ index, err = store.GetSegmentIndex(ctx, "segment1")
+ assert.Nil(t, err)
+ assert.Equal(t, []byte("index1"), index)
+
+ // Delete segment index
+ err = store.DeleteSegmentIndex(ctx, "segment0")
+ assert.Nil(t, err)
+ err = store.DeleteSegmentIndex(ctx, "segment1")
+ assert.Nil(t, err)
+ index, err = store.GetSegmentIndex(ctx, "segment0")
+ assert.Nil(t, err)
+ assert.Nil(t, index)
+}
+
+func TestTikvStore_DeleteSegmentDL(t *testing.T) {
+ ctx := context.Background()
+
+ // Put segment delete log
+ err := store.PutSegmentDL(ctx, "segment0", []byte("index0"))
+ assert.Nil(t, err)
+ err = store.PutSegmentDL(ctx, "segment1", []byte("index1"))
+ assert.Nil(t, err)
+
+ // Get segment delete log
+ index, err := store.GetSegmentDL(ctx, "segment0")
+ assert.Nil(t, err)
+ assert.Equal(t, []byte("index0"), index)
+ index, err = store.GetSegmentDL(ctx, "segment1")
+ assert.Nil(t, err)
+ assert.Equal(t, []byte("index1"), index)
+
+ // Delete segment delete log
+ err = store.DeleteSegmentDL(ctx, "segment0")
+ assert.Nil(t, err)
+ err = store.DeleteSegmentDL(ctx, "segment1")
+ assert.Nil(t, err)
+ index, err = store.GetSegmentDL(ctx, "segment0")
+ assert.Nil(t, err)
+ assert.Nil(t, index)
+}
diff --git a/internal/storage/storage.go b/internal/storage/storage.go
new file mode 100644
index 0000000000000000000000000000000000000000..67e9e44e8939075caaaa8aefa17cccc12e786bbd
--- /dev/null
+++ b/internal/storage/storage.go
@@ -0,0 +1,39 @@
+package storage
+
+import (
+ "context"
+ "errors"
+
+ S3Driver "github.com/zilliztech/milvus-distributed/internal/storage/internal/S3"
+ minIODriver "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio"
+ tikvDriver "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv"
+ storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
+)
+
+func NewStore(ctx context.Context, option storagetype.Option) (storagetype.Store, error) {
+ var err error
+ var store storagetype.Store
+ switch option.Type {
+ case storagetype.TIKVDriver:
+ store, err = tikvDriver.NewTikvStore(ctx, option)
+ if err != nil {
+ panic(err.Error())
+ }
+ return store, nil
+ case storagetype.MinIODriver:
+ store, err = minIODriver.NewMinioDriver(ctx, option)
+ if err != nil {
+ //panic(err.Error())
+ return nil, err
+ }
+ return store, nil
+ case storagetype.S3DRIVER:
+ store, err = S3Driver.NewS3Driver(ctx, option)
+ if err != nil {
+ //panic(err.Error())
+ return nil, err
+ }
+ return store, nil
+ }
+ return nil, errors.New("unsupported driver")
+}
diff --git a/internal/storage/type/storagetype.go b/internal/storage/type/storagetype.go
new file mode 100644
index 0000000000000000000000000000000000000000..9549a106e573e96589ec9f415379a0a5b7144c2b
--- /dev/null
+++ b/internal/storage/type/storagetype.go
@@ -0,0 +1,79 @@
+package storagetype
+
+import (
+ "context"
+
+ "github.com/zilliztech/milvus-distributed/internal/util/typeutil"
+)
+
+type Key = []byte
+type Value = []byte
+type Timestamp = typeutil.Timestamp
+type DriverType = string
+type SegmentIndex = []byte
+type SegmentDL = []byte
+
+type Option struct {
+ Type DriverType
+ TikvAddress string
+ BucketName string
+}
+
+const (
+ MinIODriver DriverType = "MinIO"
+ TIKVDriver DriverType = "TIKV"
+ S3DRIVER DriverType = "S3"
+)
+
+/*
+type Store interface {
+ Get(ctx context.Context, key Key, timestamp Timestamp) (Value, error)
+ BatchGet(ctx context.Context, keys [] Key, timestamp Timestamp) ([]Value, error)
+ Set(ctx context.Context, key Key, v Value, timestamp Timestamp) error
+ BatchSet(ctx context.Context, keys []Key, v []Value, timestamp Timestamp) error
+ Delete(ctx context.Context, key Key, timestamp Timestamp) error
+ BatchDelete(ctx context.Context, keys []Key, timestamp Timestamp) error
+ Close() error
+}
+*/
+
+type storeEngine interface {
+ Put(ctx context.Context, key Key, value Value) error
+ Get(ctx context.Context, key Key) (Value, error)
+ GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error)
+ Scan(ctx context.Context, startKey Key, endKey Key, limit int, keyOnly bool) ([]Key, []Value, error)
+ Delete(ctx context.Context, key Key) error
+ DeleteByPrefix(ctx context.Context, prefix Key) error
+ DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error
+}
+
+type Store interface {
+ //put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error
+ //scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error)
+ //scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error)
+ //deleteLE(ctx context.Context, key Key, timestamp Timestamp) error
+ //deleteGE(ctx context.Context, key Key, timestamp Timestamp) error
+ //deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error
+
+ GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error)
+ GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error)
+
+ PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error
+ PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error
+
+ GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error)
+
+ DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error
+ DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error
+
+ PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error
+ GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error)
+
+ GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error)
+ PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error
+ DeleteSegmentIndex(ctx context.Context, segment string) error
+
+ GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error)
+ PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error
+ DeleteSegmentDL(ctx context.Context, segment string) error
+}
diff --git a/internal/util/flowgraph/input_node.go b/internal/util/flowgraph/input_node.go
index 7c4271b23be5e31373966c3b64acfc395285916f..b7891040e8208894b2c66f07582ddc26adc3d59e 100644
--- a/internal/util/flowgraph/input_node.go
+++ b/internal/util/flowgraph/input_node.go
@@ -1,8 +1,12 @@
package flowgraph
import (
+ "fmt"
"log"
+ "github.com/opentracing/opentracing-go"
+ "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
+
"github.com/zilliztech/milvus-distributed/internal/msgstream"
)
@@ -25,11 +29,32 @@ func (inNode *InputNode) InStream() *msgstream.MsgStream {
}
// empty input and return one *Msg
-func (inNode *InputNode) Operate(in []*Msg) []*Msg {
+func (inNode *InputNode) Operate([]*Msg) []*Msg {
//fmt.Println("Do InputNode operation")
-
msgPack := (*inNode.inStream).Consume()
+ var childs []opentracing.Span
+ tracer := opentracing.GlobalTracer()
+ if tracer != nil && msgPack != nil {
+ for _, msg := range msgPack.Msgs {
+ if msg.Type() == internalpb.MsgType_kInsert || msg.Type() == internalpb.MsgType_kSearch {
+ var child opentracing.Span
+ ctx := msg.GetContext()
+ if parent := opentracing.SpanFromContext(ctx); parent != nil {
+ child = tracer.StartSpan(fmt.Sprintf("through msg input node, start time = %d", msg.BeginTs()),
+ opentracing.FollowsFrom(parent.Context()))
+ } else {
+ child = tracer.StartSpan(fmt.Sprintf("through msg input node, start time = %d", msg.BeginTs()))
+ }
+ child.SetTag("hash keys", msg.HashKeys())
+ child.SetTag("start time", msg.BeginTs())
+ child.SetTag("end time", msg.EndTs())
+ msg.SetContext(opentracing.ContextWithSpan(ctx, child))
+ childs = append(childs, child)
+ }
+ }
+ }
+
// TODO: add status
if msgPack == nil {
log.Println("null msg pack")
@@ -42,6 +67,10 @@ func (inNode *InputNode) Operate(in []*Msg) []*Msg {
timestampMax: msgPack.EndTs,
}
+ for _, child := range childs {
+ child.Finish()
+ }
+
return []*Msg{&msgStreamMsg}
}
diff --git a/internal/writenode/flow_graph_dd_node.go b/internal/writenode/flow_graph_dd_node.go
index 2c77e398c53eb7a03b7afb90600b85f574dd1628..7dd8e1fd433fda223cab267bbae39bb07ce2122c 100644
--- a/internal/writenode/flow_graph_dd_node.go
+++ b/internal/writenode/flow_graph_dd_node.go
@@ -9,6 +9,9 @@ import (
"strconv"
"github.com/golang/protobuf/proto"
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
+
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/kv"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
@@ -357,16 +360,19 @@ func newDDNode(ctx context.Context, outCh chan *ddlFlushSyncMsg) *ddNode {
partitionRecords: make(map[UniqueID]interface{}),
}
- bucketName := Params.MinioBucketName
- option := &miniokv.Option{
- Address: Params.MinioAddress,
- AccessKeyID: Params.MinioAccessKeyID,
- SecretAccessKeyID: Params.MinioSecretAccessKey,
- UseSSL: Params.MinioUseSSL,
- BucketName: bucketName,
- CreateBucket: true,
+ minIOEndPoint := Params.MinioAddress
+ minIOAccessKeyID := Params.MinioAccessKeyID
+ minIOSecretAccessKey := Params.MinioSecretAccessKey
+ minIOUseSSL := Params.MinioUseSSL
+ minIOClient, err := minio.New(minIOEndPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(minIOAccessKeyID, minIOSecretAccessKey, ""),
+ Secure: minIOUseSSL,
+ })
+ if err != nil {
+ panic(err)
}
- minioKV, err := miniokv.NewMinIOKV(ctx, option)
+ bucketName := Params.MinioBucketName
+ minioKV, err := miniokv.NewMinIOKV(ctx, minIOClient, bucketName)
if err != nil {
panic(err)
}
diff --git a/internal/writenode/flow_graph_filter_dm_node.go b/internal/writenode/flow_graph_filter_dm_node.go
index 98e9ec0bc4cc5fb596c93a2725ef2bd521da911e..0bca67ebcb4ec4723454c083331822f617275990 100644
--- a/internal/writenode/flow_graph_filter_dm_node.go
+++ b/internal/writenode/flow_graph_filter_dm_node.go
@@ -1,8 +1,11 @@
package writenode
import (
+ "context"
"log"
+ "github.com/opentracing/opentracing-go"
+
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
@@ -31,11 +34,34 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
// TODO: add error handling
}
+ var childs []opentracing.Span
+ tracer := opentracing.GlobalTracer()
+ if tracer != nil {
+ for _, msg := range msgStreamMsg.TsMessages() {
+ if msg.Type() == internalPb.MsgType_kInsert {
+ var child opentracing.Span
+ ctx := msg.GetContext()
+ if parent := opentracing.SpanFromContext(ctx); parent != nil {
+ child = tracer.StartSpan("pass filter node",
+ opentracing.FollowsFrom(parent.Context()))
+ } else {
+ child = tracer.StartSpan("pass filter node")
+ }
+ child.SetTag("hash keys", msg.HashKeys())
+ child.SetTag("start time", msg.BeginTs())
+ child.SetTag("end time", msg.EndTs())
+ msg.SetContext(opentracing.ContextWithSpan(ctx, child))
+ childs = append(childs, child)
+ }
+ }
+ }
+
ddMsg, ok := (*in[1]).(*ddMsg)
if !ok {
log.Println("type assertion failed for ddMsg")
// TODO: add error handling
}
+
fdmNode.ddMsg = ddMsg
var iMsg = insertMsg{
@@ -56,11 +82,20 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
}
}
- for _, msg := range msgStreamMsg.TsMessages() {
+ for key, msg := range msgStreamMsg.TsMessages() {
switch msg.Type() {
case internalPb.MsgType_kInsert:
+ var ctx2 context.Context
+ if childs != nil {
+ if childs[key] != nil {
+ ctx2 = opentracing.ContextWithSpan(msg.GetContext(), childs[key])
+ } else {
+ ctx2 = context.Background()
+ }
+ }
resMsg := fdmNode.filterInvalidInsertMessage(msg.(*msgstream.InsertMsg))
if resMsg != nil {
+ resMsg.SetContext(ctx2)
iMsg.insertMessages = append(iMsg.insertMessages, resMsg)
}
// case internalPb.MsgType_kDelete:
@@ -69,8 +104,11 @@ func (fdmNode *filterDmNode) Operate(in []*Msg) []*Msg {
log.Println("Non supporting message type:", msg.Type())
}
}
-
var res Msg = &iMsg
+
+ for _, child := range childs {
+ child.Finish()
+ }
return []*Msg{&res}
}
diff --git a/internal/writenode/flow_graph_insert_buffer_node.go b/internal/writenode/flow_graph_insert_buffer_node.go
index 2ebc1300f8a0fc1c1f56c5a55a203f3b821f87d4..927e751a62fd5c9ddd29e298b7e692353e723d9b 100644
--- a/internal/writenode/flow_graph_insert_buffer_node.go
+++ b/internal/writenode/flow_graph_insert_buffer_node.go
@@ -4,13 +4,19 @@ import (
"bytes"
"context"
"encoding/binary"
+ "fmt"
"log"
"path"
"strconv"
"time"
"unsafe"
+ "github.com/opentracing/opentracing-go"
+ oplog "github.com/opentracing/opentracing-go/log"
+
"github.com/golang/protobuf/proto"
+ "github.com/minio/minio-go/v7"
+ "github.com/minio/minio-go/v7/pkg/credentials"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
@@ -100,11 +106,22 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
// iMsg is insertMsg
// 1. iMsg -> buffer
for _, msg := range iMsg.insertMessages {
+ ctx := msg.GetContext()
+ var span opentracing.Span
+ if ctx != nil {
+ span, _ = opentracing.StartSpanFromContext(ctx, fmt.Sprintf("insert buffer node, start time = %d", msg.BeginTs()))
+ } else {
+ span = opentracing.StartSpan(fmt.Sprintf("insert buffer node, start time = %d", msg.BeginTs()))
+ }
+ span.SetTag("hash keys", msg.HashKeys())
+ span.SetTag("start time", msg.BeginTs())
+ span.SetTag("end time", msg.EndTs())
if len(msg.RowIDs) != len(msg.Timestamps) || len(msg.RowIDs) != len(msg.RowData) {
log.Println("Error: misaligned messages detected")
continue
}
currentSegID := msg.GetSegmentID()
+ span.LogFields(oplog.Int("segment id", int(currentSegID)))
idata, ok := ibNode.insertBuffer.insertData[currentSegID]
if !ok {
@@ -113,6 +130,21 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
}
}
+ // Timestamps
+ _, ok = idata.Data[1].(*storage.Int64FieldData)
+ if !ok {
+ idata.Data[1] = &storage.Int64FieldData{
+ Data: []int64{},
+ NumRows: 0,
+ }
+ }
+ tsData := idata.Data[1].(*storage.Int64FieldData)
+ for _, ts := range msg.Timestamps {
+ tsData.Data = append(tsData.Data, int64(ts))
+ }
+ tsData.NumRows += len(msg.Timestamps)
+ span.LogFields(oplog.Int("tsData numRows", tsData.NumRows))
+
// 1.1 Get CollectionMeta from etcd
segMeta, collMeta, err := ibNode.getMeta(currentSegID)
if err != nil {
@@ -358,9 +390,11 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
// 1.3 store in buffer
ibNode.insertBuffer.insertData[currentSegID] = idata
+ span.LogFields(oplog.String("store in buffer", "store in buffer"))
// 1.4 if full
// 1.4.1 generate binlogs
+ span.LogFields(oplog.String("generate binlogs", "generate binlogs"))
if ibNode.insertBuffer.full(currentSegID) {
log.Printf(". Insert Buffer full, auto flushing (%v) rows of data...", ibNode.insertBuffer.size(currentSegID))
// partitionTag -> partitionID
@@ -426,6 +460,7 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
ibNode.outCh <- inBinlogMsg
}
}
+ span.Finish()
}
if len(iMsg.insertMessages) > 0 {
@@ -608,17 +643,20 @@ func newInsertBufferNode(ctx context.Context, outCh chan *insertFlushSyncMsg) *i
kvClient := etcdkv.NewEtcdKV(cli, MetaRootPath)
// MinIO
-
- option := &miniokv.Option{
- Address: Params.MinioAddress,
- AccessKeyID: Params.MinioAccessKeyID,
- SecretAccessKeyID: Params.MinioSecretAccessKey,
- UseSSL: Params.MinioUseSSL,
- CreateBucket: true,
- BucketName: Params.MinioBucketName,
+ minioendPoint := Params.MinioAddress
+ miniioAccessKeyID := Params.MinioAccessKeyID
+ miniioSecretAccessKey := Params.MinioSecretAccessKey
+ minioUseSSL := Params.MinioUseSSL
+ minioBucketName := Params.MinioBucketName
+
+ minioClient, err := minio.New(minioendPoint, &minio.Options{
+ Creds: credentials.NewStaticV4(miniioAccessKeyID, miniioSecretAccessKey, ""),
+ Secure: minioUseSSL,
+ })
+ if err != nil {
+ panic(err)
}
-
- minIOKV, err := miniokv.NewMinIOKV(ctx, option)
+ minIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, minioBucketName)
if err != nil {
panic(err)
}
diff --git a/internal/writenode/write_node.go b/internal/writenode/write_node.go
index b747bd397be62c5dd08ca7ac42322bb7e1e8e962..77aeea56684cee54c81271e734b1c9ba837e7033 100644
--- a/internal/writenode/write_node.go
+++ b/internal/writenode/write_node.go
@@ -2,6 +2,12 @@ package writenode
import (
"context"
+ "fmt"
+ "io"
+
+ "github.com/opentracing/opentracing-go"
+ "github.com/uber/jaeger-client-go"
+ "github.com/uber/jaeger-client-go/config"
)
type WriteNode struct {
@@ -9,6 +15,9 @@ type WriteNode struct {
WriteNodeID uint64
dataSyncService *dataSyncService
flushSyncService *flushSyncService
+
+ tracer opentracing.Tracer
+ closer io.Closer
}
func NewWriteNode(ctx context.Context, writeNodeID uint64) *WriteNode {
@@ -28,6 +37,22 @@ func Init() {
}
func (node *WriteNode) Start() error {
+ cfg := &config.Configuration{
+ ServiceName: "tracing",
+ Sampler: &config.SamplerConfig{
+ Type: "const",
+ Param: 1,
+ },
+ Reporter: &config.ReporterConfig{
+ LogSpans: true,
+ },
+ }
+ var err error
+ node.tracer, node.closer, err = cfg.NewTracer(config.Logger(jaeger.StdLogger))
+ if err != nil {
+ panic(fmt.Sprintf("ERROR: cannot init Jaeger: %v\n", err))
+ }
+ opentracing.SetGlobalTracer(node.tracer)
// TODO GOOSE Init Size??
chanSize := 100
@@ -39,6 +64,7 @@ func (node *WriteNode) Start() error {
go node.dataSyncService.start()
go node.flushSyncService.start()
+
return nil
}
diff --git a/tools/core_gen/all_generate.py b/tools/core_gen/all_generate.py
index a4be00cead810a06504156c058b1755086ac1899..499022d583bf44465868530a4ddee5aec991523b 100755
--- a/tools/core_gen/all_generate.py
+++ b/tools/core_gen/all_generate.py
@@ -58,10 +58,6 @@ if __name__ == "__main__":
'visitor_name': "ExecExprVisitor",
"parameter_name": 'expr',
},
- {
- 'visitor_name': "VerifyExprVisitor",
- "parameter_name": 'expr',
- },
],
'PlanNode': [
{
@@ -72,10 +68,7 @@ if __name__ == "__main__":
'visitor_name': "ExecPlanNodeVisitor",
"parameter_name": 'node',
},
- {
- 'visitor_name': "VerifyPlanNodeVisitor",
- "parameter_name": 'node',
- },
+
]
}
extract_extra_body(visitor_info, query_path)
diff --git a/tools/core_gen/templates/visitor_derived.h b/tools/core_gen/templates/visitor_derived.h
index 5fe2be775ddaebaa817d4819413898d71d4468a1..cda1ea1a74c15cb90ca13fe26b879bdf1cd021cc 100644
--- a/tools/core_gen/templates/visitor_derived.h
+++ b/tools/core_gen/templates/visitor_derived.h
@@ -13,7 +13,7 @@
#include "@@base_visitor@@.h"
namespace @@namespace@@ {
-class @@visitor_name@@ : public @@base_visitor@@ {
+class @@visitor_name@@ : @@base_visitor@@ {
public:
@@body@@