diff --git a/internal/core/src/indexbuilder/IndexWrapper.cpp b/internal/core/src/indexbuilder/IndexWrapper.cpp index cb53d66bfd94a40d6759f42bf12b0ca74d6287f4..f06c349cf26f0a0e5274b4eaabc2a90d7bcef4fa 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.cpp +++ b/internal/core/src/indexbuilder/IndexWrapper.cpp @@ -9,6 +9,9 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include <map> +#include <exception> + #include "pb/index_cgo_msg.pb.h" #include "knowhere/index/vector_index/VecIndexFactory.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" @@ -24,11 +27,14 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria parse(); - auto index_type = index_config_["index_type"].get<std::string>(); - auto index_mode = index_config_["index_mode"].get<std::string>(); - auto mode = index_mode == "CPU" ? knowhere::IndexMode::MODE_CPU : knowhere::IndexMode::MODE_GPU; + std::map<std::string, knowhere::IndexMode> mode_map = {{"CPU", knowhere::IndexMode::MODE_CPU}, + {"GPU", knowhere::IndexMode::MODE_GPU}}; + auto type = get_config_by_name<std::string>("index_type"); + auto mode = get_config_by_name<std::string>("index_mode"); + auto index_type = type.has_value() ? type.value() : knowhere::IndexEnum::INDEX_FAISS_IVFPQ; + auto index_mode = mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU; - index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type, mode); + index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type, index_mode); } void @@ -36,38 +42,90 @@ IndexWrapper::parse() { namespace indexcgo = milvus::proto::indexcgo; bool serialized_success; - indexcgo::BinarySet type_config; + indexcgo::TypeParams type_config; serialized_success = type_config.ParseFromString(type_params_); Assert(serialized_success); - indexcgo::BinarySet index_config; + indexcgo::IndexParams index_config; serialized_success = index_config.ParseFromString(index_params_); Assert(serialized_success); - for (auto i = 0; i < type_config.datas_size(); ++i) { - auto binary = type_config.datas(i); - type_config_[binary.key()] = binary.value(); + for (auto i = 0; i < type_config.params_size(); ++i) { + auto type_param = type_config.params(i); + auto key = type_param.key(); + auto value = type_param.value(); + type_config_[key] = value; + config_[key] = value; + } + + for (auto i = 0; i < index_config.params_size(); ++i) { + auto index_param = index_config.params(i); + auto key = index_param.key(); + auto value = index_param.value(); + index_config_[key] = value; + config_[key] = value; + } + + if (!config_.contains(milvus::knowhere::meta::DIM)) { + // should raise exception here? + throw "dim must be specific in type params or index params!"; + } else { + auto dim = config_[milvus::knowhere::meta::DIM].get<std::string>(); + config_[milvus::knowhere::meta::DIM] = std::stoi(dim); + } + + if (!config_.contains(milvus::knowhere::meta::TOPK)) { + } else { + auto topk = config_[milvus::knowhere::meta::TOPK].get<std::string>(); + config_[milvus::knowhere::meta::TOPK] = std::stoi(topk); + } + + if (!config_.contains(milvus::knowhere::IndexParams::nlist)) { + } else { + auto nlist = config_[milvus::knowhere::IndexParams::nlist].get<std::string>(); + config_[milvus::knowhere::IndexParams::nlist] = std::stoi(nlist); } - for (auto i = 0; i < index_config.datas_size(); ++i) { - auto binary = index_config.datas(i); - index_config_[binary.key()] = binary.value(); + if (!config_.contains(milvus::knowhere::IndexParams::nprobe)) { + } else { + auto nprobe = config_[milvus::knowhere::IndexParams::nprobe].get<std::string>(); + config_[milvus::knowhere::IndexParams::nprobe] = std::stoi(nprobe); } - // TODO: parse from type_params & index_params - auto dim = 128; - config_ = knowhere::Config{ - {knowhere::meta::DIM, dim}, {knowhere::IndexParams::nlist, 100}, - {knowhere::IndexParams::nprobe, 4}, {knowhere::IndexParams::m, 4}, - {knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, knowhere::Metric::L2}, - {knowhere::meta::DEVICEID, 0}, - }; + if (!config_.contains(milvus::knowhere::IndexParams::nbits)) { + } else { + auto nbits = config_[milvus::knowhere::IndexParams::nbits].get<std::string>(); + config_[milvus::knowhere::IndexParams::nbits] = std::stoi(nbits); + } + + if (!config_.contains(milvus::knowhere::IndexParams::m)) { + } else { + auto m = config_[milvus::knowhere::IndexParams::m].get<std::string>(); + config_[milvus::knowhere::IndexParams::m] = std::stoi(m); + } + + if (!config_.contains(milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) { + config_[milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE] = 4; + } else { + auto slice_size = config_[milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<std::string>(); + config_[milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE] = std::stoi(slice_size); + } +} + +template <typename T> +std::optional<T> +IndexWrapper::get_config_by_name(std::string name) { + if (config_.contains(name)) { + return {config_[name].get<T>()}; + } + return std::nullopt; } int64_t IndexWrapper::dim() { - // TODO: get from config - return 128; + auto dimension = get_config_by_name<int64_t>(milvus::knowhere::meta::DIM); + Assert(dimension.has_value()); + return (dimension.value()); } void @@ -88,24 +146,23 @@ IndexWrapper::Serialize() { for (auto [key, value] : binarySet.binary_map_) { auto binary = ret.add_datas(); binary->set_key(key); - binary->set_value(reinterpret_cast<char*>(value->data.get())); + binary->set_value(value->data.get(), value->size); } std::string serialized_data; auto ok = ret.SerializeToString(&serialized_data); Assert(ok); - auto data = new char[serialized_data.length() + 1]; + auto data = new char[serialized_data.length()]; memcpy(data, serialized_data.c_str(), serialized_data.length()); - data[serialized_data.length()] = 0; - return {data, static_cast<int32_t>(serialized_data.length() + 1)}; + return {data, static_cast<int32_t>(serialized_data.length())}; } void -IndexWrapper::Load(const char* serialized_sliced_blob_buffer) { +IndexWrapper::Load(const char* serialized_sliced_blob_buffer, int32_t size) { namespace indexcgo = milvus::proto::indexcgo; - auto data = std::string(serialized_sliced_blob_buffer); + auto data = std::string(serialized_sliced_blob_buffer, size); indexcgo::BinarySet blob_buffer; auto ok = blob_buffer.ParseFromString(data); diff --git a/internal/core/src/indexbuilder/IndexWrapper.h b/internal/core/src/indexbuilder/IndexWrapper.h index ab326fc2eafe77db6738712d54ade3eda806f4ca..49954cc209e83d95af803a72654e3233b63f500d 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.h +++ b/internal/core/src/indexbuilder/IndexWrapper.h @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include <string> +#include <optional> #include "knowhere/index/vector_index/VecIndex.h" namespace milvus { @@ -34,12 +35,16 @@ class IndexWrapper { Serialize(); void - Load(const char* serialized_sliced_blob_buffer); + Load(const char* serialized_sliced_blob_buffer, int32_t size); private: void parse(); + template <typename T> + std::optional<T> + get_config_by_name(std::string name); + private: knowhere::VecIndexPtr index_ = nullptr; std::string type_params_; diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 29a52526296e9e0e8e9d4e06697caf6de65e9330..51e4a9cf461126589425a2d1ba7044f3f893ec85 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -18,7 +18,7 @@ CIndex CreateIndex(const char* serialized_type_params, const char* serialized_index_params) { auto index = std::make_unique<milvus::indexbuilder::IndexWrapper>(serialized_type_params, serialized_index_params); - return (void*)(index.release()); + return index.release(); } void @@ -45,7 +45,7 @@ SerializeToSlicedBuffer(CIndex index, int32_t* buffer_size) { } void -LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer) { +LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, int32_t size) { auto cIndex = (milvus::indexbuilder::IndexWrapper*)index; - cIndex->Load(serialized_sliced_blob_buffer); + cIndex->Load(serialized_sliced_blob_buffer, size); } diff --git a/internal/core/src/indexbuilder/index_c.h b/internal/core/src/indexbuilder/index_c.h index 6918c8743ae2f184c5101d422c0f263ff17c5d2d..128bdc8ac2bf6a41ab7992c522e2f77d919275ef 100644 --- a/internal/core/src/indexbuilder/index_c.h +++ b/internal/core/src/indexbuilder/index_c.h @@ -47,7 +47,7 @@ char* SerializeToSlicedBuffer(CIndex index, int32_t* buffer_size); void -LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer); +LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, int32_t size); #ifdef __cplusplus }; diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 29fef50fb70a9c26b49e4f1f5b7d573d1263049f..afef74785ec502cabf8663c3c06d74516cd244e0 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -12,6 +12,7 @@ set(MILVUS_TEST_FILES test_expr.cpp test_bitmap.cpp test_binary.cpp + test_index_wrapper.cpp ) add_executable(all_tests ${MILVUS_TEST_FILES} @@ -21,6 +22,7 @@ target_link_libraries(all_tests gtest gtest_main milvus_segcore + milvus_indexbuilder knowhere log pthread diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..635950a76acb756d154000f83e912d50b49e5b02 --- /dev/null +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -0,0 +1,168 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include <tuple> +#include <random> +#include <gtest/gtest.h> + +#include "pb/index_cgo_msg.pb.h" +#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h" +#include "index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "indexbuilder/IndexWrapper.h" +#include "indexbuilder/index_c.h" +#include "test_utils/DataGen.h" +#include "faiss/MetricType.h" +#include "index/knowhere/knowhere/index/vector_index/VecIndexFactory.h" + +namespace indexcgo = milvus::proto::indexcgo; + +constexpr int64_t DIM = 4; +constexpr int64_t NB = 10000; +constexpr int64_t NQ = 10; +constexpr int64_t K = 4; +constexpr auto METRIC_TYPE = milvus::knowhere::Metric::L2; + +namespace { +auto +generate_conf(const milvus::knowhere::IndexType& type) { + if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { + return milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::IndexParams::nlist, 100}, + // {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::m, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4}, + }; + } + return milvus::knowhere::Config(); +} + +auto +generate_params() { + indexcgo::TypeParams type_params; + indexcgo::IndexParams index_params; + + auto configs = generate_conf(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ); + for (auto& [key, value] : configs.items()) { + auto param = index_params.add_params(); + auto value_str = value.is_string() ? value.get<std::string>() : value.dump(); + param->set_key(key); + param->set_value(value_str); + } + + return std::make_tuple(type_params, index_params); +} +} // namespace + +TEST(IndexWrapperTest, Constructor) { + auto [type_params, index_params] = generate_params(); + std::string type_params_str, index_params_str; + bool ok; + + ok = type_params.SerializeToString(&type_params_str); + assert(ok); + ok = index_params.SerializeToString(&index_params_str); + assert(ok); + + auto index = + std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); +} + +TEST(IndexWrapperTest, Dim) { + auto [type_params, index_params] = generate_params(); + std::string type_params_str, index_params_str; + bool ok; + + ok = type_params.SerializeToString(&type_params_str); + assert(ok); + ok = index_params.SerializeToString(&index_params_str); + assert(ok); + + auto index = + std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); + + ASSERT_EQ(index->dim(), DIM); +} + +TEST(IndexWrapperTest, BuildWithoutIds) { + auto [type_params, index_params] = generate_params(); + std::string type_params_str, index_params_str; + bool ok; + + ok = type_params.SerializeToString(&type_params_str); + assert(ok); + ok = index_params.SerializeToString(&index_params_str); + assert(ok); + + auto index = + std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); + + auto schema = std::make_shared<milvus::Schema>(); + schema->AddField("fakevec", milvus::engine::DataType::VECTOR_FLOAT, DIM, faiss::MetricType::METRIC_L2); + auto dataset = milvus::segcore::DataGen(schema, NB); + auto xb_data = dataset.get_col<float>(0); + + index->BuildWithoutIds(milvus::knowhere::GenDataset(NB, DIM, xb_data.data())); +} + +TEST(IndexWrapperTest, Load) { + auto type = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ; + auto index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(type); + auto conf = generate_conf(type); + auto schema = std::make_shared<milvus::Schema>(); + schema->AddField("fakevec", milvus::engine::DataType::VECTOR_FLOAT, DIM, faiss::MetricType::METRIC_L2); + auto dataset = milvus::segcore::DataGen(schema, NB); + auto xb_data = dataset.get_col<float>(0); + auto ds = milvus::knowhere::GenDataset(NB, DIM, xb_data.data()); + index->Train(ds, conf); + index->AddWithoutIds(ds, conf); + // std::vector<int64_t> ids(NB); + // std::iota(ids.begin(), ids.end(), 0); // range(0, NB) + // auto ds = milvus::knowhere::GenDatasetWithIds(NB, DIM, xb_data.data(), ids.data()); + // index->Train(ds, conf); + // index->Add(ds, conf); + auto binary_set = index->Serialize(conf); + auto copy_index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(type); + copy_index->Load(binary_set); +} + +TEST(IndexWrapperTest, Codec) { + auto [type_params, index_params] = generate_params(); + std::string type_params_str, index_params_str; + bool ok; + + ok = type_params.SerializeToString(&type_params_str); + assert(ok); + ok = index_params.SerializeToString(&index_params_str); + assert(ok); + + auto index = + std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); + + auto schema = std::make_shared<milvus::Schema>(); + schema->AddField("fakevec", milvus::engine::DataType::VECTOR_FLOAT, DIM, faiss::MetricType::METRIC_L2); + auto dataset = milvus::segcore::DataGen(schema, NB); + auto xb_data = dataset.get_col<float>(0); + + index->BuildWithoutIds(milvus::knowhere::GenDataset(NB, DIM, xb_data.data())); + + auto binary = index->Serialize(); + auto copy_index = + std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str()); + copy_index->Load(binary.data, binary.size); + ASSERT_EQ(copy_index->dim(), copy_index->dim()); + auto copy_binary = copy_index->Serialize(); + ASSERT_EQ(binary.size, copy_binary.size); + ASSERT_EQ(strcmp(binary.data, copy_binary.data), 0); +} diff --git a/internal/indexbuilder/index.go b/internal/indexbuilder/index.go index 278677ec45ccd9027d3bd69e03d4da830357ff1b..fadaa2b356a03c3a626e47cc15c0318068c86160 100644 --- a/internal/indexbuilder/index.go +++ b/internal/indexbuilder/index.go @@ -18,14 +18,11 @@ import ( "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/indexcgopb" + "github.com/zilliztech/milvus-distributed/internal/storage" ) // TODO: use storage.Blob instead later -// type Blob = storage.Blob -type Blob struct { - Key string - Value []byte -} +type Blob = storage.Blob type Index interface { Serialize() ([]*Blob, error) @@ -75,9 +72,9 @@ func (index *CIndex) Load(blobs []*Blob) error { /* void - LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer); + LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, int32_t size); */ - C.LoadFromSlicedBuffer(index.indexPtr, (*C.char)(unsafe.Pointer(&datas[0]))) + C.LoadFromSlicedBuffer(index.indexPtr, (*C.char)(unsafe.Pointer(&datas[0])), (C.int32_t)(len(datas))) return nil } diff --git a/internal/indexbuilder/task.go b/internal/indexbuilder/task.go index 5819e565db90573ab5bd59f3191f54f32a503704..f4e73aa27552978936419d2317fd54ba2815308a 100644 --- a/internal/indexbuilder/task.go +++ b/internal/indexbuilder/task.go @@ -198,13 +198,7 @@ func (it *IndexBuildTask) Execute() error { }, nil } getStorageBlobs := func(blobs []*Blob) []*storage.Blob { - // when storage.Blob.Key & storage.Blob.Value is visible, - // use `return blobs` - ret := make([]*storage.Blob, 0) - for _, blob := range blobs { - ret = append(ret, storage.NewBlob(blob.Key, blob.Value)) - } - return ret + return blobs } toLoadDataPaths := it.indexMeta.Req.GetDataPaths() @@ -259,7 +253,7 @@ func (it *IndexBuildTask) Execute() error { it.savePaths = make([]string, 0) for _, blob := range serializedIndexBlobs { - key, value := blob.GetKey(), blob.GetValue() + key, value := blob.Key, blob.Value savePath := getSavePathByKey(key) err := saveBlob(savePath, value) if err != nil {