diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 01c9789e5d32304a11b37865d5976e3d20b32bbe..6ef5e5647fd605a0d96d64eed317a146b2c575b3 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -32,6 +32,12 @@ NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info) { } } +void +DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info) { + auto info = (LoadIndexInfo*)c_load_index_info; + delete info; +} + CStatus AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* c_index_key, const char* c_index_value) { try { diff --git a/internal/core/src/segcore/load_index_c.h b/internal/core/src/segcore/load_index_c.h index 88985a1e81c2f42b799a3d974463d7337ff2410d..3508018ed48b6087cbb8587b283a9adfdebf03b6 100644 --- a/internal/core/src/segcore/load_index_c.h +++ b/internal/core/src/segcore/load_index_c.h @@ -25,6 +25,9 @@ typedef void* CBinarySet; CStatus NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info); +void +DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info); + CStatus AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* index_key, const char* index_value); diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index a9e2d5d5219309f9e16251b881c3da81f11b6d1c..6a95a622b27301f084f699a08452734bc8382481 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -176,8 +176,9 @@ FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) { CStatus UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) { - auto load_index_info = (LoadIndexInfo*)c_load_index_info; try { + auto segment = (milvus::segcore::SegmentBase*)c_segment; + auto load_index_info = (LoadIndexInfo*)c_load_index_info; auto status = CStatus(); status.error_code = Success; status.error_msg = ""; @@ -189,7 +190,6 @@ UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) { return status; } } - ////////////////////////////////////////////////////////////////// int diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 2aaff331d876326424dbae729468305d9308d128..dedfd71c2987734d35872728e265a116cd7c4c17 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -685,6 +685,49 @@ TEST(CApiTest, Reduce) { DeleteSegment(segment); } +TEST(CApiTest, LoadIndexInfo) { + // generator index + constexpr auto DIM = 16; + constexpr auto K = 10; + + auto N = 1024 * 10; + auto [raw_data, timestamps, uids] = generate_data(N); + auto indexing = std::make_shared<milvus::knowhere::IVFPQ>(); + auto conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::m, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, 0}}; + + auto database = milvus::knowhere::GenDataset(N, DIM, raw_data.data()); + indexing->Train(database, conf); + indexing->AddWithoutIds(database, conf); + EXPECT_EQ(indexing->Count(), N); + EXPECT_EQ(indexing->Dim(), DIM); + auto binary_set = indexing->Serialize(conf); + CBinarySet c_binary_set = (CBinarySet)&binary_set; + + void* c_load_index_info = nullptr; + auto status = NewLoadIndexInfo(&c_load_index_info); + assert(status.error_code == Success); + std::string index_param_key1 = "index_type"; + std::string index_param_value1 = "IVF_PQ"; + status = AppendIndexParam(c_load_index_info, index_param_key1.data(), index_param_value1.data()); + std::string index_param_key2 = "index_mode"; + std::string index_param_value2 = "cpu"; + status = AppendIndexParam(c_load_index_info, index_param_key2.data(), index_param_value2.data()); + assert(status.error_code == Success); + std::string field_name = "field0"; + status = AppendFieldInfo(c_load_index_info, field_name.data(), 0); + assert(status.error_code == Success); + status = AppendIndex(c_load_index_info, c_binary_set); + assert(status.error_code == Success); + DeleteLoadIndexInfo(c_load_index_info); +} + TEST(CApiTest, LoadIndex_Search) { // generator index constexpr auto DIM = 16; diff --git a/internal/querynode/client/client.go b/internal/querynode/client/client.go index 5cbfc2a08ecb77b296ca2165595884343ddd912e..d1de811a9102aaf1763c4cac0da782f17134bbd8 100644 --- a/internal/querynode/client/client.go +++ b/internal/querynode/client/client.go @@ -4,6 +4,7 @@ import ( "context" "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" ) @@ -21,18 +22,28 @@ func NewLoadIndexClient(ctx context.Context, pulsarAddress string, loadIndexChan } } -func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, indexParam map[string]string) error { - // TODO:: add indexParam to proto +func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, fieldName string, indexParams map[string]string) error { baseMsg := msgstream.BaseMsg{ BeginTimestamp: 0, EndTimestamp: 0, HashValues: []uint32{0}, } + + var indexParamsKV []*commonpb.KeyValuePair + for indexParam := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: indexParam, + Value: indexParams[indexParam], + }) + } + loadIndexRequest := internalPb.LoadIndex{ - MsgType: internalPb.MsgType_kLoadIndex, - SegmentID: segmentID, - FieldID: fieldID, - IndexPaths: indexPaths, + MsgType: internalPb.MsgType_kLoadIndex, + SegmentID: segmentID, + FieldName: fieldName, + FieldID: fieldID, + IndexPaths: indexPaths, + IndexParams: indexParamsKV, } loadIndexMsg := &msgstream.LoadIndexMsg{ diff --git a/internal/querynode/load_index_info.go b/internal/querynode/load_index_info.go index 362b687b764c32dedf6fad5f0332b765120a4549..0187a04e1333fdea658261b3b971acce89f3f302 100644 --- a/internal/querynode/load_index_info.go +++ b/internal/querynode/load_index_info.go @@ -18,7 +18,7 @@ type LoadIndexInfo struct { cLoadIndexInfo C.CLoadIndexInfo } -func NewLoadIndexInfo() (*LoadIndexInfo, error) { +func newLoadIndexInfo() (*LoadIndexInfo, error) { var cLoadIndexInfo C.CLoadIndexInfo status := C.NewLoadIndexInfo(&cLoadIndexInfo) errorCode := status.error_code @@ -31,7 +31,11 @@ func NewLoadIndexInfo() (*LoadIndexInfo, error) { return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil } -func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) error { +func deleteLoadIndexInfo(info *LoadIndexInfo) { + C.DeleteLoadIndexInfo(info.cLoadIndexInfo) +} + +func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error { cIndexKey := C.CString(indexKey) cIndexValue := C.CString(indexValue) status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue) @@ -45,7 +49,7 @@ func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) er return nil } -func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error { +func (li *LoadIndexInfo) appendFieldInfo(fieldName string, fieldID int64) error { cFieldName := C.CString(fieldName) cFieldID := C.long(fieldID) status := C.AppendFieldInfo(li.cLoadIndexInfo, cFieldName, cFieldID) @@ -59,7 +63,7 @@ func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error return nil } -func (li *LoadIndexInfo) AppendIndex(bytesIndex [][]byte, indexKeys []string) error { +func (li *LoadIndexInfo) appendIndex(bytesIndex [][]byte, indexKeys []string) error { var cBinarySet C.CBinarySet status := C.NewBinarySet(&cBinarySet) diff --git a/internal/querynode/load_index_info_test.go b/internal/querynode/load_index_info_test.go new file mode 100644 index 0000000000000000000000000000000000000000..95261c7002eb7c9012598afab0fe59d66d69840d --- /dev/null +++ b/internal/querynode/load_index_info_test.go @@ -0,0 +1,36 @@ +package querynode + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" +) + +func TestLoadIndexInfo(t *testing.T) { + indexParams := make([]*commonpb.KeyValuePair, 0) + indexParams = append(indexParams, &commonpb.KeyValuePair{ + Key: "index_type", + Value: "IVF_PQ", + }) + indexParams = append(indexParams, &commonpb.KeyValuePair{ + Key: "index_mode", + Value: "cpu", + }) + + indexBytes := make([][]byte, 0) + indexValue := make([]byte, 10) + indexBytes = append(indexBytes, indexValue) + indexPaths := make([]string, 0) + indexPaths = append(indexPaths, "index-0") + + loadIndexInfo, err := newLoadIndexInfo() + assert.Nil(t, err) + for _, indexParam := range indexParams { + loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value) + } + loadIndexInfo.appendFieldInfo("field0", 0) + loadIndexInfo.appendIndex(indexBytes, indexPaths) + + deleteLoadIndexInfo(loadIndexInfo) +} diff --git a/internal/querynode/load_index_service.go b/internal/querynode/load_index_service.go index a2eaac7bfeca7f7ca51c4a4443ecc95d74fe5d6e..3c9a1a5454dbcda599ea1c6d75b3b728432fd9cd 100644 --- a/internal/querynode/load_index_service.go +++ b/internal/querynode/load_index_service.go @@ -107,17 +107,28 @@ func (lis *loadIndexService) start() { log.Println("type assertion failed for LoadIndexMsg") continue } - /* TODO: debug - // 1. use msg's index paths to get index bytes - indexBuffer := lis.loadIndex(indexMsg.IndexPaths) - // 2. use index bytes and index path to update segment - err := lis.updateSegmentIndex(indexBuffer, indexMsg.IndexPaths, indexMsg.SegmentID) - if err != nil { - log.Println(err) - continue - } - */ - // 3. update segment index stats + //// 1. use msg's index paths to get index bytes + //var indexBuffer [][]byte + //var err error + //fn := func() error { + // indexBuffer, err = lis.loadIndex(indexMsg.IndexPaths) + // if err != nil { + // return err + // } + // return nil + //} + //err = msgstream.Retry(5, time.Millisecond*200, fn) + //if err != nil { + // log.Println(err) + // continue + //} + //// 2. use index bytes and index path to update segment + //err = lis.updateSegmentIndex(indexBuffer, indexMsg) + //if err != nil { + // log.Println(err) + // continue + //} + //3. update segment index stats err := lis.updateSegmentIndexStats(indexMsg) if err != nil { log.Println(err) @@ -216,7 +227,7 @@ func (lis *loadIndexService) updateSegmentIndexStats(indexMsg *msgstream.LoadInd return nil } -func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte { +func (lis *loadIndexService) loadIndex(indexPath []string) ([][]byte, error) { index := make([][]byte, 0) for _, path := range indexPath { @@ -224,13 +235,12 @@ func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte { binarySetKey := filepath.Base(path) indexPiece, err := (*lis.client).Load(binarySetKey) if err != nil { - log.Println(err) - return nil + return nil, err } index = append(index, []byte(indexPiece)) } - return index + return index, nil } func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMsg *msgstream.LoadIndexMsg) error { @@ -239,21 +249,22 @@ func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMs return err } - loadIndexInfo, err := NewLoadIndexInfo() + loadIndexInfo, err := newLoadIndexInfo() + defer deleteLoadIndexInfo(loadIndexInfo) if err != nil { return err } - err = loadIndexInfo.AppendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID) + err = loadIndexInfo.appendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID) if err != nil { return err } for _, indexParam := range loadIndexMsg.IndexParams { - err = loadIndexInfo.AppendIndexParam(indexParam.Key, indexParam.Value) + err = loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value) if err != nil { return err } } - err = loadIndexInfo.AppendIndex(bytesIndex, loadIndexMsg.IndexPaths) + err = loadIndexInfo.appendIndex(bytesIndex, loadIndexMsg.IndexPaths) if err != nil { return err } diff --git a/internal/querynode/load_index_service_test.go b/internal/querynode/load_index_service_test.go index 49d215670395a8b0dc6a45b58a3165ce05de8fac..2a59eb30e51e2b0cdb8a4e62f935adf716b77691 100644 --- a/internal/querynode/load_index_service_test.go +++ b/internal/querynode/load_index_service_test.go @@ -1,6 +1,7 @@ package querynode import ( + "context" "math" "math/rand" "sort" @@ -11,8 +12,26 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/querynode/client" ) +func TestLoadIndexClient_LoadIndex(t *testing.T) { + pulsarURL := Params.PulsarAddress + loadIndexChannels := Params.LoadIndexChannelNames + loadIndexClient := client.NewLoadIndexClient(context.Background(), pulsarURL, loadIndexChannels) + + loadIndexPath := "collection0-segment0-field0" + loadIndexPaths := make([]string, 0) + loadIndexPaths = append(loadIndexPaths, loadIndexPath) + + indexParams := make(map[string]string) + indexParams["index_type"] = "IVF_PQ" + indexParams["index_mode"] = "cpu" + + loadIndexClient.LoadIndex(loadIndexPaths, 0, 0, "field0", indexParams) + loadIndexClient.Close() +} + func TestLoadIndexService_PulsarAddress(t *testing.T) { node := newQueryNode() collectionID := rand.Int63n(1000000) @@ -125,24 +144,38 @@ func TestLoadIndexService_PulsarAddress(t *testing.T) { statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) statsMs.Start() - receiveMsg := msgstream.MsgStream(statsMs).Consume() - assert.NotNil(t, receiveMsg) - assert.NotEqual(t, len(receiveMsg.Msgs), 0) - statsMsg, ok := receiveMsg.Msgs[0].(*msgstream.QueryNodeStatsMsg) - assert.Equal(t, ok, true) - assert.Equal(t, len(statsMsg.FieldStats), 1) - fieldStats0 := statsMsg.FieldStats[0] - assert.Equal(t, fieldStats0.FieldID, fieldID) - assert.Equal(t, fieldStats0.CollectionID, collectionID) - assert.Equal(t, len(fieldStats0.IndexStats), 1) - indexStats0 := fieldStats0.IndexStats[0] - - params := indexStats0.IndexParams - // sort index params by key - sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key }) - indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams) - assert.Equal(t, indexEqual, true) + findFiledStats := false + for { + receiveMsg := msgstream.MsgStream(statsMs).Consume() + assert.NotNil(t, receiveMsg) + assert.NotEqual(t, len(receiveMsg.Msgs), 0) + + for _, msg := range receiveMsg.Msgs { + statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) + if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { + continue + } + findFiledStats = true + assert.Equal(t, ok, true) + assert.Equal(t, len(statsMsg.FieldStats), 1) + fieldStats0 := statsMsg.FieldStats[0] + assert.Equal(t, fieldStats0.FieldID, fieldID) + assert.Equal(t, fieldStats0.CollectionID, collectionID) + assert.Equal(t, len(fieldStats0.IndexStats), 1) + indexStats0 := fieldStats0.IndexStats[0] + params := indexStats0.IndexParams + // sort index params by key + sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key }) + indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams) + assert.Equal(t, indexEqual, true) + } + + if findFiledStats { + break + } + } + defer assert.Equal(t, findFiledStats, true) <-node.queryNodeLoopCtx.Done() node.Close() }