diff --git a/cmd/storage/benchmark.go b/cmd/storage/benchmark.go deleted file mode 100644 index 6a7c06f9a209b25df74545d1f23f58b2d1c80594..0000000000000000000000000000000000000000 --- a/cmd/storage/benchmark.go +++ /dev/null @@ -1,315 +0,0 @@ -package main - -import ( - "context" - "crypto/md5" - "flag" - "fmt" - "log" - "math/rand" - "os" - "sync" - "sync/atomic" - "time" - - "github.com/pivotal-golang/bytefmt" - "github.com/zilliztech/milvus-distributed/internal/storage" - storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -// Global variables -var durationSecs, threads, loops, numVersion, batchOpSize int -var valueSize uint64 -var valueData []byte -var batchValueData [][]byte -var counter, totalKeyCount, keyNum int32 -var endTime, setFinish, getFinish, deleteFinish time.Time -var totalKeys [][]byte - -var logFileName = "benchmark.log" -var logFile *os.File - -var store storagetype.Store -var wg sync.WaitGroup - -func runSet() { - for time.Now().Before(endTime) { - num := atomic.AddInt32(&keyNum, 1) - key := []byte(fmt.Sprint("key", num)) - for ver := 1; ver <= numVersion; ver++ { - atomic.AddInt32(&counter, 1) - err := store.PutRow(context.Background(), key, valueData, "empty", uint64(ver)) - if err != nil { - log.Fatalf("Error setting key %s, %s", key, err.Error()) - //atomic.AddInt32(&setCount, -1) - } - } - } - // Remember last done time - setFinish = time.Now() - wg.Done() -} - -func runBatchSet() { - for time.Now().Before(endTime) { - num := atomic.AddInt32(&keyNum, int32(batchOpSize)) - keys := make([][]byte, batchOpSize) - versions := make([]uint64, batchOpSize) - batchSuffix := make([]string, batchOpSize) - for n := batchOpSize; n > 0; n-- { - keys[n-1] = []byte(fmt.Sprint("key", num-int32(n))) - } - for ver := 1; ver <= numVersion; ver++ { - atomic.AddInt32(&counter, 1) - err := store.PutRows(context.Background(), keys, batchValueData, batchSuffix, versions) - if err != nil { - log.Fatalf("Error setting batch keys %s %s", keys, err.Error()) - //atomic.AddInt32(&batchSetCount, -1) - } - } - } - setFinish = time.Now() - wg.Done() -} - -func runGet() { - for time.Now().Before(endTime) { - num := atomic.AddInt32(&counter, 1) - //num := atomic.AddInt32(&keyNum, 1) - //key := []byte(fmt.Sprint("key", num)) - num = num % totalKeyCount - key := totalKeys[num] - _, err := store.GetRow(context.Background(), key, uint64(numVersion)) - if err != nil { - log.Fatalf("Error getting key %s, %s", key, err.Error()) - //atomic.AddInt32(&getCount, -1) - } - } - // Remember last done time - getFinish = time.Now() - wg.Done() -} - -func runBatchGet() { - for time.Now().Before(endTime) { - num := atomic.AddInt32(&keyNum, int32(batchOpSize)) - //keys := make([][]byte, batchOpSize) - //for n := batchOpSize; n > 0; n-- { - // keys[n-1] = []byte(fmt.Sprint("key", num-int32(n))) - //} - end := num % totalKeyCount - if end < int32(batchOpSize) { - end = int32(batchOpSize) - } - start := end - int32(batchOpSize) - keys := totalKeys[start:end] - versions := make([]uint64, batchOpSize) - for i := range versions { - versions[i] = uint64(numVersion) - } - atomic.AddInt32(&counter, 1) - _, err := store.GetRows(context.Background(), keys, versions) - if err != nil { - log.Fatalf("Error getting key %s, %s", keys, err.Error()) - //atomic.AddInt32(&batchGetCount, -1) - } - } - // Remember last done time - getFinish = time.Now() - wg.Done() -} - -func runDelete() { - for time.Now().Before(endTime) { - num := atomic.AddInt32(&counter, 1) - //num := atomic.AddInt32(&keyNum, 1) - //key := []byte(fmt.Sprint("key", num)) - num = num % totalKeyCount - key := totalKeys[num] - err := store.DeleteRow(context.Background(), key, uint64(numVersion)) - if err != nil { - log.Fatalf("Error getting key %s, %s", key, err.Error()) - //atomic.AddInt32(&deleteCount, -1) - } - } - // Remember last done time - deleteFinish = time.Now() - wg.Done() -} - -func runBatchDelete() { - for time.Now().Before(endTime) { - num := atomic.AddInt32(&keyNum, int32(batchOpSize)) - //keys := make([][]byte, batchOpSize) - //for n := batchOpSize; n > 0; n-- { - // keys[n-1] = []byte(fmt.Sprint("key", num-int32(n))) - //} - end := num % totalKeyCount - if end < int32(batchOpSize) { - end = int32(batchOpSize) - } - start := end - int32(batchOpSize) - keys := totalKeys[start:end] - atomic.AddInt32(&counter, 1) - versions := make([]uint64, batchOpSize) - for i := range versions { - versions[i] = uint64(numVersion) - } - err := store.DeleteRows(context.Background(), keys, versions) - if err != nil { - log.Fatalf("Error getting key %s, %s", keys, err.Error()) - //atomic.AddInt32(&batchDeleteCount, -1) - } - } - // Remember last done time - getFinish = time.Now() - wg.Done() -} - -func main() { - // Parse command line - myflag := flag.NewFlagSet("myflag", flag.ExitOnError) - myflag.IntVar(&durationSecs, "d", 5, "Duration of each test in seconds") - myflag.IntVar(&threads, "t", 1, "Number of threads to run") - myflag.IntVar(&loops, "l", 1, "Number of times to repeat test") - var sizeArg string - var storeType string - myflag.StringVar(&sizeArg, "z", "1k", "Size of objects in bytes with postfix K, M, and G") - myflag.StringVar(&storeType, "s", "s3", "Storage type, tikv or minio or s3") - myflag.IntVar(&numVersion, "v", 1, "Max versions for each key") - myflag.IntVar(&batchOpSize, "b", 100, "Batch operation kv pair number") - - if err := myflag.Parse(os.Args[1:]); err != nil { - os.Exit(1) - } - - // Check the arguments - var err error - if valueSize, err = bytefmt.ToBytes(sizeArg); err != nil { - log.Fatalf("Invalid -z argument for object size: %v", err) - } - var option = storagetype.Option{TikvAddress: "localhost:2379", Type: storeType, BucketName: "zilliz-hz"} - - store, err = storage.NewStore(context.Background(), option) - if err != nil { - log.Fatalf("Error when creating storage " + err.Error()) - } - logFile, err = os.OpenFile(logFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0777) - if err != nil { - log.Fatalf("Prepare log file error, " + err.Error()) - } - - // Echo the parameters - log.Printf("Benchmark log will write to file %s\n", logFile.Name()) - fmt.Fprintf(logFile, "Parameters: duration=%d, threads=%d, loops=%d, valueSize=%s, batchSize=%d, versions=%d\n", durationSecs, threads, loops, sizeArg, batchOpSize, numVersion) - // Init test data - valueData = make([]byte, valueSize) - rand.Read(valueData) - hasher := md5.New() - hasher.Write(valueData) - - batchValueData = make([][]byte, batchOpSize) - for i := range batchValueData { - batchValueData[i] = make([]byte, valueSize) - rand.Read(batchValueData[i]) - hasher := md5.New() - hasher.Write(batchValueData[i]) - } - - // Loop running the tests - for loop := 1; loop <= loops; loop++ { - - // reset counters - counter = 0 - keyNum = 0 - totalKeyCount = 0 - totalKeys = nil - - // Run the batchSet case - // key seq start from setCount - counter = 0 - startTime := time.Now() - endTime = startTime.Add(time.Second * time.Duration(durationSecs)) - for n := 1; n <= threads; n++ { - wg.Add(1) - go runBatchSet() - } - wg.Wait() - - setTime := setFinish.Sub(startTime).Seconds() - bps := float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime - fmt.Fprintf(logFile, "Loop %d: BATCH PUT time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n", - loop, setTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/setTime, float64(counter*int32(batchOpSize))/setTime) - // Record all test keys - //totalKeyCount = keyNum - //totalKeys = make([][]byte, totalKeyCount) - //for i := int32(0); i < totalKeyCount; i++ { - // totalKeys[i] = []byte(fmt.Sprint("key", i)) - //} - // - //// Run the get case - //counter = 0 - //startTime = time.Now() - //endTime = startTime.Add(time.Second * time.Duration(durationSecs)) - //for n := 1; n <= threads; n++ { - // wg.Add(1) - // go runGet() - //} - //wg.Wait() - // - //getTime := getFinish.Sub(startTime).Seconds() - //bps = float64(uint64(counter)*valueSize) / getTime - //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: GET time %.1f secs, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n", - // loop, getTime, counter, bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter)/getTime)) - - // Run the batchGet case - //counter = 0 - //startTime = time.Now() - //endTime = startTime.Add(time.Second * time.Duration(durationSecs)) - //for n := 1; n <= threads; n++ { - // wg.Add(1) - // go runBatchGet() - //} - //wg.Wait() - // - //getTime = getFinish.Sub(startTime).Seconds() - //bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / getTime - //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH GET time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n", - // loop, getTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter * int32(batchOpSize))/getTime)) - // - //// Run the delete case - //counter = 0 - //startTime = time.Now() - //endTime = startTime.Add(time.Second * time.Duration(durationSecs)) - //for n := 1; n <= threads; n++ { - // wg.Add(1) - // go runDelete() - //} - //wg.Wait() - // - //deleteTime := deleteFinish.Sub(startTime).Seconds() - //bps = float64(uint64(counter)*valueSize) / deleteTime - //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: Delete time %.1f secs, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n", - // loop, deleteTime, counter, float64(counter)/deleteTime, float64(counter)/deleteTime)) - // - //// Run the batchDelete case - //counter = 0 - //startTime = time.Now() - //endTime = startTime.Add(time.Second * time.Duration(durationSecs)) - //for n := 1; n <= threads; n++ { - // wg.Add(1) - // go runBatchDelete() - //} - //wg.Wait() - // - //deleteTime = setFinish.Sub(startTime).Seconds() - //bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime - //fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH DELETE time %.1f secs, batchs = %d, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n", - // loop, setTime, counter, counter*int32(batchOpSize), float64(counter)/setTime, float64(counter * int32(batchOpSize))/setTime)) - - // Print line mark - lineMark := "\n" - fmt.Fprint(logFile, lineMark) - } - log.Print("Benchmark test done.") -} diff --git a/internal/core/src/indexbuilder/IndexWrapper.cpp b/internal/core/src/indexbuilder/IndexWrapper.cpp index 5d95eabf3d97ba40fcb9ee646b480a9c5a280254..fcced635e0e508894a70995b68002310c7b8182f 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.cpp +++ b/internal/core/src/indexbuilder/IndexWrapper.cpp @@ -55,7 +55,6 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con } auto stoi_closure = [](const std::string& s) -> int { return std::stoi(s); }; - auto stof_closure = [](const std::string& s) -> int { return std::stof(s); }; /***************************** meta *******************************/ check_parameter<int>(conf, milvus::knowhere::meta::DIM, stoi_closure, std::nullopt); @@ -89,7 +88,7 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con check_parameter<int>(conf, milvus::knowhere::IndexParams::edge_size, stoi_closure, std::nullopt); /************************** NGT Search Params *****************************/ - check_parameter<float>(conf, milvus::knowhere::IndexParams::epsilon, stof_closure, std::nullopt); + check_parameter<int>(conf, milvus::knowhere::IndexParams::epsilon, stoi_closure, std::nullopt); check_parameter<int>(conf, milvus::knowhere::IndexParams::max_search_edges, stoi_closure, std::nullopt); /************************** NGT_PANNG Params *****************************/ @@ -275,12 +274,6 @@ IndexWrapper::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* se std::unique_ptr<IndexWrapper::QueryResult> IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Config& conf) { - auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer - auto index_type = get_index_type(); - if (is_in_nm_list(index_type)) { - std::call_once(raw_data_loaded_, load_raw_data_closure); - } - auto res = index_->Query(dataset, conf, nullptr); auto ids = res->Get<int64_t*>(milvus::knowhere::meta::IDS); auto distances = res->Get<float*>(milvus::knowhere::meta::DISTANCE); @@ -298,19 +291,5 @@ IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Con return std::move(query_res); } -void -IndexWrapper::LoadRawData() { - auto index_type = get_index_type(); - if (is_in_nm_list(index_type)) { - auto bs = index_->Serialize(config_); - auto bptr = std::make_shared<milvus::knowhere::Binary>(); - auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction - bptr->data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter); - bptr->size = raw_data_.size(); - bs.Append(RAW_DATA, bptr); - index_->Load(bs); - } -} - } // namespace indexbuilder } // namespace milvus diff --git a/internal/core/src/indexbuilder/IndexWrapper.h b/internal/core/src/indexbuilder/IndexWrapper.h index 16f2721712c655bff7b2e7d53a235e32ed1d6458..65c6f149febf89bd30521e0478ba4eb2782b8583 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.h +++ b/internal/core/src/indexbuilder/IndexWrapper.h @@ -66,9 +66,6 @@ class IndexWrapper { void StoreRawData(const knowhere::DatasetPtr& dataset); - void - LoadRawData(); - template <typename T> void check_parameter(knowhere::Config& conf, @@ -95,7 +92,6 @@ class IndexWrapper { milvus::json index_config_; knowhere::Config config_; std::vector<uint8_t> raw_data_; - std::once_flag raw_data_loaded_; }; } // namespace indexbuilder diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index a1de1d4ed502053407f16f1fc6e107c163cda653..b272e388853fedb8b1178d41302826ff2b2fe8b7 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -4,13 +4,15 @@ set(MILVUS_QUERY_SRCS generated/PlanNode.cpp generated/Expr.cpp visitors/ShowPlanNodeVisitor.cpp - visitors/ExecPlanNodeVisitor.cpp visitors/ShowExprVisitor.cpp + visitors/ExecPlanNodeVisitor.cpp visitors/ExecExprVisitor.cpp + visitors/VerifyPlanNodeVisitor.cpp + visitors/VerifyExprVisitor.cpp Plan.cpp Search.cpp SearchOnSealed.cpp BruteForceSearch.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) -target_link_libraries(milvus_query milvus_proto milvus_utils) +target_link_libraries(milvus_query milvus_proto milvus_utils knowhere) diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 96653593516c15f98797a9f2d4c18e316b0161e1..78f1f14c73ebb94d897ff6629cdd62bf3c05e05f 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -21,6 +21,7 @@ #include <boost/align/aligned_allocator.hpp> #include <boost/algorithm/string.hpp> #include <algorithm> +#include "query/generated/VerifyPlanNodeVisitor.h" namespace milvus::query { @@ -138,6 +139,8 @@ Parser::CreatePlanImpl(const std::string& dsl_str) { if (predicate != nullptr) { vec_node->predicate_ = std::move(predicate); } + VerifyPlanNodeVisitor verifier; + vec_node->accept(verifier); auto plan = std::make_unique<Plan>(schema); plan->tag2field_ = std::move(tag2field_); diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h index 250d68a6e567a52f9cf9c6bc8c7c886abf12f3f3..a9e0574a6e527a6f8fe5856ec640f15072824390 100644 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ b/internal/core/src/query/generated/ExecExprVisitor.h @@ -21,7 +21,7 @@ #include "ExprVisitor.h" namespace milvus::query { -class ExecExprVisitor : ExprVisitor { +class ExecExprVisitor : public ExprVisitor { public: void visit(BoolUnaryExpr& expr) override; diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h index 0eb33384d71eec5e01a486014c27c1049e442c46..c026c689857958c27b44daf2c160b51592b26c44 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h @@ -19,7 +19,7 @@ #include "PlanNodeVisitor.h" namespace milvus::query { -class ExecPlanNodeVisitor : PlanNodeVisitor { +class ExecPlanNodeVisitor : public PlanNodeVisitor { public: void visit(FloatVectorANNS& node) override; diff --git a/internal/core/src/query/generated/ShowExprVisitor.h b/internal/core/src/query/generated/ShowExprVisitor.h index 55659e24c04e4a419a97b50ec39cfaffb1bcb558..6a1ed2646fc641b7670a0c7f100da9ed8408dc06 100644 --- a/internal/core/src/query/generated/ShowExprVisitor.h +++ b/internal/core/src/query/generated/ShowExprVisitor.h @@ -19,7 +19,7 @@ #include "ExprVisitor.h" namespace milvus::query { -class ShowExprVisitor : ExprVisitor { +class ShowExprVisitor : public ExprVisitor { public: void visit(BoolUnaryExpr& expr) override; diff --git a/internal/core/src/query/generated/ShowPlanNodeVisitor.h b/internal/core/src/query/generated/ShowPlanNodeVisitor.h index b921ec81fc5aa3eb29eb15c091a72738cfc57d4b..c518c3f7d0b23204f804c035db3471bcf08c4831 100644 --- a/internal/core/src/query/generated/ShowPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ShowPlanNodeVisitor.h @@ -20,7 +20,7 @@ #include "PlanNodeVisitor.h" namespace milvus::query { -class ShowPlanNodeVisitor : PlanNodeVisitor { +class ShowPlanNodeVisitor : public PlanNodeVisitor { public: void visit(FloatVectorANNS& node) override; diff --git a/internal/core/src/query/generated/VerifyExprVisitor.cpp b/internal/core/src/query/generated/VerifyExprVisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44af4dde81bdeea864b8bef34d064e4fbd4f2fee --- /dev/null +++ b/internal/core/src/query/generated/VerifyExprVisitor.cpp @@ -0,0 +1,36 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#error TODO: copy this file out, and modify the content. +#include "query/generated/VerifyExprVisitor.h" + +namespace milvus::query { +void +VerifyExprVisitor::visit(BoolUnaryExpr& expr) { + // TODO +} + +void +VerifyExprVisitor::visit(BoolBinaryExpr& expr) { + // TODO +} + +void +VerifyExprVisitor::visit(TermExpr& expr) { + // TODO +} + +void +VerifyExprVisitor::visit(RangeExpr& expr) { + // TODO +} + +} // namespace milvus::query diff --git a/internal/core/src/query/generated/VerifyExprVisitor.h b/internal/core/src/query/generated/VerifyExprVisitor.h new file mode 100644 index 0000000000000000000000000000000000000000..6b04a76978d7db2c8247ac45399b50b85f44309d --- /dev/null +++ b/internal/core/src/query/generated/VerifyExprVisitor.h @@ -0,0 +1,40 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once +// Generated File +// DO NOT EDIT +#include <optional> +#include <boost/dynamic_bitset.hpp> +#include <utility> +#include <deque> +#include "segcore/SegmentSmallIndex.h" +#include "query/ExprImpl.h" +#include "ExprVisitor.h" + +namespace milvus::query { +class VerifyExprVisitor : public ExprVisitor { + public: + void + visit(BoolUnaryExpr& expr) override; + + void + visit(BoolBinaryExpr& expr) override; + + void + visit(TermExpr& expr) override; + + void + visit(RangeExpr& expr) override; + + public: +}; +} // namespace milvus::query diff --git a/internal/core/src/query/generated/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/generated/VerifyPlanNodeVisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c7b0656f57041427e8159e899f891ca0f391c901 --- /dev/null +++ b/internal/core/src/query/generated/VerifyPlanNodeVisitor.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#error TODO: copy this file out, and modify the content. +#include "query/generated/VerifyPlanNodeVisitor.h" + +namespace milvus::query { +void +VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) { + // TODO +} + +void +VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) { + // TODO +} + +} // namespace milvus::query diff --git a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h b/internal/core/src/query/generated/VerifyPlanNodeVisitor.h new file mode 100644 index 0000000000000000000000000000000000000000..a964e6c08f920bcf3b2b1f4e7f70ebbddb0e264a --- /dev/null +++ b/internal/core/src/query/generated/VerifyPlanNodeVisitor.h @@ -0,0 +1,37 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once +// Generated File +// DO NOT EDIT +#include "utils/Json.h" +#include "query/PlanImpl.h" +#include "segcore/SegmentBase.h" +#include <utility> +#include "PlanNodeVisitor.h" + +namespace milvus::query { +class VerifyPlanNodeVisitor : public PlanNodeVisitor { + public: + void + visit(FloatVectorANNS& node) override; + + void + visit(BinaryVectorANNS& node) override; + + public: + using RetType = QueryResult; + VerifyPlanNodeVisitor() = default; + + private: + std::optional<RetType> ret_; +}; +} // namespace milvus::query diff --git a/internal/core/src/query/visitors/VerifyExprVisitor.cpp b/internal/core/src/query/visitors/VerifyExprVisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b3326c74a60b9f604de92a24fb71c4156a60cad --- /dev/null +++ b/internal/core/src/query/visitors/VerifyExprVisitor.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "query/generated/VerifyExprVisitor.h" + +namespace milvus::query { +void +VerifyExprVisitor::visit(BoolUnaryExpr& expr) { + // TODO +} + +void +VerifyExprVisitor::visit(BoolBinaryExpr& expr) { + // TODO +} + +void +VerifyExprVisitor::visit(TermExpr& expr) { + // TODO +} + +void +VerifyExprVisitor::visit(RangeExpr& expr) { + // TODO +} + +} // namespace milvus::query diff --git a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..263390a39a52831c0eb7f1bd71f8dc0cc1cecdb5 --- /dev/null +++ b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "query/generated/VerifyPlanNodeVisitor.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" +#include "segcore/SegmentSmallIndex.h" +#include "knowhere/index/vector_index/ConfAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus::query { + +#if 1 +namespace impl { +// THIS CONTAINS EXTRA BODY FOR VISITOR +// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ +class VerifyPlanNodeVisitor : PlanNodeVisitor { + public: + using RetType = QueryResult; + VerifyPlanNodeVisitor() = default; + + private: + std::optional<RetType> ret_; +}; +} // namespace impl +#endif + +static knowhere::IndexType +InferIndexType(const Json& search_params) { + // ivf -> nprobe + // nsg -> search_length + // hnsw/rhnsw/*pq/*sq -> ef + // annoy -> search_k + // ngtpanng / ngtonng -> max_search_edges / epsilon + static const std::map<std::string, knowhere::IndexType> key_list = [] { + std::map<std::string, knowhere::IndexType> list; + namespace ip = knowhere::IndexParams; + namespace ie = knowhere::IndexEnum; + list.emplace(ip::nprobe, ie::INDEX_FAISS_IVFFLAT); + list.emplace(ip::search_length, ie::INDEX_NSG); + list.emplace(ip::ef, ie::INDEX_HNSW); + list.emplace(ip::search_k, ie::INDEX_ANNOY); + list.emplace(ip::max_search_edges, ie::INDEX_NGTONNG); + list.emplace(ip::epsilon, ie::INDEX_NGTONNG); + return list; + }(); + auto dbg_str = search_params.dump(); + for (auto& kv : search_params.items()) { + std::string key = kv.key(); + if (key_list.count(key)) { + return key_list.at(key); + } + } + PanicInfo("failed to infer index type"); +} + +void +VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) { + auto& search_params = node.query_info_.search_params_; + auto inferred_type = InferIndexType(search_params); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(inferred_type); + auto index_mode = knowhere::IndexMode::MODE_CPU; + + // mock the api, topk will be passed from placeholder + auto params_copy = search_params; + params_copy[knowhere::meta::TOPK] = 10; + + // NOTE: the second parameter is not checked in knowhere, may be redundant + auto passed = adapter->CheckSearch(params_copy, inferred_type, index_mode); + AssertInfo(passed, "invalid search params"); +} + +void +VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) { + // TODO +} + +} // namespace milvus::query diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 709a983c977aceacc0c049b070887044701840f6..1a011c984b690d92e62dc6bb172aa245cbfd140f 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -24,5 +24,6 @@ target_link_libraries(milvus_segcore dl backtrace milvus_common milvus_query + milvus_utils ) diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index a69690728065e26d373ed961c0fd4ccd4817e10b..29fe7b32cfe6068b0b06662d2757de00c3e9a86d 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -24,10 +24,8 @@ target_link_libraries(all_tests gtest_main milvus_segcore milvus_indexbuilder - knowhere log pthread - milvus_utils ) install (TARGETS all_tests DESTINATION unittest) diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 0c5dbe83a8ca657cfef829b5767577312124f416..65866a60b7f39538dcb5ef2cba8a872f23da8689 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -137,7 +137,7 @@ TEST(CApiTest, SearchTest) { auto offset = PreInsert(segment, N); auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(ins_res.error_code == Success); + ASSERT_EQ(ins_res.error_code, Success); const char* dsl_string = R"( { @@ -176,11 +176,11 @@ TEST(CApiTest, SearchTest) { void* plan = nullptr; auto status = CreatePlan(collection, dsl_string, &plan); - assert(status.error_code == Success); + ASSERT_EQ(status.error_code, Success); void* placeholderGroup = nullptr; status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup); - assert(status.error_code == Success); + ASSERT_EQ(status.error_code, Success); std::vector<CPlaceholderGroup> placeholderGroups; placeholderGroups.push_back(placeholderGroup); @@ -189,7 +189,7 @@ TEST(CApiTest, SearchTest) { CQueryResult search_result; auto res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, &search_result); - assert(res.error_code == Success); + ASSERT_EQ(res.error_code, Success); DeletePlan(plan); DeletePlaceholderGroup(placeholderGroup); diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp index bd335951f8053029f720e368d7907cb0d65d451d..a885c837a096b9558da69ec74c73d2c9c019e510 100644 --- a/internal/core/unittest/test_index_wrapper.cpp +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -11,8 +11,6 @@ #include <tuple> #include <map> -#include <limits> -#include <math.h> #include <gtest/gtest.h> #include <google/protobuf/text_format.h> @@ -43,16 +41,16 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IDMAP) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::Metric::TYPE, metric_type}, {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, }; } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::IndexParams::nlist, 100}, - {milvus::knowhere::IndexParams::nprobe, 4}, + // {milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::IndexParams::m, 4}, {milvus::knowhere::IndexParams::nbits, 8}, {milvus::knowhere::Metric::TYPE, metric_type}, @@ -61,9 +59,9 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::IndexParams::nlist, 100}, - {milvus::knowhere::IndexParams::nprobe, 4}, + // {milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::Metric::TYPE, metric_type}, {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, #ifdef MILVUS_GPU_VERSION @@ -73,9 +71,9 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::IndexParams::nlist, 100}, - {milvus::knowhere::IndexParams::nprobe, 4}, + // {milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::IndexParams::nbits, 8}, {milvus::knowhere::Metric::TYPE, metric_type}, {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, @@ -86,9 +84,9 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::IndexParams::nlist, 100}, - {milvus::knowhere::IndexParams::nprobe, 4}, + // {milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::IndexParams::m, 4}, {milvus::knowhere::IndexParams::nbits, 8}, {milvus::knowhere::Metric::TYPE, metric_type}, @@ -97,14 +95,13 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::Metric::TYPE, metric_type}, }; } else if (index_type == milvus::knowhere::IndexEnum::INDEX_NSG) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, {milvus::knowhere::IndexParams::nlist, 163}, - {milvus::knowhere::meta::TOPK, K}, {milvus::knowhere::IndexParams::nprobe, 8}, {milvus::knowhere::IndexParams::knng, 20}, {milvus::knowhere::IndexParams::search_length, 40}, @@ -130,14 +127,17 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh #endif } else if (index_type == milvus::knowhere::IndexEnum::INDEX_HNSW) { return milvus::knowhere::Config{ - {milvus::knowhere::meta::DIM, DIM}, {milvus::knowhere::meta::TOPK, K}, - {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, - {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, metric_type}, + {milvus::knowhere::meta::DIM, DIM}, + // {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, + {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, + {milvus::knowhere::Metric::TYPE, metric_type}, }; } else if (index_type == milvus::knowhere::IndexEnum::INDEX_ANNOY) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::IndexParams::n_trees, 4}, {milvus::knowhere::IndexParams::search_k, 100}, {milvus::knowhere::Metric::TYPE, metric_type}, @@ -146,7 +146,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_RHNSWFlat) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, {milvus::knowhere::IndexParams::ef, 200}, @@ -156,7 +156,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_RHNSWPQ) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, {milvus::knowhere::IndexParams::ef, 200}, @@ -167,7 +167,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_RHNSWSQ) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, {milvus::knowhere::IndexParams::ef, 200}, @@ -177,7 +177,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_NGTPANNG) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::Metric::TYPE, metric_type}, {milvus::knowhere::IndexParams::edge_size, 10}, {milvus::knowhere::IndexParams::epsilon, 0.1}, @@ -189,7 +189,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh } else if (index_type == milvus::knowhere::IndexEnum::INDEX_NGTONNG) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, K}, + // {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::Metric::TYPE, metric_type}, {milvus::knowhere::IndexParams::edge_size, 20}, {milvus::knowhere::IndexParams::epsilon, 0.1}, @@ -234,99 +234,6 @@ GenDataset(int64_t N, const milvus::knowhere::MetricType& metric_type, bool is_b return milvus::segcore::DataGen(schema, N); } } - -using QueryResultPtr = std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult>; -void -PrintQueryResult(const QueryResultPtr& result) { - auto nq = result->nq; - auto k = result->topk; - - std::stringstream ss_id; - std::stringstream ss_dist; - - for (auto i = 0; i < nq; i++) { - for (auto j = 0; j < k; ++j) { - ss_id << result->ids[i * k + j] << " "; - ss_dist << result->distances[i * k + j] << " "; - } - ss_id << std::endl; - ss_dist << std::endl; - } - std::cout << "id\n" << ss_id.str() << std::endl; - std::cout << "dist\n" << ss_dist.str() << std::endl; -} - -float -L2(const float* point_a, const float* point_b, int dim) { - float dis = 0; - for (auto i = 0; i < dim; i++) { - auto c_a = point_a[i]; - auto c_b = point_b[i]; - dis += pow(c_b - c_a, 2); - } - return dis; -} - -int hamming_weight(uint8_t n) { - int count=0; - while(n != 0){ - count += n&1; - n >>= 1; - } - return count; -} -float -Jaccard(const uint8_t* point_a, const uint8_t* point_b, int dim) { - float dis; - int len = dim / 8; - float intersection = 0; - float union_num = 0; - for (int i = 0; i < len; i++) { - intersection += hamming_weight(point_a[i] & point_b[i]); - union_num += hamming_weight(point_a[i] | point_b[i]); - } - dis = 1 - (intersection / union_num); - return dis; -} - -float -CountDistance(const void* point_a, - const void* point_b, - int dim, - const milvus::knowhere::MetricType& metric, - bool is_binary = false) { - if (point_a == nullptr || point_b == nullptr) { - return std::numeric_limits<float>::max(); - } - if (metric == milvus::knowhere::Metric::L2) { - return L2(static_cast<const float*>(point_a), static_cast<const float*>(point_b), dim); - } else if (metric == milvus::knowhere::Metric::JACCARD) { - return Jaccard(static_cast<const uint8_t*>(point_a), static_cast<const uint8_t*>(point_b), dim); - } else { - return std::numeric_limits<float>::max(); - } -} - -void -CheckDistances(const QueryResultPtr& result, - const milvus::knowhere::DatasetPtr& base_dataset, - const milvus::knowhere::DatasetPtr& query_dataset, - const milvus::knowhere::MetricType& metric, - const float threshold = 1.0e-5) { - auto base_vecs = base_dataset->Get<float*>(milvus::knowhere::meta::TENSOR); - auto query_vecs = query_dataset->Get<float*>(milvus::knowhere::meta::TENSOR); - auto dim = base_dataset->Get<int64_t>(milvus::knowhere::meta::DIM); - auto nq = result->nq; - auto k = result->topk; - for (auto i = 0; i < nq; i++) { - for (auto j = 0; j < k; ++j) { - auto dis = result->distances[i * k + j]; - auto id = result->ids[i * k + j]; - auto count_dis = CountDistance(query_vecs + i * dim, base_vecs + id * dim, dim, metric); - // assert(std::abs(dis - count_dis) < threshold); - } - } -} } // namespace using Param = std::pair<milvus::knowhere::IndexType, milvus::knowhere::MetricType>; @@ -340,26 +247,8 @@ class IndexWrapperTest : public ::testing::TestWithParam<Param> { metric_type = param.second; std::tie(type_params, index_params) = generate_params(index_type, metric_type); - std::map<std::string, bool> is_binary_map = { - {milvus::knowhere::IndexEnum::INDEX_FAISS_IDMAP, false}, - {milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, false}, - {milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, false}, - {milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, false}, - {milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, true}, - {milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, true}, -#ifdef MILVUS_SUPPORT_SPTAG - {milvus::knowhere::IndexEnum::INDEX_SPTAG_KDT_RNT, false}, - {milvus::knowhere::IndexEnum::INDEX_SPTAG_BKT_RNT, false}, -#endif - {milvus::knowhere::IndexEnum::INDEX_HNSW, false}, - {milvus::knowhere::IndexEnum::INDEX_ANNOY, false}, - {milvus::knowhere::IndexEnum::INDEX_RHNSWFlat, false}, - {milvus::knowhere::IndexEnum::INDEX_RHNSWPQ, false}, - {milvus::knowhere::IndexEnum::INDEX_RHNSWSQ, false}, - {milvus::knowhere::IndexEnum::INDEX_NGTPANNG, false}, - {milvus::knowhere::IndexEnum::INDEX_NGTONNG, false}, - {milvus::knowhere::IndexEnum::INDEX_NSG, false}, - }; + std::map<std::string, bool> is_binary_map = {{milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, false}, + {milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, true}}; is_binary = is_binary_map[index_type]; @@ -373,13 +262,9 @@ class IndexWrapperTest : public ::testing::TestWithParam<Param> { if (!is_binary) { xb_data = dataset.get_col<float>(0); xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data()); - xq_data = dataset.get_col<float>(0); - xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_data.data()); } else { xb_bin_data = dataset.get_col<uint8_t>(0); xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_bin_data.data()); - xq_bin_data = dataset.get_col<uint8_t>(0); - xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_bin_data.data()); } } @@ -397,9 +282,6 @@ class IndexWrapperTest : public ::testing::TestWithParam<Param> { std::vector<float> xb_data; std::vector<uint8_t> xb_bin_data; std::vector<milvus::knowhere::IDType> ids; - milvus::knowhere::DatasetPtr xq_dataset; - std::vector<float> xq_data; - std::vector<uint8_t> xq_bin_data; }; TEST(PQ, Build) { @@ -426,47 +308,6 @@ TEST(IVFFLATNM, Build) { ASSERT_NO_THROW(index->AddWithoutIds(xb_dataset, conf)); } -TEST(IVFFLATNM, Query) { - auto index_type = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; - auto metric_type = milvus::knowhere::Metric::L2; - auto conf = generate_conf(index_type, metric_type); - auto index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type); - auto dataset = GenDataset(NB, metric_type, false); - auto xb_data = dataset.get_col<float>(0); - auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data()); - ASSERT_NO_THROW(index->Train(xb_dataset, conf)); - ASSERT_NO_THROW(index->AddWithoutIds(xb_dataset, conf)); - auto bs = index->Serialize(conf); - auto bptr = std::make_shared<milvus::knowhere::Binary>(); - bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_data.data(), [&](uint8_t*) {}); - bptr->size = DIM * NB * sizeof(float); - bs.Append(RAW_DATA, bptr); - index->Load(bs); - auto xq_data = dataset.get_col<float>(0); - auto xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_data.data()); - auto result = index->Query(xq_dataset, conf, nullptr); -} - -TEST(NSG, Query) { - auto index_type = milvus::knowhere::IndexEnum::INDEX_NSG; - auto metric_type = milvus::knowhere::Metric::L2; - auto conf = generate_conf(index_type, metric_type); - auto index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type); - auto dataset = GenDataset(NB, metric_type, false); - auto xb_data = dataset.get_col<float>(0); - auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data()); - index->BuildAll(xb_dataset, conf); - auto bs = index->Serialize(conf); - auto bptr = std::make_shared<milvus::knowhere::Binary>(); - bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_data.data(), [&](uint8_t*) {}); - bptr->size = DIM * NB * sizeof(float); - bs.Append(RAW_DATA, bptr); - index->Load(bs); - auto xq_data = dataset.get_col<float>(0); - auto xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_data.data()); - auto result = index->Query(xq_dataset, conf, nullptr); -} - TEST(BINFLAT, Build) { auto index_type = milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT; auto metric_type = milvus::knowhere::Metric::JACCARD; @@ -644,7 +485,12 @@ TEST_P(IndexWrapperTest, Dim) { TEST_P(IndexWrapperTest, BuildWithoutIds) { auto index = std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); - ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset)); + + if (milvus::indexbuilder::is_in_need_id_list(index_type)) { + ASSERT_ANY_THROW(index->BuildWithoutIds(xb_dataset)); + } else { + ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset)); + } } TEST_P(IndexWrapperTest, Codec) { @@ -665,16 +511,3 @@ TEST_P(IndexWrapperTest, Codec) { ASSERT_EQ(strcmp(binary.data, copy_binary.data), 0); } } - -TEST_P(IndexWrapperTest, Query) { - auto index_wrapper = - std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); - - index_wrapper->BuildWithoutIds(xb_dataset); - - std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult> query_result = index_wrapper->Query(xq_dataset); - ASSERT_EQ(query_result->topk, K); - ASSERT_EQ(query_result->nq, NQ); - ASSERT_EQ(query_result->distances.size(), query_result->topk * query_result->nq); - ASSERT_EQ(query_result->ids.size(), query_result->topk * query_result->nq); -} diff --git a/internal/indexbuilder/indexbuilder.go b/internal/indexbuilder/indexbuilder.go index 5b21e68dd4eebd11079784d799aaf63679a28359..4acfffc3d284665158d9f3c237682a5792257d50 100644 --- a/internal/indexbuilder/indexbuilder.go +++ b/internal/indexbuilder/indexbuilder.go @@ -11,9 +11,6 @@ import ( miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" - "go.etcd.io/etcd/clientv3" "github.com/zilliztech/milvus-distributed/internal/allocator" @@ -71,19 +68,16 @@ func CreateBuilder(ctx context.Context) (*Builder, error) { idAllocator, err := allocator.NewIDAllocator(b.loopCtx, Params.MasterAddress) - minIOEndPoint := Params.MinIOAddress - minIOAccessKeyID := Params.MinIOAccessKeyID - minIOSecretAccessKey := Params.MinIOSecretAccessKey - minIOUseSSL := Params.MinIOUseSSL - minIOClient, err := minio.New(minIOEndPoint, &minio.Options{ - Creds: credentials.NewStaticV4(minIOAccessKeyID, minIOSecretAccessKey, ""), - Secure: minIOUseSSL, - }) - if err != nil { - return nil, err + option := &miniokv.Option{ + Address: Params.MinIOAddress, + AccessKeyID: Params.MinIOAccessKeyID, + SecretAccessKeyID: Params.MinIOSecretAccessKey, + UseSSL: Params.MinIOUseSSL, + BucketName: Params.MinioBucketName, + CreateBucket: true, } - b.kv, err = miniokv.NewMinIOKV(b.loopCtx, minIOClient, Params.MinioBucketName) + b.kv, err = miniokv.NewMinIOKV(b.loopCtx, option) if err != nil { return nil, err } diff --git a/internal/kv/minio/minio_kv.go b/internal/kv/minio/minio_kv.go index 68bb3a3438bbd771a2d5afc98c23988703db72ea..6b3522fb454273f24a08e11d83f4f64e43a40381 100644 --- a/internal/kv/minio/minio_kv.go +++ b/internal/kv/minio/minio_kv.go @@ -2,11 +2,15 @@ package miniokv import ( "context" + "fmt" + "io" "log" "strings" "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "github.com/zilliztech/milvus-distributed/internal/errors" ) type MinIOKV struct { @@ -15,24 +19,46 @@ type MinIOKV struct { bucketName string } -// NewMinIOKV creates a new MinIO kv. -func NewMinIOKV(ctx context.Context, client *minio.Client, bucketName string) (*MinIOKV, error) { +type Option struct { + Address string + AccessKeyID string + BucketName string + SecretAccessKeyID string + UseSSL bool + CreateBucket bool // when bucket not existed, create it +} - bucketExists, err := client.BucketExists(ctx, bucketName) +func NewMinIOKV(ctx context.Context, option *Option) (*MinIOKV, error) { + minIOClient, err := minio.New(option.Address, &minio.Options{ + Creds: credentials.NewStaticV4(option.AccessKeyID, option.SecretAccessKeyID, ""), + Secure: option.UseSSL, + }) if err != nil { return nil, err } - if !bucketExists { - err = client.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{}) - if err != nil { - return nil, err + bucketExists, err := minIOClient.BucketExists(ctx, option.BucketName) + if err != nil { + return nil, err + } + + if option.CreateBucket { + if !bucketExists { + err = minIOClient.MakeBucket(ctx, option.BucketName, minio.MakeBucketOptions{}) + if err != nil { + return nil, err + } + } + } else { + if !bucketExists { + return nil, errors.New(fmt.Sprintf("Bucket %s not Existed.", option.BucketName)) } } + return &MinIOKV{ ctx: ctx, - minioClient: client, - bucketName: bucketName, + minioClient: minIOClient, + bucketName: option.BucketName, }, nil } diff --git a/internal/kv/minio/minio_kv_test.go b/internal/kv/minio/minio_kv_test.go index ac2a3180b2966bbdc09158ff70ec8baaef23dedb..2e50545b40a83516aa16cb0991bbc64cf34360cc 100644 --- a/internal/kv/minio/minio_kv_test.go +++ b/internal/kv/minio/minio_kv_test.go @@ -5,8 +5,6 @@ import ( "strconv" "testing" - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" "github.com/zilliztech/milvus-distributed/internal/util/paramtable" @@ -15,24 +13,31 @@ import ( var Params paramtable.BaseTable -func TestMinIOKV_Load(t *testing.T) { - Params.Init() +func newMinIOKVClient(ctx context.Context, bucketName string) (*miniokv.MinIOKV, error) { endPoint, _ := Params.Load("_MinioAddress") accessKeyID, _ := Params.Load("minio.accessKeyID") secretAccessKey, _ := Params.Load("minio.secretAccessKey") useSSLStr, _ := Params.Load("minio.useSSL") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() useSSL, _ := strconv.ParseBool(useSSLStr) + option := &miniokv.Option{ + Address: endPoint, + AccessKeyID: accessKeyID, + SecretAccessKeyID: secretAccessKey, + UseSSL: useSSL, + BucketName: bucketName, + CreateBucket: true, + } + client, err := miniokv.NewMinIOKV(ctx, option) + return client, err +} - minioClient, err := minio.New(endPoint, &minio.Options{ - Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""), - Secure: useSSL, - }) - assert.Nil(t, err) +func TestMinIOKV_Load(t *testing.T) { + Params.Init() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() bucketName := "fantastic-tech-test" - MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName) + MinIOKV, err := newMinIOKVClient(ctx, bucketName) assert.Nil(t, err) defer MinIOKV.RemoveWithPrefix("") @@ -79,25 +84,14 @@ func TestMinIOKV_Load(t *testing.T) { } func TestMinIOKV_MultiSave(t *testing.T) { + Params.Init() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - Params.Init() - endPoint, _ := Params.Load("_MinioAddress") - accessKeyID, _ := Params.Load("minio.accessKeyID") - secretAccessKey, _ := Params.Load("minio.secretAccessKey") - useSSLStr, _ := Params.Load("minio.useSSL") - useSSL, _ := strconv.ParseBool(useSSLStr) - - minioClient, err := minio.New(endPoint, &minio.Options{ - Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""), - Secure: useSSL, - }) - assert.Nil(t, err) - bucketName := "fantastic-tech-test" - MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName) + MinIOKV, err := newMinIOKVClient(ctx, bucketName) assert.Nil(t, err) + defer MinIOKV.RemoveWithPrefix("") err = MinIOKV.Save("key_1", "111") @@ -117,25 +111,13 @@ func TestMinIOKV_MultiSave(t *testing.T) { } func TestMinIOKV_Remove(t *testing.T) { + Params.Init() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - Params.Init() - endPoint, _ := Params.Load("_MinioAddress") - accessKeyID, _ := Params.Load("minio.accessKeyID") - secretAccessKey, _ := Params.Load("minio.secretAccessKey") - useSSLStr, _ := Params.Load("minio.useSSL") - useSSL, _ := strconv.ParseBool(useSSLStr) - - minioClient, err := minio.New(endPoint, &minio.Options{ - Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""), - Secure: useSSL, - }) - assert.Nil(t, err) - bucketName := "fantastic-tech-test" - MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName) + MinIOKV, err := newMinIOKVClient(ctx, bucketName) assert.Nil(t, err) defer MinIOKV.RemoveWithPrefix("") diff --git a/internal/querynode/load_index_service.go b/internal/querynode/load_index_service.go index cedbd50bb0447e0c535e926eacf208ca1d1e1b29..d8ae759b6731512e1e07df7bf03e63c0b421a0e2 100644 --- a/internal/querynode/load_index_service.go +++ b/internal/querynode/load_index_service.go @@ -11,9 +11,6 @@ import ( "strings" "time" - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" - minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -38,16 +35,17 @@ type loadIndexService struct { func newLoadIndexService(ctx context.Context, replica collectionReplica) *loadIndexService { ctx1, cancel := context.WithCancel(ctx) - // init minio - minioClient, err := minio.New(Params.MinioEndPoint, &minio.Options{ - Creds: credentials.NewStaticV4(Params.MinioAccessKeyID, Params.MinioSecretAccessKey, ""), - Secure: Params.MinioUseSSLStr, - }) - if err != nil { - panic(err) + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + CreateBucket: true, + BucketName: Params.MinioBucketName, } - MinioKV, err := minioKV.NewMinIOKV(ctx1, minioClient, Params.MinioBucketName) + // TODO: load bucketName from config + MinioKV, err := minioKV.NewMinIOKV(ctx1, option) if err != nil { panic(err) } diff --git a/internal/querynode/load_index_service_test.go b/internal/querynode/load_index_service_test.go index cb2fb8504e190345567ee170ecce71846fe2ce45..000edb49df2bf53fddef89b3261c9ceb15f4c5de 100644 --- a/internal/querynode/load_index_service_test.go +++ b/internal/querynode/load_index_service_test.go @@ -5,8 +5,6 @@ import ( "sort" "testing" - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" "github.com/stretchr/testify/assert" "github.com/zilliztech/milvus-distributed/internal/indexbuilder" @@ -68,13 +66,16 @@ func TestLoadIndexService(t *testing.T) { binarySet, err := index.Serialize() assert.Equal(t, err, nil) - //save index to minio - minioClient, err := minio.New(Params.MinioEndPoint, &minio.Options{ - Creds: credentials.NewStaticV4(Params.MinioAccessKeyID, Params.MinioSecretAccessKey, ""), - Secure: Params.MinioUseSSLStr, - }) - assert.Equal(t, err, nil) - minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, minioClient, Params.MinioBucketName) + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: Params.MinioBucketName, + CreateBucket: true, + } + + minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) assert.Equal(t, err, nil) indexPaths := make([]string, 0) for _, index := range binarySet { diff --git a/internal/storage/internal/S3/S3_test.go b/internal/storage/internal/S3/S3_test.go deleted file mode 100644 index c565f4a5a15cf25d109f5dcff642a0a21f2f94d1..0000000000000000000000000000000000000000 --- a/internal/storage/internal/S3/S3_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package s3driver - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -var option = storagetype.Option{BucketName: "zilliz-hz"} -var ctx = context.Background() -var client, err = NewS3Driver(ctx, option) - -func TestS3Driver_PutRowAndGetRow(t *testing.T) { - err = client.PutRow(ctx, []byte("bar"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("bar"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("bar"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("bar1"), []byte("testkeybarorbar_1"), "SegmentC", 3) - assert.Nil(t, err) - object, _ := client.GetRow(ctx, []byte("bar"), 1) - assert.Equal(t, "abcdefghijklmnoopqrstuvwxyz", string(object)) - object, _ = client.GetRow(ctx, []byte("bar"), 2) - assert.Equal(t, "djhfkjsbdfbsdughorsgsdjhgoisdgh", string(object)) - object, _ = client.GetRow(ctx, []byte("bar"), 5) - assert.Equal(t, "123854676ershdgfsgdfk,sdhfg;sdi8", string(object)) - object, _ = client.GetRow(ctx, []byte("bar1"), 5) - assert.Equal(t, "testkeybarorbar_1", string(object)) -} - -func TestS3Driver_DeleteRow(t *testing.T) { - err = client.DeleteRow(ctx, []byte("bar"), 5) - assert.Nil(t, err) - object, _ := client.GetRow(ctx, []byte("bar"), 6) - assert.Nil(t, object) - err = client.DeleteRow(ctx, []byte("bar1"), 5) - assert.Nil(t, err) - object2, _ := client.GetRow(ctx, []byte("bar1"), 6) - assert.Nil(t, object2) -} - -func TestS3Driver_GetSegments(t *testing.T) { - err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1) - assert.Nil(t, err) - - segements, err := client.GetSegments(ctx, []byte("seg"), 4) - assert.Nil(t, err) - assert.Equal(t, 2, len(segements)) - if segements[0] == "SegmentA" { - assert.Equal(t, "SegmentA", segements[0]) - assert.Equal(t, "SegmentB", segements[1]) - } else { - assert.Equal(t, "SegmentB", segements[0]) - assert.Equal(t, "SegmentA", segements[1]) - } -} - -func TestS3Driver_PutRowsAndGetRows(t *testing.T) { - keys := [][]byte{[]byte("foo"), []byte("bar")} - values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")} - segments := []string{"segmentA", "segmentB"} - timestamps := []uint64{1, 2} - err = client.PutRows(ctx, keys, values, segments, timestamps) - assert.Nil(t, err) - - objects, err := client.GetRows(ctx, keys, timestamps) - assert.Nil(t, err) - assert.Equal(t, "The key is foo!", string(objects[0])) - assert.Equal(t, "The key is bar!", string(objects[1])) -} - -func TestS3Driver_DeleteRows(t *testing.T) { - keys := [][]byte{[]byte("foo"), []byte("bar")} - timestamps := []uint64{3, 3} - err := client.DeleteRows(ctx, keys, timestamps) - assert.Nil(t, err) - - objects, err := client.GetRows(ctx, keys, timestamps) - assert.Nil(t, err) - assert.Nil(t, objects[0]) - assert.Nil(t, objects[1]) -} - -func TestS3Driver_PutLogAndGetLog(t *testing.T) { - err = client.PutLog(ctx, []byte("insert"), []byte("This is insert log!"), 1, 11) - assert.Nil(t, err) - err = client.PutLog(ctx, []byte("delete"), []byte("This is delete log!"), 2, 10) - assert.Nil(t, err) - err = client.PutLog(ctx, []byte("update"), []byte("This is update log!"), 3, 9) - assert.Nil(t, err) - err = client.PutLog(ctx, []byte("select"), []byte("This is select log!"), 4, 8) - assert.Nil(t, err) - - channels := []int{5, 8, 9, 10, 11, 12, 13} - logValues, err := client.GetLog(ctx, 0, 5, channels) - assert.Nil(t, err) - assert.Equal(t, "This is select log!", string(logValues[0])) - assert.Equal(t, "This is update log!", string(logValues[1])) - assert.Equal(t, "This is delete log!", string(logValues[2])) - assert.Equal(t, "This is insert log!", string(logValues[3])) -} - -func TestS3Driver_Segment(t *testing.T) { - err := client.PutSegmentIndex(ctx, "segmentA", []byte("This is segmentA's index!")) - assert.Nil(t, err) - - segmentIndex, err := client.GetSegmentIndex(ctx, "segmentA") - assert.Equal(t, "This is segmentA's index!", string(segmentIndex)) - assert.Nil(t, err) - - err = client.DeleteSegmentIndex(ctx, "segmentA") - assert.Nil(t, err) -} - -func TestS3Driver_SegmentDL(t *testing.T) { - err := client.PutSegmentDL(ctx, "segmentB", []byte("This is segmentB's delete log!")) - assert.Nil(t, err) - - segmentDL, err := client.GetSegmentDL(ctx, "segmentB") - assert.Nil(t, err) - assert.Equal(t, "This is segmentB's delete log!", string(segmentDL)) - - err = client.DeleteSegmentDL(ctx, "segmentB") - assert.Nil(t, err) -} diff --git a/internal/storage/internal/S3/s3_engine.go b/internal/storage/internal/S3/s3_engine.go deleted file mode 100644 index 8034d679e7e01fb45867b7424b6e9e31bc35c294..0000000000000000000000000000000000000000 --- a/internal/storage/internal/S3/s3_engine.go +++ /dev/null @@ -1,173 +0,0 @@ -package s3driver - -import ( - "bytes" - "context" - "io" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - . "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -type S3Store struct { - client *s3.S3 -} - -func NewS3Store(config aws.Config) (*S3Store, error) { - sess := session.Must(session.NewSession(&config)) - service := s3.New(sess) - - return &S3Store{ - client: service, - }, nil -} - -func (s *S3Store) Put(ctx context.Context, key Key, value Value) error { - _, err := s.client.PutObjectWithContext(ctx, &s3.PutObjectInput{ - Bucket: aws.String(bucketName), - Key: aws.String(string(key)), - Body: bytes.NewReader(value), - }) - - //sess := session.Must(session.NewSessionWithOptions(session.Options{ - // SharedConfigState: session.SharedConfigEnable, - //})) - //uploader := s3manager.NewUploader(sess) - // - //_, err := uploader.Upload(&s3manager.UploadInput{ - // Bucket: aws.String(bucketName), - // Key: aws.String(string(key)), - // Body: bytes.NewReader(value), - //}) - - return err -} - -func (s *S3Store) Get(ctx context.Context, key Key) (Value, error) { - object, err := s.client.GetObjectWithContext(ctx, &s3.GetObjectInput{ - Bucket: aws.String(bucketName), - Key: aws.String(string(key)), - }) - if err != nil { - return nil, err - } - - //TODO: get size - size := 256 * 1024 - buf := make([]byte, size) - n, err := object.Body.Read(buf) - if err != nil && err != io.EOF { - return nil, err - } - return buf[:n], nil -} - -func (s *S3Store) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) { - objectsOutput, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ - Bucket: aws.String(bucketName), - Prefix: aws.String(string(prefix)), - }) - - var objectsKeys []Key - var objectsValues []Value - - if objectsOutput != nil && err == nil { - for _, object := range objectsOutput.Contents { - objectsKeys = append(objectsKeys, []byte(*object.Key)) - if !keyOnly { - value, err := s.Get(ctx, []byte(*object.Key)) - if err != nil { - return nil, nil, err - } - objectsValues = append(objectsValues, value) - } - } - } else { - return nil, nil, err - } - - return objectsKeys, objectsValues, nil - -} - -func (s *S3Store) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit int, keyOnly bool) ([]Key, []Value, error) { - var keys []Key - var values []Value - limitCount := uint(limit) - objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ - Bucket: aws.String(bucketName), - Prefix: aws.String(string(keyStart)), - }) - if err == nil && objects != nil { - for _, object := range objects.Contents { - if *object.Key >= string(keyEnd) { - keys = append(keys, []byte(*object.Key)) - if !keyOnly { - value, err := s.Get(ctx, []byte(*object.Key)) - if err != nil { - return nil, nil, err - } - values = append(values, value) - } - limitCount-- - if limitCount <= 0 { - break - } - } - } - } - - return keys, values, err -} - -func (s *S3Store) Delete(ctx context.Context, key Key) error { - _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(bucketName), - Key: aws.String(string(key)), - }) - return err -} - -func (s *S3Store) DeleteByPrefix(ctx context.Context, prefix Key) error { - - objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ - Bucket: aws.String(bucketName), - Prefix: aws.String(string(prefix)), - }) - - if objects != nil && err == nil { - for _, object := range objects.Contents { - _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(bucketName), - Key: object.Key, - }) - return err - } - } - - return nil -} - -func (s *S3Store) DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error { - - objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ - Bucket: aws.String(bucketName), - Prefix: aws.String(string(keyStart)), - }) - - if objects != nil && err == nil { - for _, object := range objects.Contents { - if *object.Key > string(keyEnd) { - _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(bucketName), - Key: object.Key, - }) - return err - } - } - } - - return nil -} diff --git a/internal/storage/internal/S3/s3_store.go b/internal/storage/internal/S3/s3_store.go deleted file mode 100644 index 19199baa54dcdc50d6ae947f899be1068c383642..0000000000000000000000000000000000000000 --- a/internal/storage/internal/S3/s3_store.go +++ /dev/null @@ -1,339 +0,0 @@ -package s3driver - -import ( - "context" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio/codec" - . "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -type S3Driver struct { - driver *S3Store -} - -var bucketName string - -func NewS3Driver(ctx context.Context, option Option) (*S3Driver, error) { - // to-do read conf - - bucketName = option.BucketName - - S3Client, err := NewS3Store(aws.Config{ - Region: aws.String(endpoints.CnNorthwest1RegionID)}) - - if err != nil { - return nil, err - } - - return &S3Driver{ - S3Client, - }, nil -} - -func (s *S3Driver) put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error { - minioKey, err := codec.MvccEncode(key, timestamp, suffix) - if err != nil { - return err - } - - err = s.driver.Put(ctx, minioKey, value) - return err -} - -func (s *S3Driver) scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) { - keyEnd, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, nil, nil, err - } - - keys, values, err := s.driver.Scan(ctx, key, []byte(keyEnd), -1, keyOnly) - if err != nil { - return nil, nil, nil, err - } - - var timestamps []Timestamp - for _, key := range keys { - _, timestamp, _, _ := codec.MvccDecode(key) - timestamps = append(timestamps, timestamp) - } - - return timestamps, keys, values, nil -} - -func (s *S3Driver) scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) { - keyStart, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, nil, nil, err - } - - keys, values, err := s.driver.Scan(ctx, key, keyStart, -1, keyOnly) - if err != nil { - return nil, nil, nil, err - } - - var timestamps []Timestamp - for _, key := range keys { - _, timestamp, _, _ := codec.MvccDecode(key) - timestamps = append(timestamps, timestamp) - } - - return timestamps, keys, values, nil -} - -//scan(ctx context.Context, key Key, start Timestamp, end Timestamp, withValue bool) ([]Timestamp, []Key, []Value, error) -func (s *S3Driver) deleteLE(ctx context.Context, key Key, timestamp Timestamp) error { - keyEnd, err := codec.MvccEncode(key, timestamp, "delete") - if err != nil { - return err - } - err = s.driver.DeleteRange(ctx, key, keyEnd) - return err -} -func (s *S3Driver) deleteGE(ctx context.Context, key Key, timestamp Timestamp) error { - keys, _, err := s.driver.GetByPrefix(ctx, key, true) - if err != nil { - return err - } - keyStart, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - panic(err) - } - err = s.driver.DeleteRange(ctx, []byte(keyStart), keys[len(keys)-1]) - return err -} -func (s *S3Driver) deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error { - keyStart, err := codec.MvccEncode(key, start, "") - if err != nil { - return err - } - keyEnd, err := codec.MvccEncode(key, end, "") - if err != nil { - return err - } - err = s.driver.DeleteRange(ctx, keyStart, keyEnd) - return err -} - -func (s *S3Driver) GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) { - minioKey, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, err - } - - keys, values, err := s.driver.Scan(ctx, append(key, byte('_')), minioKey, 1, false) - if values == nil || keys == nil { - return nil, err - } - - _, _, suffix, err := codec.MvccDecode(keys[0]) - if err != nil { - return nil, err - } - if suffix == "delete" { - return nil, nil - } - - return values[0], err -} -func (s *S3Driver) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) { - var values []Value - for i, key := range keys { - value, err := s.GetRow(ctx, key, timestamps[i]) - if err != nil { - return nil, err - } - values = append(values, value) - } - return values, nil -} - -func (s *S3Driver) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error { - minioKey, err := codec.MvccEncode(key, timestamp, segment) - if err != nil { - return err - } - err = s.driver.Put(ctx, minioKey, value) - return err -} -func (s *S3Driver) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error { - maxThread := 100 - batchSize := 1 - keysLength := len(keys) - - if keysLength/batchSize > maxThread { - batchSize = keysLength / maxThread - } - - batchNums := keysLength / batchSize - - if keysLength%batchSize != 0 { - batchNums = keysLength/batchSize + 1 - } - - errCh := make(chan error) - f := func(ctx2 context.Context, keys2 []Key, values2 []Value, segments2 []string, timestamps2 []Timestamp) { - for i := 0; i < len(keys2); i++ { - err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i]) - errCh <- err - } - } - for i := 0; i < batchNums; i++ { - j := i - go func() { - start, end := j*batchSize, (j+1)*batchSize - if len(keys) < end { - end = len(keys) - } - f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end]) - }() - } - - for i := 0; i < len(keys); i++ { - if err := <-errCh; err != nil { - return err - } - } - return nil -} - -func (s *S3Driver) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) { - keyEnd, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, err - } - keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1, true) - if err != nil { - return nil, err - } - segmentsSet := map[string]bool{} - for _, key := range keys { - _, _, segment, err := codec.MvccDecode(key) - if err != nil { - panic("must no error") - } - if segment != "delete" { - segmentsSet[segment] = true - } - } - - var segments []string - for k, v := range segmentsSet { - if v { - segments = append(segments, k) - } - } - return segments, err -} - -func (s *S3Driver) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error { - minioKey, err := codec.MvccEncode(key, timestamp, "delete") - if err != nil { - return err - } - value := []byte("0") - err = s.driver.Put(ctx, minioKey, value) - return err -} - -func (s *S3Driver) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error { - maxThread := 100 - batchSize := 1 - keysLength := len(keys) - - if keysLength/batchSize > maxThread { - batchSize = keysLength / maxThread - } - - batchNums := keysLength / batchSize - - if keysLength%batchSize != 0 { - batchNums = keysLength/batchSize + 1 - } - - errCh := make(chan error) - f := func(ctx2 context.Context, keys2 []Key, timestamps2 []Timestamp) { - for i := 0; i < len(keys2); i++ { - err := s.DeleteRow(ctx2, keys2[i], timestamps2[i]) - errCh <- err - } - } - for i := 0; i < batchNums; i++ { - j := i - go func() { - start, end := j*batchSize, (j+1)*batchSize - if len(keys) < end { - end = len(keys) - } - f(ctx, keys[start:end], timestamps[start:end]) - }() - } - - for i := 0; i < len(keys); i++ { - if err := <-errCh; err != nil { - return err - } - } - return nil -} - -func (s *S3Driver) PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error { - logKey := codec.LogEncode(key, timestamp, channel) - err := s.driver.Put(ctx, logKey, value) - return err -} - -func (s *S3Driver) GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error) { - keys, values, err := s.driver.GetByPrefix(ctx, []byte("log_"), false) - if err != nil { - return nil, err - } - - var resultValues []Value - for i, key := range keys { - _, ts, channel, err := codec.LogDecode(string(key)) - if err != nil { - return nil, err - } - if ts >= start && ts <= end { - for j := 0; j < len(channels); j++ { - if channel == channels[j] { - resultValues = append(resultValues, values[i]) - } - } - } - } - - return resultValues, nil -} - -func (s *S3Driver) GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) { - - return s.driver.Get(ctx, codec.SegmentEncode(segment, "index")) -} - -func (s *S3Driver) PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error { - - return s.driver.Put(ctx, codec.SegmentEncode(segment, "index"), index) -} - -func (s *S3Driver) DeleteSegmentIndex(ctx context.Context, segment string) error { - - return s.driver.Delete(ctx, codec.SegmentEncode(segment, "index")) -} - -func (s *S3Driver) GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) { - - return s.driver.Get(ctx, codec.SegmentEncode(segment, "DL")) -} - -func (s *S3Driver) PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error { - - return s.driver.Put(ctx, codec.SegmentEncode(segment, "DL"), log) -} - -func (s *S3Driver) DeleteSegmentDL(ctx context.Context, segment string) error { - - return s.driver.Delete(ctx, codec.SegmentEncode(segment, "DL")) -} diff --git a/internal/storage/internal/minio/codec/codec.go b/internal/storage/internal/minio/codec/codec.go deleted file mode 100644 index 4d2b76ee2f939b9f221b47d2ef3f38ce8fe0c71e..0000000000000000000000000000000000000000 --- a/internal/storage/internal/minio/codec/codec.go +++ /dev/null @@ -1,101 +0,0 @@ -package codec - -import ( - "errors" - "fmt" -) - -func MvccEncode(key []byte, ts uint64, suffix string) ([]byte, error) { - return []byte(string(key) + "_" + fmt.Sprintf("%016x", ^ts) + "_" + suffix), nil -} - -func MvccDecode(key []byte) (string, uint64, string, error) { - if len(key) < 16 { - return "", 0, "", errors.New("insufficient bytes to decode value") - } - - suffixIndex := 0 - TSIndex := 0 - undersCount := 0 - for i := len(key) - 1; i > 0; i-- { - if key[i] == byte('_') { - undersCount++ - if undersCount == 1 { - suffixIndex = i + 1 - } - if undersCount == 2 { - TSIndex = i + 1 - break - } - } - } - if suffixIndex == 0 || TSIndex == 0 { - return "", 0, "", errors.New("key is wrong formatted") - } - - var TS uint64 - _, err := fmt.Sscanf(string(key[TSIndex:suffixIndex-1]), "%x", &TS) - TS = ^TS - if err != nil { - return "", 0, "", err - } - - return string(key[0 : TSIndex-1]), TS, string(key[suffixIndex:]), nil -} - -func LogEncode(key []byte, ts uint64, channel int) []byte { - suffix := string(key) + "_" + fmt.Sprintf("%d", channel) - logKey, err := MvccEncode([]byte("log"), ts, suffix) - if err != nil { - return nil - } - return logKey -} - -func LogDecode(logKey string) (string, uint64, int, error) { - if len(logKey) < 16 { - return "", 0, 0, errors.New("insufficient bytes to decode value") - } - - channelIndex := 0 - keyIndex := 0 - TSIndex := 0 - undersCount := 0 - - for i := len(logKey) - 1; i > 0; i-- { - if logKey[i] == '_' { - undersCount++ - if undersCount == 1 { - channelIndex = i + 1 - } - if undersCount == 2 { - keyIndex = i + 1 - } - if undersCount == 3 { - TSIndex = i + 1 - break - } - } - } - if channelIndex == 0 || TSIndex == 0 || keyIndex == 0 || logKey[:TSIndex-1] != "log" { - return "", 0, 0, errors.New("key is wrong formatted") - } - - var TS uint64 - var channel int - _, err := fmt.Sscanf(logKey[TSIndex:keyIndex-1], "%x", &TS) - if err != nil { - return "", 0, 0, err - } - TS = ^TS - - _, err = fmt.Sscanf(logKey[channelIndex:], "%d", &channel) - if err != nil { - return "", 0, 0, err - } - return logKey[keyIndex : channelIndex-1], TS, channel, nil -} - -func SegmentEncode(segment string, suffix string) []byte { - return []byte(segment + "_" + suffix) -} diff --git a/internal/storage/internal/minio/minio_store.go b/internal/storage/internal/minio/minio_store.go deleted file mode 100644 index 18e2512401ac94093cd55eb89b016a9be4d92dcf..0000000000000000000000000000000000000000 --- a/internal/storage/internal/minio/minio_store.go +++ /dev/null @@ -1,361 +0,0 @@ -package miniodriver - -import ( - "context" - - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" - "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio/codec" - storageType "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -type MinioDriver struct { - driver *minioStore -} - -var bucketName string - -func NewMinioDriver(ctx context.Context, option storageType.Option) (*MinioDriver, error) { - // to-do read conf - var endPoint = "localhost:9000" - var accessKeyID = "testminio" - var secretAccessKey = "testminio" - var useSSL = false - - bucketName := option.BucketName - - minioClient, err := minio.New(endPoint, &minio.Options{ - Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""), - Secure: useSSL, - }) - - if err != nil { - return nil, err - } - - bucketExists, err := minioClient.BucketExists(ctx, bucketName) - if err != nil { - return nil, err - } - - if !bucketExists { - err = minioClient.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{}) - if err != nil { - return nil, err - } - } - return &MinioDriver{ - &minioStore{ - client: minioClient, - }, - }, nil -} - -func (s *MinioDriver) put(ctx context.Context, key storageType.Key, value storageType.Value, timestamp storageType.Timestamp, suffix string) error { - minioKey, err := codec.MvccEncode(key, timestamp, suffix) - if err != nil { - return err - } - - err = s.driver.Put(ctx, minioKey, value) - return err -} - -func (s *MinioDriver) scanLE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp, keyOnly bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) { - keyEnd, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, nil, nil, err - } - - keys, values, err := s.driver.Scan(ctx, key, []byte(keyEnd), -1, keyOnly) - if err != nil { - return nil, nil, nil, err - } - - var timestamps []storageType.Timestamp - for _, key := range keys { - _, timestamp, _, _ := codec.MvccDecode(key) - timestamps = append(timestamps, timestamp) - } - - return timestamps, keys, values, nil -} - -func (s *MinioDriver) scanGE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp, keyOnly bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) { - keyStart, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, nil, nil, err - } - - keys, values, err := s.driver.Scan(ctx, key, keyStart, -1, keyOnly) - if err != nil { - return nil, nil, nil, err - } - - var timestamps []storageType.Timestamp - for _, key := range keys { - _, timestamp, _, _ := codec.MvccDecode(key) - timestamps = append(timestamps, timestamp) - } - - return timestamps, keys, values, nil -} - -//scan(ctx context.Context, key storageType.Key, start storageType.Timestamp, end storageType.Timestamp, withValue bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) -func (s *MinioDriver) deleteLE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error { - keyEnd, err := codec.MvccEncode(key, timestamp, "delete") - if err != nil { - return err - } - err = s.driver.DeleteRange(ctx, key, keyEnd) - return err -} -func (s *MinioDriver) deleteGE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error { - keys, _, err := s.driver.GetByPrefix(ctx, key, true) - if err != nil { - return err - } - keyStart, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - panic(err) - } - err = s.driver.DeleteRange(ctx, keyStart, keys[len(keys)-1]) - if err != nil { - panic(err) - } - return nil -} -func (s *MinioDriver) deleteRange(ctx context.Context, key storageType.Key, start storageType.Timestamp, end storageType.Timestamp) error { - keyStart, err := codec.MvccEncode(key, start, "") - if err != nil { - return err - } - keyEnd, err := codec.MvccEncode(key, end, "") - if err != nil { - return err - } - err = s.driver.DeleteRange(ctx, keyStart, keyEnd) - return err -} - -func (s *MinioDriver) GetRow(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) (storageType.Value, error) { - minioKey, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, err - } - - keys, values, err := s.driver.Scan(ctx, append(key, byte('_')), minioKey, 1, false) - if values == nil || keys == nil { - return nil, err - } - - _, _, suffix, err := codec.MvccDecode(keys[0]) - if err != nil { - return nil, err - } - if suffix == "delete" { - return nil, nil - } - - return values[0], err -} -func (s *MinioDriver) GetRows(ctx context.Context, keys []storageType.Key, timestamps []storageType.Timestamp) ([]storageType.Value, error) { - var values []storageType.Value - for i, key := range keys { - value, err := s.GetRow(ctx, key, timestamps[i]) - if err != nil { - return nil, err - } - values = append(values, value) - } - return values, nil -} - -func (s *MinioDriver) PutRow(ctx context.Context, key storageType.Key, value storageType.Value, segment string, timestamp storageType.Timestamp) error { - minioKey, err := codec.MvccEncode(key, timestamp, segment) - if err != nil { - return err - } - err = s.driver.Put(ctx, minioKey, value) - return err -} -func (s *MinioDriver) PutRows(ctx context.Context, keys []storageType.Key, values []storageType.Value, segments []string, timestamps []storageType.Timestamp) error { - maxThread := 100 - batchSize := 1 - keysLength := len(keys) - - if keysLength/batchSize > maxThread { - batchSize = keysLength / maxThread - } - - batchNums := keysLength / batchSize - - if keysLength%batchSize != 0 { - batchNums = keysLength/batchSize + 1 - } - - errCh := make(chan error) - f := func(ctx2 context.Context, keys2 []storageType.Key, values2 []storageType.Value, segments2 []string, timestamps2 []storageType.Timestamp) { - for i := 0; i < len(keys2); i++ { - err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i]) - errCh <- err - } - } - for i := 0; i < batchNums; i++ { - j := i - go func() { - start, end := j*batchSize, (j+1)*batchSize - if len(keys) < end { - end = len(keys) - } - f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end]) - }() - } - - for i := 0; i < len(keys); i++ { - if err := <-errCh; err != nil { - return err - } - } - return nil -} - -func (s *MinioDriver) GetSegments(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) ([]string, error) { - keyEnd, err := codec.MvccEncode(key, timestamp, "") - if err != nil { - return nil, err - } - keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1, true) - if err != nil { - return nil, err - } - segmentsSet := map[string]bool{} - for _, key := range keys { - _, _, segment, err := codec.MvccDecode(key) - if err != nil { - panic("must no error") - } - if segment != "delete" { - segmentsSet[segment] = true - } - } - - var segments []string - for k, v := range segmentsSet { - if v { - segments = append(segments, k) - } - } - return segments, err -} - -func (s *MinioDriver) DeleteRow(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error { - minioKey, err := codec.MvccEncode(key, timestamp, "delete") - if err != nil { - return err - } - value := []byte("0") - err = s.driver.Put(ctx, minioKey, value) - return err -} - -func (s *MinioDriver) DeleteRows(ctx context.Context, keys []storageType.Key, timestamps []storageType.Timestamp) error { - maxThread := 100 - batchSize := 1 - keysLength := len(keys) - - if keysLength/batchSize > maxThread { - batchSize = keysLength / maxThread - } - - batchNums := keysLength / batchSize - - if keysLength%batchSize != 0 { - batchNums = keysLength/batchSize + 1 - } - - errCh := make(chan error) - f := func(ctx2 context.Context, keys2 []storageType.Key, timestamps2 []storageType.Timestamp) { - for i := 0; i < len(keys2); i++ { - err := s.DeleteRow(ctx2, keys2[i], timestamps2[i]) - errCh <- err - } - } - for i := 0; i < batchNums; i++ { - j := i - go func() { - start, end := j*batchSize, (j+1)*batchSize - if len(keys) < end { - end = len(keys) - } - f(ctx, keys[start:end], timestamps[start:end]) - }() - } - - for i := 0; i < len(keys); i++ { - if err := <-errCh; err != nil { - return err - } - } - return nil -} - -func (s *MinioDriver) PutLog(ctx context.Context, key storageType.Key, value storageType.Value, timestamp storageType.Timestamp, channel int) error { - logKey := codec.LogEncode(key, timestamp, channel) - err := s.driver.Put(ctx, logKey, value) - return err -} - -func (s *MinioDriver) GetLog(ctx context.Context, start storageType.Timestamp, end storageType.Timestamp, channels []int) ([]storageType.Value, error) { - keys, values, err := s.driver.GetByPrefix(ctx, []byte("log_"), false) - if err != nil { - return nil, err - } - - var resultValues []storageType.Value - for i, key := range keys { - _, ts, channel, err := codec.LogDecode(string(key)) - if err != nil { - return nil, err - } - if ts >= start && ts <= end { - for j := 0; j < len(channels); j++ { - if channel == channels[j] { - resultValues = append(resultValues, values[i]) - } - } - } - } - - return resultValues, nil -} - -func (s *MinioDriver) GetSegmentIndex(ctx context.Context, segment string) (storageType.SegmentIndex, error) { - - return s.driver.Get(ctx, codec.SegmentEncode(segment, "index")) -} - -func (s *MinioDriver) PutSegmentIndex(ctx context.Context, segment string, index storageType.SegmentIndex) error { - - return s.driver.Put(ctx, codec.SegmentEncode(segment, "index"), index) -} - -func (s *MinioDriver) DeleteSegmentIndex(ctx context.Context, segment string) error { - - return s.driver.Delete(ctx, codec.SegmentEncode(segment, "index")) -} - -func (s *MinioDriver) GetSegmentDL(ctx context.Context, segment string) (storageType.SegmentDL, error) { - - return s.driver.Get(ctx, codec.SegmentEncode(segment, "DL")) -} - -func (s *MinioDriver) PutSegmentDL(ctx context.Context, segment string, log storageType.SegmentDL) error { - - return s.driver.Put(ctx, codec.SegmentEncode(segment, "DL"), log) -} - -func (s *MinioDriver) DeleteSegmentDL(ctx context.Context, segment string) error { - - return s.driver.Delete(ctx, codec.SegmentEncode(segment, "DL")) -} diff --git a/internal/storage/internal/minio/minio_storeEngine.go b/internal/storage/internal/minio/minio_storeEngine.go deleted file mode 100644 index 64d74e859032b0306c9e9cce655d48d3df52a8a0..0000000000000000000000000000000000000000 --- a/internal/storage/internal/minio/minio_storeEngine.go +++ /dev/null @@ -1,130 +0,0 @@ -package miniodriver - -import ( - "bytes" - "context" - "io" - - "github.com/minio/minio-go/v7" - . "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -type minioStore struct { - client *minio.Client -} - -func (s *minioStore) Put(ctx context.Context, key Key, value Value) error { - reader := bytes.NewReader(value) - _, err := s.client.PutObject(ctx, bucketName, string(key), reader, int64(len(value)), minio.PutObjectOptions{}) - - if err != nil { - return err - } - - return err -} - -func (s *minioStore) Get(ctx context.Context, key Key) (Value, error) { - object, err := s.client.GetObject(ctx, bucketName, string(key), minio.GetObjectOptions{}) - if err != nil { - return nil, err - } - - size := 256 * 1024 - buf := make([]byte, size) - n, err := object.Read(buf) - if err != nil && err != io.EOF { - return nil, err - } - return buf[:n], nil -} - -func (s *minioStore) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) { - objects := s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(prefix)}) - - var objectsKeys []Key - var objectsValues []Value - - for object := range objects { - objectsKeys = append(objectsKeys, []byte(object.Key)) - if !keyOnly { - value, err := s.Get(ctx, []byte(object.Key)) - if err != nil { - return nil, nil, err - } - objectsValues = append(objectsValues, value) - } - } - - return objectsKeys, objectsValues, nil - -} - -func (s *minioStore) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit int, keyOnly bool) ([]Key, []Value, error) { - var keys []Key - var values []Value - limitCount := uint(limit) - for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(keyStart)}) { - if object.Key >= string(keyEnd) { - keys = append(keys, []byte(object.Key)) - if !keyOnly { - value, err := s.Get(ctx, []byte(object.Key)) - if err != nil { - return nil, nil, err - } - values = append(values, value) - } - limitCount-- - if limitCount <= 0 { - break - } - } - } - - return keys, values, nil -} - -func (s *minioStore) Delete(ctx context.Context, key Key) error { - err := s.client.RemoveObject(ctx, bucketName, string(key), minio.RemoveObjectOptions{}) - return err -} - -func (s *minioStore) DeleteByPrefix(ctx context.Context, prefix Key) error { - objectsCh := make(chan minio.ObjectInfo) - - go func() { - defer close(objectsCh) - - for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(prefix)}) { - objectsCh <- object - } - }() - - for rErr := range s.client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) { - if rErr.Err != nil { - return rErr.Err - } - } - return nil -} - -func (s *minioStore) DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error { - objectsCh := make(chan minio.ObjectInfo) - - go func() { - defer close(objectsCh) - - for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(keyStart)}) { - if object.Key <= string(keyEnd) { - objectsCh <- object - } - } - }() - - for rErr := range s.client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) { - if rErr.Err != nil { - return rErr.Err - } - } - return nil -} diff --git a/internal/storage/internal/minio/minio_test.go b/internal/storage/internal/minio/minio_test.go deleted file mode 100644 index d98a98cfeaac2a4d77b8cf999940a04c1a61c71d..0000000000000000000000000000000000000000 --- a/internal/storage/internal/minio/minio_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package miniodriver - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -var option = storagetype.Option{BucketName: "zilliz-hz"} -var ctx = context.Background() -var client, err = NewMinioDriver(ctx, option) - -func TestMinioDriver_PutRowAndGetRow(t *testing.T) { - err = client.PutRow(ctx, []byte("bar"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("bar"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("bar"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("bar1"), []byte("testkeybarorbar_1"), "SegmentC", 3) - assert.Nil(t, err) - object, _ := client.GetRow(ctx, []byte("bar"), 5) - assert.Equal(t, "abcdefghijklmnoopqrstuvwxyz", string(object)) - object, _ = client.GetRow(ctx, []byte("bar"), 2) - assert.Equal(t, "djhfkjsbdfbsdughorsgsdjhgoisdgh", string(object)) - object, _ = client.GetRow(ctx, []byte("bar"), 5) - assert.Equal(t, "123854676ershdgfsgdfk,sdhfg;sdi8", string(object)) - object, _ = client.GetRow(ctx, []byte("bar1"), 5) - assert.Equal(t, "testkeybarorbar_1", string(object)) -} - -func TestMinioDriver_DeleteRow(t *testing.T) { - err = client.DeleteRow(ctx, []byte("bar"), 5) - assert.Nil(t, err) - object, _ := client.GetRow(ctx, []byte("bar"), 6) - assert.Nil(t, object) - err = client.DeleteRow(ctx, []byte("bar1"), 5) - assert.Nil(t, err) - object2, _ := client.GetRow(ctx, []byte("bar1"), 6) - assert.Nil(t, object2) -} - -func TestMinioDriver_GetSegments(t *testing.T) { - err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3) - assert.Nil(t, err) - err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1) - assert.Nil(t, err) - - segements, err := client.GetSegments(ctx, []byte("seg"), 4) - assert.Nil(t, err) - assert.Equal(t, 2, len(segements)) - if segements[0] == "SegmentA" { - assert.Equal(t, "SegmentA", segements[0]) - assert.Equal(t, "SegmentB", segements[1]) - } else { - assert.Equal(t, "SegmentB", segements[0]) - assert.Equal(t, "SegmentA", segements[1]) - } -} - -func TestMinioDriver_PutRowsAndGetRows(t *testing.T) { - keys := [][]byte{[]byte("foo"), []byte("bar")} - values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")} - segments := []string{"segmentA", "segmentB"} - timestamps := []uint64{1, 2} - err = client.PutRows(ctx, keys, values, segments, timestamps) - assert.Nil(t, err) - - objects, err := client.GetRows(ctx, keys, timestamps) - assert.Nil(t, err) - assert.Equal(t, "The key is foo!", string(objects[0])) - assert.Equal(t, "The key is bar!", string(objects[1])) -} - -func TestMinioDriver_DeleteRows(t *testing.T) { - keys := [][]byte{[]byte("foo"), []byte("bar")} - timestamps := []uint64{3, 3} - err := client.DeleteRows(ctx, keys, timestamps) - assert.Nil(t, err) - - objects, err := client.GetRows(ctx, keys, timestamps) - assert.Nil(t, err) - assert.Nil(t, objects[0]) - assert.Nil(t, objects[1]) -} - -func TestMinioDriver_PutLogAndGetLog(t *testing.T) { - err = client.PutLog(ctx, []byte("insert"), []byte("This is insert log!"), 1, 11) - assert.Nil(t, err) - err = client.PutLog(ctx, []byte("delete"), []byte("This is delete log!"), 2, 10) - assert.Nil(t, err) - err = client.PutLog(ctx, []byte("update"), []byte("This is update log!"), 3, 9) - assert.Nil(t, err) - err = client.PutLog(ctx, []byte("select"), []byte("This is select log!"), 4, 8) - assert.Nil(t, err) - - channels := []int{5, 8, 9, 10, 11, 12, 13} - logValues, err := client.GetLog(ctx, 0, 5, channels) - assert.Nil(t, err) - assert.Equal(t, "This is select log!", string(logValues[0])) - assert.Equal(t, "This is update log!", string(logValues[1])) - assert.Equal(t, "This is delete log!", string(logValues[2])) - assert.Equal(t, "This is insert log!", string(logValues[3])) -} - -func TestMinioDriver_Segment(t *testing.T) { - err := client.PutSegmentIndex(ctx, "segmentA", []byte("This is segmentA's index!")) - assert.Nil(t, err) - - segmentIndex, err := client.GetSegmentIndex(ctx, "segmentA") - assert.Equal(t, "This is segmentA's index!", string(segmentIndex)) - assert.Nil(t, err) - - err = client.DeleteSegmentIndex(ctx, "segmentA") - assert.Nil(t, err) -} - -func TestMinioDriver_SegmentDL(t *testing.T) { - err := client.PutSegmentDL(ctx, "segmentB", []byte("This is segmentB's delete log!")) - assert.Nil(t, err) - - segmentDL, err := client.GetSegmentDL(ctx, "segmentB") - assert.Nil(t, err) - assert.Equal(t, "This is segmentB's delete log!", string(segmentDL)) - - err = client.DeleteSegmentDL(ctx, "segmentB") - assert.Nil(t, err) -} diff --git a/internal/storage/internal/tikv/codec/codec.go b/internal/storage/internal/tikv/codec/codec.go deleted file mode 100644 index ca09296097e8778ca1023c3395801e7cb0e99669..0000000000000000000000000000000000000000 --- a/internal/storage/internal/tikv/codec/codec.go +++ /dev/null @@ -1,62 +0,0 @@ -package codec - -import ( - "encoding/binary" - "errors" - - "github.com/tikv/client-go/codec" -) - -var ( - Delimiter = byte('_') - DelimiterPlusOne = Delimiter + 0x01 - DeleteMark = byte('d') - SegmentIndexMark = byte('i') - SegmentDLMark = byte('d') -) - -// EncodeKey append timestamp, delimiter, and suffix string -// to one slice key. -// Note: suffix string should not contains Delimiter -func EncodeKey(key []byte, timestamp uint64, suffix string) []byte { - //TODO: should we encode key to memory comparable - ret := EncodeDelimiter(key, Delimiter) - ret = codec.EncodeUintDesc(ret, timestamp) - return append(ret, suffix...) -} - -func DecodeKey(key []byte) ([]byte, uint64, string, error) { - if len(key) < 8 { - return nil, 0, "", errors.New("insufficient bytes to decode value") - } - - lenDeKey := 0 - for i := len(key) - 1; i > 0; i-- { - if key[i] == Delimiter { - lenDeKey = i - break - } - } - - if lenDeKey == 0 || lenDeKey+8 > len(key) { - return nil, 0, "", errors.New("insufficient bytes to decode value") - } - - tsBytes := key[lenDeKey+1 : lenDeKey+9] - ts := binary.BigEndian.Uint64(tsBytes) - suffix := string(key[lenDeKey+9:]) - key = key[:lenDeKey-1] - return key, ^ts, suffix, nil -} - -// EncodeDelimiter append a delimiter byte to slice b, and return the appended slice. -func EncodeDelimiter(b []byte, delimiter byte) []byte { - return append(b, delimiter) -} - -func EncodeSegment(segName []byte, segType byte) []byte { - segmentKey := []byte("segment") - segmentKey = append(segmentKey, Delimiter) - segmentKey = append(segmentKey, segName...) - return append(segmentKey, Delimiter, segType) -} diff --git a/internal/storage/internal/tikv/tikv_store.go b/internal/storage/internal/tikv/tikv_store.go deleted file mode 100644 index 5ecf8936f675f574f262219013001db4519f77cd..0000000000000000000000000000000000000000 --- a/internal/storage/internal/tikv/tikv_store.go +++ /dev/null @@ -1,389 +0,0 @@ -package tikvdriver - -import ( - "context" - "errors" - "strconv" - "strings" - - "github.com/tikv/client-go/config" - "github.com/tikv/client-go/rawkv" - . "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv/codec" - . "github.com/zilliztech/milvus-distributed/internal/storage/type" - storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -func keyAddOne(key Key) Key { - if key == nil { - return nil - } - lenKey := len(key) - ret := make(Key, lenKey) - copy(ret, key) - ret[lenKey-1] += 0x01 - return ret -} - -type tikvEngine struct { - client *rawkv.Client - conf config.Config -} - -func (e tikvEngine) Put(ctx context.Context, key Key, value Value) error { - return e.client.Put(ctx, key, value) -} - -func (e tikvEngine) BatchPut(ctx context.Context, keys []Key, values []Value) error { - return e.client.BatchPut(ctx, keys, values) -} - -func (e tikvEngine) Get(ctx context.Context, key Key) (Value, error) { - return e.client.Get(ctx, key) -} - -func (e tikvEngine) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) (keys []Key, values []Value, err error) { - startKey := prefix - endKey := keyAddOne(prefix) - limit := e.conf.Raw.MaxScanLimit - for { - ks, vs, err := e.Scan(ctx, startKey, endKey, limit, keyOnly) - if err != nil { - return keys, values, err - } - keys = append(keys, ks...) - values = append(values, vs...) - if len(ks) < limit { - break - } - // update the start key, and exclude the start key - startKey = append(ks[len(ks)-1], '\000') - } - return -} - -func (e tikvEngine) Scan(ctx context.Context, startKey Key, endKey Key, limit int, keyOnly bool) ([]Key, []Value, error) { - return e.client.Scan(ctx, startKey, endKey, limit, rawkv.ScanOption{KeyOnly: keyOnly}) -} - -func (e tikvEngine) Delete(ctx context.Context, key Key) error { - return e.client.Delete(ctx, key) -} - -func (e tikvEngine) DeleteByPrefix(ctx context.Context, prefix Key) error { - startKey := prefix - endKey := keyAddOne(prefix) - return e.client.DeleteRange(ctx, startKey, endKey) -} - -func (e tikvEngine) DeleteRange(ctx context.Context, startKey Key, endKey Key) error { - return e.client.DeleteRange(ctx, startKey, endKey) -} - -func (e tikvEngine) Close() error { - return e.client.Close() -} - -type TikvStore struct { - engine *tikvEngine -} - -func NewTikvStore(ctx context.Context, option storagetype.Option) (*TikvStore, error) { - - conf := config.Default() - client, err := rawkv.NewClient(ctx, []string{option.TikvAddress}, conf) - if err != nil { - return nil, err - } - return &TikvStore{ - &tikvEngine{ - client: client, - conf: conf, - }, - }, nil -} - -func (s *TikvStore) Name() string { - return "TiKV storage" -} - -func (s *TikvStore) put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error { - return s.engine.Put(ctx, EncodeKey(key, timestamp, suffix), value) -} - -func (s *TikvStore) scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) { - panic("implement me") -} - -func (s *TikvStore) scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) { - panic("implement me") -} - -func (s *TikvStore) scan(ctx context.Context, key Key, start Timestamp, end Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) { - //startKey := EncodeKey(key, start, "") - //endKey := EncodeKey(EncodeDelimiter(key, DelimiterPlusOne), end, "") - //return s.engine.Scan(ctx, startKey, endKey, -1, keyOnly) - panic("implement me") -} - -func (s *TikvStore) deleteLE(ctx context.Context, key Key, timestamp Timestamp) error { - panic("implement me") -} - -func (s *TikvStore) deleteGE(ctx context.Context, key Key, timestamp Timestamp) error { - panic("implement me") -} - -func (s *TikvStore) deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error { - panic("implement me") -} - -func (s *TikvStore) GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) { - startKey := EncodeKey(key, timestamp, "") - endKey := EncodeDelimiter(key, DelimiterPlusOne) - keys, values, err := s.engine.Scan(ctx, startKey, endKey, 1, false) - if err != nil || keys == nil { - return nil, err - } - _, _, suffix, err := DecodeKey(keys[0]) - if err != nil { - return nil, err - } - // key is marked deleted - if suffix == string(DeleteMark) { - return nil, nil - } - return values[0], nil -} - -// TODO: how to spilt keys to some batches -var batchSize = 100 - -type kvPair struct { - key Key - value Value - err error -} - -func batchKeys(keys []Key) [][]Key { - keysLen := len(keys) - numBatch := (keysLen-1)/batchSize + 1 - batches := make([][]Key, numBatch) - - for i := 0; i < numBatch; i++ { - batchStart := i * batchSize - batchEnd := batchStart + batchSize - // the last batch - if i == numBatch-1 { - batchEnd = keysLen - } - batches[i] = keys[batchStart:batchEnd] - } - return batches -} - -func (s *TikvStore) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) { - if len(keys) != len(timestamps) { - return nil, errors.New("the len of keys is not equal to the len of timestamps") - } - - batches := batchKeys(keys) - ch := make(chan kvPair, len(keys)) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - for n, b := range batches { - batch := b - numBatch := n - go func() { - for i, key := range batch { - select { - case <-ctx.Done(): - return - default: - v, err := s.GetRow(ctx, key, timestamps[numBatch*batchSize+i]) - ch <- kvPair{ - key: key, - value: v, - err: err, - } - } - } - }() - } - - var err error - var values []Value - kvMap := make(map[string]Value) - for i := 0; i < len(keys); i++ { - kv := <-ch - if kv.err != nil { - cancel() - if err == nil { - err = kv.err - } - } - kvMap[string(kv.key)] = kv.value - } - for _, key := range keys { - values = append(values, kvMap[string(key)]) - } - return values, err -} - -func (s *TikvStore) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error { - return s.put(ctx, key, value, timestamp, segment) -} - -func (s *TikvStore) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error { - if len(keys) != len(values) { - return errors.New("the len of keys is not equal to the len of values") - } - if len(keys) != len(timestamps) { - return errors.New("the len of keys is not equal to the len of timestamps") - } - - encodedKeys := make([]Key, len(keys)) - for i, key := range keys { - encodedKeys[i] = EncodeKey(key, timestamps[i], segments[i]) - } - return s.engine.BatchPut(ctx, encodedKeys, values) -} - -func (s *TikvStore) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error { - return s.put(ctx, key, Value{0x00}, timestamp, string(DeleteMark)) -} - -func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error { - encodeKeys := make([]Key, len(keys)) - values := make([]Value, len(keys)) - for i, key := range keys { - encodeKeys[i] = EncodeKey(key, timestamps[i], string(DeleteMark)) - values[i] = Value{0x00} - } - return s.engine.BatchPut(ctx, encodeKeys, values) -} - -//func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamp Timestamp) error { -// batches := batchKeys(keys) -// ch := make(chan error, len(batches)) -// ctx, cancel := context.WithCancel(ctx) -// -// for _, b := range batches { -// batch := b -// go func() { -// for _, key := range batch { -// select { -// case <-ctx.Done(): -// return -// default: -// ch <- s.DeleteRow(ctx, key, timestamp) -// } -// } -// }() -// } -// -// var err error -// for i := 0; i < len(keys); i++ { -// if e := <-ch; e != nil { -// cancel() -// if err == nil { -// err = e -// } -// } -// } -// return err -//} - -func (s *TikvStore) PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error { - suffix := string(EncodeDelimiter(key, DelimiterPlusOne)) + strconv.Itoa(channel) - return s.put(ctx, Key("log"), value, timestamp, suffix) -} - -func (s *TikvStore) GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) (logs []Value, err error) { - key := Key("log") - startKey := EncodeKey(key, end, "") - endKey := EncodeKey(key, start, "") - // TODO: use for loop to ensure get all keys - keys, values, err := s.engine.Scan(ctx, startKey, endKey, s.engine.conf.Raw.MaxScanLimit, false) - if err != nil || keys == nil { - return nil, err - } - - for i, key := range keys { - _, _, suffix, err := DecodeKey(key) - log := values[i] - if err != nil { - return logs, err - } - - // no channels filter - if len(channels) == 0 { - logs = append(logs, log) - } - slice := strings.Split(suffix, string(DelimiterPlusOne)) - channel, err := strconv.Atoi(slice[len(slice)-1]) - if err != nil { - panic(err) - } - for _, item := range channels { - if item == channel { - logs = append(logs, log) - break - } - } - } - return -} - -func (s *TikvStore) GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) { - return s.engine.Get(ctx, EncodeSegment([]byte(segment), SegmentIndexMark)) -} - -func (s *TikvStore) PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error { - return s.engine.Put(ctx, EncodeSegment([]byte(segment), SegmentIndexMark), index) -} - -func (s *TikvStore) DeleteSegmentIndex(ctx context.Context, segment string) error { - return s.engine.Delete(ctx, EncodeSegment([]byte(segment), SegmentIndexMark)) -} - -func (s *TikvStore) GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) { - return s.engine.Get(ctx, EncodeSegment([]byte(segment), SegmentDLMark)) -} - -func (s *TikvStore) PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error { - return s.engine.Put(ctx, EncodeSegment([]byte(segment), SegmentDLMark), log) -} - -func (s *TikvStore) DeleteSegmentDL(ctx context.Context, segment string) error { - return s.engine.Delete(ctx, EncodeSegment([]byte(segment), SegmentDLMark)) -} - -func (s *TikvStore) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) { - keys, _, err := s.engine.GetByPrefix(ctx, EncodeDelimiter(key, Delimiter), true) - if err != nil { - return nil, err - } - segmentsSet := map[string]bool{} - for _, key := range keys { - _, ts, segment, err := DecodeKey(key) - if err != nil { - panic("must no error") - } - if ts <= timestamp && segment != string(DeleteMark) { - segmentsSet[segment] = true - } - } - - var segments []string - for k, v := range segmentsSet { - if v { - segments = append(segments, k) - } - } - return segments, err -} - -func (s *TikvStore) Close() error { - return s.engine.Close() -} diff --git a/internal/storage/internal/tikv/tikv_test.go b/internal/storage/internal/tikv/tikv_test.go deleted file mode 100644 index 4e69d14d2c3079f7251123666881206a9af03d03..0000000000000000000000000000000000000000 --- a/internal/storage/internal/tikv/tikv_test.go +++ /dev/null @@ -1,293 +0,0 @@ -package tikvdriver - -import ( - "bytes" - "context" - "fmt" - "math" - "os" - "sort" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - . "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv/codec" - . "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -//var store TikvStore -var store *TikvStore -var option = Option{TikvAddress: "localhost:2379"} - -func TestMain(m *testing.M) { - store, _ = NewTikvStore(context.Background(), option) - exitCode := m.Run() - _ = store.Close() - os.Exit(exitCode) -} - -func TestTikvEngine_Prefix(t *testing.T) { - ctx := context.Background() - prefix := Key("key") - engine := store.engine - value := Value("value") - - // Put some key with same prefix - key := prefix - err := engine.Put(ctx, key, value) - require.Nil(t, err) - key = EncodeKey(prefix, 0, "") - err = engine.Put(ctx, key, value) - assert.Nil(t, err) - - // Get by prefix - ks, _, err := engine.GetByPrefix(ctx, prefix, true) - assert.Equal(t, 2, len(ks)) - assert.Nil(t, err) - - // Delete by prefix - err = engine.DeleteByPrefix(ctx, prefix) - assert.Nil(t, err) - ks, _, err = engine.GetByPrefix(ctx, prefix, true) - assert.Equal(t, 0, len(ks)) - assert.Nil(t, err) - - //Test large amount keys - num := engine.conf.Raw.MaxScanLimit + 1 - keys := make([]Key, num) - values := make([]Value, num) - for i := 0; i < num; i++ { - key = EncodeKey(prefix, uint64(i), "") - keys[i] = key - values[i] = value - } - err = engine.BatchPut(ctx, keys, values) - assert.Nil(t, err) - - ks, _, err = engine.GetByPrefix(ctx, prefix, true) - assert.Nil(t, err) - assert.Equal(t, num, len(ks)) - err = engine.DeleteByPrefix(ctx, prefix) - assert.Nil(t, err) -} - -func TestTikvStore_Row(t *testing.T) { - ctx := context.Background() - key := Key("key") - - // Add same row with different timestamp - err := store.PutRow(ctx, key, Value("value0"), "segment0", 0) - assert.Nil(t, err) - err = store.PutRow(ctx, key, Value("value1"), "segment0", 2) - assert.Nil(t, err) - - // Get most recent row using key and timestamp - v, err := store.GetRow(ctx, key, 3) - assert.Nil(t, err) - assert.Equal(t, Value("value1"), v) - v, err = store.GetRow(ctx, key, 2) - assert.Nil(t, err) - assert.Equal(t, Value("value1"), v) - v, err = store.GetRow(ctx, key, 1) - assert.Nil(t, err) - assert.Equal(t, Value("value0"), v) - - // Add a different row, but with same prefix - key1 := Key("key_y") - err = store.PutRow(ctx, key1, Value("valuey"), "segment0", 2) - assert.Nil(t, err) - - // Get most recent row using key and timestamp - v, err = store.GetRow(ctx, key, 3) - assert.Nil(t, err) - assert.Equal(t, Value("value1"), v) - v, err = store.GetRow(ctx, key1, 3) - assert.Nil(t, err) - assert.Equal(t, Value("valuey"), v) - - // Delete a row - err = store.DeleteRow(ctx, key, 4) - assert.Nil(t, err) - v, err = store.GetRow(ctx, key, 5) - assert.Nil(t, err) - assert.Nil(t, v) - - // Clear test data - err = store.engine.DeleteByPrefix(ctx, key) - assert.Nil(t, err) - k, va, err := store.engine.GetByPrefix(ctx, key, false) - assert.Nil(t, err) - assert.Nil(t, k) - assert.Nil(t, va) -} - -func TestTikvStore_BatchRow(t *testing.T) { - ctx := context.Background() - - // Prepare test data - size := 0 - var testKeys []Key - var testValues []Value - var segments []string - var timestamps []Timestamp - for i := 0; size/store.engine.conf.Raw.MaxBatchPutSize < 1; i++ { - key := fmt.Sprint("key", i) - size += len(key) - testKeys = append(testKeys, []byte(key)) - value := fmt.Sprint("value", i) - size += len(value) - testValues = append(testValues, []byte(value)) - segments = append(segments, "test") - v, err := store.GetRow(ctx, Key(key), math.MaxUint64) - assert.Nil(t, v) - assert.Nil(t, err) - } - - // Batch put rows - for range testKeys { - timestamps = append(timestamps, 1) - } - err := store.PutRows(ctx, testKeys, testValues, segments, timestamps) - assert.Nil(t, err) - - // Batch get rows - for i := range timestamps { - timestamps[i] = 2 - } - checkValues, err := store.GetRows(ctx, testKeys, timestamps) - assert.NotNil(t, checkValues) - assert.Nil(t, err) - assert.Equal(t, len(checkValues), len(testValues)) - for i := range testKeys { - assert.Equal(t, testValues[i], checkValues[i]) - } - - // Delete all test rows - for i := range timestamps { - timestamps[i] = math.MaxUint64 - } - err = store.DeleteRows(ctx, testKeys, timestamps) - assert.Nil(t, err) - // Ensure all test row is deleted - for i := range timestamps { - timestamps[i] = math.MaxUint64 - } - checkValues, err = store.GetRows(ctx, testKeys, timestamps) - assert.Nil(t, err) - for _, value := range checkValues { - assert.Nil(t, value) - } - - // Clean test data - err = store.engine.DeleteByPrefix(ctx, Key("key")) - assert.Nil(t, err) -} - -func TestTikvStore_GetSegments(t *testing.T) { - ctx := context.Background() - key := Key("key") - - // Put rows - err := store.PutRow(ctx, key, Value{0}, "a", 1) - assert.Nil(t, err) - err = store.PutRow(ctx, key, Value{0}, "a", 2) - assert.Nil(t, err) - err = store.PutRow(ctx, key, Value{0}, "c", 3) - assert.Nil(t, err) - - // Get segments - segs, err := store.GetSegments(ctx, key, 2) - assert.Nil(t, err) - assert.Equal(t, 1, len(segs)) - assert.Equal(t, "a", segs[0]) - - segs, err = store.GetSegments(ctx, key, 3) - assert.Nil(t, err) - assert.Equal(t, 2, len(segs)) - - // Clean test data - err = store.engine.DeleteByPrefix(ctx, key) - assert.Nil(t, err) -} - -func TestTikvStore_Log(t *testing.T) { - ctx := context.Background() - - // Put some log - err := store.PutLog(ctx, Key("key1"), Value("value1"), 1, 1) - assert.Nil(t, err) - err = store.PutLog(ctx, Key("key1"), Value("value1_1"), 1, 2) - assert.Nil(t, err) - err = store.PutLog(ctx, Key("key2"), Value("value2"), 2, 1) - assert.Nil(t, err) - - // Check log - log, err := store.GetLog(ctx, 0, 2, []int{1, 2}) - if err != nil { - panic(err) - } - sort.Slice(log, func(i, j int) bool { - return bytes.Compare(log[i], log[j]) == -1 - }) - assert.Equal(t, log[0], Value("value1")) - assert.Equal(t, log[1], Value("value1_1")) - assert.Equal(t, log[2], Value("value2")) - - // Delete test data - err = store.engine.DeleteByPrefix(ctx, Key("log")) - assert.Nil(t, err) -} - -func TestTikvStore_SegmentIndex(t *testing.T) { - ctx := context.Background() - - // Put segment index - err := store.PutSegmentIndex(ctx, "segment0", []byte("index0")) - assert.Nil(t, err) - err = store.PutSegmentIndex(ctx, "segment1", []byte("index1")) - assert.Nil(t, err) - - // Get segment index - index, err := store.GetSegmentIndex(ctx, "segment0") - assert.Nil(t, err) - assert.Equal(t, []byte("index0"), index) - index, err = store.GetSegmentIndex(ctx, "segment1") - assert.Nil(t, err) - assert.Equal(t, []byte("index1"), index) - - // Delete segment index - err = store.DeleteSegmentIndex(ctx, "segment0") - assert.Nil(t, err) - err = store.DeleteSegmentIndex(ctx, "segment1") - assert.Nil(t, err) - index, err = store.GetSegmentIndex(ctx, "segment0") - assert.Nil(t, err) - assert.Nil(t, index) -} - -func TestTikvStore_DeleteSegmentDL(t *testing.T) { - ctx := context.Background() - - // Put segment delete log - err := store.PutSegmentDL(ctx, "segment0", []byte("index0")) - assert.Nil(t, err) - err = store.PutSegmentDL(ctx, "segment1", []byte("index1")) - assert.Nil(t, err) - - // Get segment delete log - index, err := store.GetSegmentDL(ctx, "segment0") - assert.Nil(t, err) - assert.Equal(t, []byte("index0"), index) - index, err = store.GetSegmentDL(ctx, "segment1") - assert.Nil(t, err) - assert.Equal(t, []byte("index1"), index) - - // Delete segment delete log - err = store.DeleteSegmentDL(ctx, "segment0") - assert.Nil(t, err) - err = store.DeleteSegmentDL(ctx, "segment1") - assert.Nil(t, err) - index, err = store.GetSegmentDL(ctx, "segment0") - assert.Nil(t, err) - assert.Nil(t, index) -} diff --git a/internal/storage/storage.go b/internal/storage/storage.go deleted file mode 100644 index 67e9e44e8939075caaaa8aefa17cccc12e786bbd..0000000000000000000000000000000000000000 --- a/internal/storage/storage.go +++ /dev/null @@ -1,39 +0,0 @@ -package storage - -import ( - "context" - "errors" - - S3Driver "github.com/zilliztech/milvus-distributed/internal/storage/internal/S3" - minIODriver "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio" - tikvDriver "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv" - storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type" -) - -func NewStore(ctx context.Context, option storagetype.Option) (storagetype.Store, error) { - var err error - var store storagetype.Store - switch option.Type { - case storagetype.TIKVDriver: - store, err = tikvDriver.NewTikvStore(ctx, option) - if err != nil { - panic(err.Error()) - } - return store, nil - case storagetype.MinIODriver: - store, err = minIODriver.NewMinioDriver(ctx, option) - if err != nil { - //panic(err.Error()) - return nil, err - } - return store, nil - case storagetype.S3DRIVER: - store, err = S3Driver.NewS3Driver(ctx, option) - if err != nil { - //panic(err.Error()) - return nil, err - } - return store, nil - } - return nil, errors.New("unsupported driver") -} diff --git a/internal/storage/type/storagetype.go b/internal/storage/type/storagetype.go deleted file mode 100644 index 9549a106e573e96589ec9f415379a0a5b7144c2b..0000000000000000000000000000000000000000 --- a/internal/storage/type/storagetype.go +++ /dev/null @@ -1,79 +0,0 @@ -package storagetype - -import ( - "context" - - "github.com/zilliztech/milvus-distributed/internal/util/typeutil" -) - -type Key = []byte -type Value = []byte -type Timestamp = typeutil.Timestamp -type DriverType = string -type SegmentIndex = []byte -type SegmentDL = []byte - -type Option struct { - Type DriverType - TikvAddress string - BucketName string -} - -const ( - MinIODriver DriverType = "MinIO" - TIKVDriver DriverType = "TIKV" - S3DRIVER DriverType = "S3" -) - -/* -type Store interface { - Get(ctx context.Context, key Key, timestamp Timestamp) (Value, error) - BatchGet(ctx context.Context, keys [] Key, timestamp Timestamp) ([]Value, error) - Set(ctx context.Context, key Key, v Value, timestamp Timestamp) error - BatchSet(ctx context.Context, keys []Key, v []Value, timestamp Timestamp) error - Delete(ctx context.Context, key Key, timestamp Timestamp) error - BatchDelete(ctx context.Context, keys []Key, timestamp Timestamp) error - Close() error -} -*/ - -type storeEngine interface { - Put(ctx context.Context, key Key, value Value) error - Get(ctx context.Context, key Key) (Value, error) - GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) - Scan(ctx context.Context, startKey Key, endKey Key, limit int, keyOnly bool) ([]Key, []Value, error) - Delete(ctx context.Context, key Key) error - DeleteByPrefix(ctx context.Context, prefix Key) error - DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error -} - -type Store interface { - //put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error - //scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) - //scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) - //deleteLE(ctx context.Context, key Key, timestamp Timestamp) error - //deleteGE(ctx context.Context, key Key, timestamp Timestamp) error - //deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error - - GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) - GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) - - PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error - PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error - - GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) - - DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error - DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error - - PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error - GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error) - - GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) - PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error - DeleteSegmentIndex(ctx context.Context, segment string) error - - GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) - PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error - DeleteSegmentDL(ctx context.Context, segment string) error -} diff --git a/internal/writenode/flow_graph_dd_node.go b/internal/writenode/flow_graph_dd_node.go index 7dd8e1fd433fda223cab267bbae39bb07ce2122c..2c77e398c53eb7a03b7afb90600b85f574dd1628 100644 --- a/internal/writenode/flow_graph_dd_node.go +++ b/internal/writenode/flow_graph_dd_node.go @@ -9,9 +9,6 @@ import ( "strconv" "github.com/golang/protobuf/proto" - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" - "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/kv" miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" @@ -360,19 +357,16 @@ func newDDNode(ctx context.Context, outCh chan *ddlFlushSyncMsg) *ddNode { partitionRecords: make(map[UniqueID]interface{}), } - minIOEndPoint := Params.MinioAddress - minIOAccessKeyID := Params.MinioAccessKeyID - minIOSecretAccessKey := Params.MinioSecretAccessKey - minIOUseSSL := Params.MinioUseSSL - minIOClient, err := minio.New(minIOEndPoint, &minio.Options{ - Creds: credentials.NewStaticV4(minIOAccessKeyID, minIOSecretAccessKey, ""), - Secure: minIOUseSSL, - }) - if err != nil { - panic(err) - } bucketName := Params.MinioBucketName - minioKV, err := miniokv.NewMinIOKV(ctx, minIOClient, bucketName) + option := &miniokv.Option{ + Address: Params.MinioAddress, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSL, + BucketName: bucketName, + CreateBucket: true, + } + minioKV, err := miniokv.NewMinIOKV(ctx, option) if err != nil { panic(err) } diff --git a/internal/writenode/flow_graph_insert_buffer_node.go b/internal/writenode/flow_graph_insert_buffer_node.go index 0c001ee55d802fc3de1316de49679ce99a986e76..2ebc1300f8a0fc1c1f56c5a55a203f3b821f87d4 100644 --- a/internal/writenode/flow_graph_insert_buffer_node.go +++ b/internal/writenode/flow_graph_insert_buffer_node.go @@ -11,8 +11,6 @@ import ( "unsafe" "github.com/golang/protobuf/proto" - "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/kv" etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd" @@ -610,20 +608,17 @@ func newInsertBufferNode(ctx context.Context, outCh chan *insertFlushSyncMsg) *i kvClient := etcdkv.NewEtcdKV(cli, MetaRootPath) // MinIO - minioendPoint := Params.MinioAddress - miniioAccessKeyID := Params.MinioAccessKeyID - miniioSecretAccessKey := Params.MinioSecretAccessKey - minioUseSSL := Params.MinioUseSSL - minioBucketName := Params.MinioBucketName - - minioClient, err := minio.New(minioendPoint, &minio.Options{ - Creds: credentials.NewStaticV4(miniioAccessKeyID, miniioSecretAccessKey, ""), - Secure: minioUseSSL, - }) - if err != nil { - panic(err) + + option := &miniokv.Option{ + Address: Params.MinioAddress, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSL, + CreateBucket: true, + BucketName: Params.MinioBucketName, } - minIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, minioBucketName) + + minIOKV, err := miniokv.NewMinIOKV(ctx, option) if err != nil { panic(err) } diff --git a/tools/core_gen/all_generate.py b/tools/core_gen/all_generate.py index 499022d583bf44465868530a4ddee5aec991523b..a4be00cead810a06504156c058b1755086ac1899 100755 --- a/tools/core_gen/all_generate.py +++ b/tools/core_gen/all_generate.py @@ -58,6 +58,10 @@ if __name__ == "__main__": 'visitor_name': "ExecExprVisitor", "parameter_name": 'expr', }, + { + 'visitor_name': "VerifyExprVisitor", + "parameter_name": 'expr', + }, ], 'PlanNode': [ { @@ -68,7 +72,10 @@ if __name__ == "__main__": 'visitor_name': "ExecPlanNodeVisitor", "parameter_name": 'node', }, - + { + 'visitor_name': "VerifyPlanNodeVisitor", + "parameter_name": 'node', + }, ] } extract_extra_body(visitor_info, query_path) diff --git a/tools/core_gen/templates/visitor_derived.h b/tools/core_gen/templates/visitor_derived.h index cda1ea1a74c15cb90ca13fe26b879bdf1cd021cc..5fe2be775ddaebaa817d4819413898d71d4468a1 100644 --- a/tools/core_gen/templates/visitor_derived.h +++ b/tools/core_gen/templates/visitor_derived.h @@ -13,7 +13,7 @@ #include "@@base_visitor@@.h" namespace @@namespace@@ { -class @@visitor_name@@ : @@base_visitor@@ { +class @@visitor_name@@ : public @@base_visitor@@ { public: @@body@@