diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index d6f97c72fc4e25502d2d99f78f65539791a3f57c..e97223744738c697de636a7c4111a70b7c0f6199 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -69,28 +69,30 @@ Search(CSegmentInterface c_segment, uint64_t* timestamps, int num_groups, CQueryResult* result) { - auto status = CStatus(); + auto segment = (milvus::segcore::SegmentInterface*)c_segment; + auto plan = (milvus::query::Plan*)c_plan; + std::vector<const milvus::query::PlaceholderGroup*> placeholder_groups; + for (int i = 0; i < num_groups; ++i) { + placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]); + } + auto query_result = std::make_unique<milvus::QueryResult>(); + + auto status = CStatus(); try { - auto segment = (milvus::segcore::SegmentInterface*)c_segment; - auto plan = (milvus::query::Plan*)c_plan; - std::vector<const milvus::query::PlaceholderGroup*> placeholder_groups; - for (int i = 0; i < num_groups; ++i) { - placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]); - } *query_result = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups); if (plan->plan_node_->query_info_.metric_type_ != milvus::MetricType::METRIC_INNER_PRODUCT) { for (auto& dis : query_result->result_distances_) { dis *= -1; } } - *result = query_result.release(); status.error_code = Success; status.error_msg = ""; } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } + *result = query_result.release(); // result_ids and result_distances have been allocated memory in goLang, // so we don't need to malloc here. diff --git a/internal/proxynode/insert_channels.go b/internal/proxynode/insert_channels.go index c480319d1ed4be268abaaab147769db2d826623b..73465f9928a5705d036e1fdfc12b64cd39989a20 100644 --- a/internal/proxynode/insert_channels.go +++ b/internal/proxynode/insert_channels.go @@ -75,6 +75,7 @@ type InsertChannelsMap struct { insertChannels [][]string // it's a little confusing to use []string as the key of map insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now droppedBitMap []int // 0 -> normal, 1 -> dropped + usageHistogram []int // message stream can be closed only when the use count is zero mtx sync.RWMutex nodeInstance *NodeImpl } @@ -93,6 +94,7 @@ func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []st for loc, existedChannels := range m.insertChannels { if m.droppedBitMap[loc] == 0 && SortedSliceEqual(existedChannels, channels) { m.collectionID2InsertChannels[collID] = loc + m.usageHistogram[loc]++ return nil } } @@ -108,6 +110,7 @@ func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []st stream.Start() m.insertMsgStreams = append(m.insertMsgStreams, stream) m.droppedBitMap = append(m.droppedBitMap, 0) + m.usageHistogram = append(m.usageHistogram, 1) return nil } @@ -123,7 +126,14 @@ func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error { if m.droppedBitMap[loc] != 0 { return errors.New("insert message stream already closed") } - m.insertMsgStreams[loc].Close() + if m.usageHistogram[loc] <= 0 { + return errors.New("insert message stream already closed") + } + + m.usageHistogram[loc]-- + if m.usageHistogram[loc] <= 0 { + m.insertMsgStreams[loc].Close() + } log.Print("close insert message stream ...") m.droppedBitMap[loc] = 1 @@ -164,11 +174,28 @@ func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgSt return m.insertMsgStreams[loc], nil } +func (m *InsertChannelsMap) closeAllMsgStream() { + m.mtx.Lock() + defer m.mtx.Unlock() + + for _, stream := range m.insertMsgStreams { + stream.Close() + } + + m.collectionID2InsertChannels = make(map[UniqueID]int) + m.insertChannels = make([][]string, 0) + m.insertMsgStreams = make([]msgstream.MsgStream, 0) + m.droppedBitMap = make([]int, 0) + m.usageHistogram = make([]int, 0) +} + func newInsertChannelsMap(node *NodeImpl) *InsertChannelsMap { return &InsertChannelsMap{ collectionID2InsertChannels: make(map[UniqueID]int), insertChannels: make([][]string, 0), insertMsgStreams: make([]msgstream.MsgStream, 0), + droppedBitMap: make([]int, 0), + usageHistogram: make([]int, 0), nodeInstance: node, } } diff --git a/internal/proxynode/proxy_node.go b/internal/proxynode/proxy_node.go index a871558c55266c618d896122166dba7fa53913bb..62b06b79db43c1dff6482acb84bdce6909eaa4d5 100644 --- a/internal/proxynode/proxy_node.go +++ b/internal/proxynode/proxy_node.go @@ -285,6 +285,7 @@ func (node *NodeImpl) Start() error { func (node *NodeImpl) Stop() error { node.cancel() + globalInsertChannelsMap.closeAllMsgStream() node.tsoAllocator.Close() node.idAllocator.Close() node.segAssigner.Close() diff --git a/internal/querynode/client/client.go b/internal/querynode/client/client.go new file mode 100644 index 0000000000000000000000000000000000000000..7445301c2b0f84fa17e3b9804b0eb4e008962188 --- /dev/null +++ b/internal/querynode/client/client.go @@ -0,0 +1,69 @@ +package client + +import ( + "context" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" +) + +type Client struct { + inputStream *msgstream.MsgStream +} + +func NewQueryNodeClient(ctx context.Context, pulsarAddress string, loadIndexChannels []string) *Client { + loadIndexStream := pulsarms.NewPulsarMsgStream(ctx, 0) + loadIndexStream.SetPulsarClient(pulsarAddress) + loadIndexStream.CreatePulsarProducers(loadIndexChannels) + var input msgstream.MsgStream = loadIndexStream + return &Client{ + inputStream: &input, + } +} + +func (c *Client) Close() { + (*c.inputStream).Close() +} + +func (c *Client) LoadIndex(indexPaths []string, + segmentID int64, + fieldID int64, + fieldName string, + indexParams map[string]string) error { + baseMsg := msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{0}, + } + + var indexParamsKV []*commonpb.KeyValuePair + for key, value := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } + + loadIndexRequest := internalpb2.LoadIndex{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kLoadIndex, + }, + SegmentID: segmentID, + FieldName: fieldName, + FieldID: fieldID, + IndexPaths: indexPaths, + IndexParams: indexParamsKV, + } + + loadIndexMsg := &msgstream.LoadIndexMsg{ + BaseMsg: baseMsg, + LoadIndex: loadIndexRequest, + } + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, loadIndexMsg) + + err := (*c.inputStream).Produce(&msgPack) + return err +} diff --git a/internal/querynode/collection_replica.go b/internal/querynode/collection_replica.go index 901ce5db57c395932ee75833b711aa2b6a604293..a8957bcaef14c106ffd1953db4764c7cf749d2fa 100644 --- a/internal/querynode/collection_replica.go +++ b/internal/querynode/collection_replica.go @@ -41,7 +41,6 @@ type collectionReplica interface { getCollectionByID(collectionID UniqueID) (*Collection, error) getCollectionByName(collectionName string) (*Collection, error) hasCollection(collectionID UniqueID) bool - getVecFieldsByCollectionID(collectionID UniqueID) (map[int64]string, error) // partition // Partition tags in different collections are not unique, @@ -67,8 +66,8 @@ type collectionReplica interface { removeSegment(segmentID UniqueID) error getSegmentByID(segmentID UniqueID) (*Segment, error) hasSegment(segmentID UniqueID) bool + getVecFieldsBySegmentID(segmentID UniqueID) (map[int64]string, error) getSealedSegments() ([]UniqueID, []UniqueID) - replaceGrowingSegmentBySealedSegment(segment *Segment) error freeAll() } @@ -175,29 +174,6 @@ func (colReplica *collectionReplicaImpl) hasCollection(collectionID UniqueID) bo return false } -func (colReplica *collectionReplicaImpl) getVecFieldsByCollectionID(collectionID UniqueID) (map[int64]string, error) { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - - col, err := colReplica.getCollectionByIDPrivate(collectionID) - if err != nil { - return nil, err - } - - vecFields := make(map[int64]string) - for _, field := range col.Schema().Fields { - if field.DataType == schemapb.DataType_VECTOR_BINARY || field.DataType == schemapb.DataType_VECTOR_FLOAT { - vecFields[field.FieldID] = field.Name - } - } - - if len(vecFields) <= 0 { - return nil, errors.New("no vector field in segment " + strconv.FormatInt(collectionID, 10)) - } - - return vecFields, nil -} - //----------------------------------------------------------------------------------------------------- partition func (colReplica *collectionReplicaImpl) getPartitionNum(collectionID UniqueID) (int, error) { colReplica.mu.RLock() @@ -508,10 +484,6 @@ func (colReplica *collectionReplicaImpl) removeSegment(segmentID UniqueID) error colReplica.mu.Lock() defer colReplica.mu.Unlock() - return colReplica.removeSegmentPrivate(segmentID) -} - -func (colReplica *collectionReplicaImpl) removeSegmentPrivate(segmentID UniqueID) error { var targetPartition *Partition var segmentIndex = -1 @@ -521,7 +493,6 @@ func (colReplica *collectionReplicaImpl) removeSegmentPrivate(segmentID UniqueID if s.ID() == segmentID { targetPartition = p segmentIndex = i - deleteSegment(colReplica.segments[s.ID()]) } } } @@ -562,6 +533,34 @@ func (colReplica *collectionReplicaImpl) hasSegment(segmentID UniqueID) bool { return ok } +func (colReplica *collectionReplicaImpl) getVecFieldsBySegmentID(segmentID UniqueID) (map[int64]string, error) { + colReplica.mu.RLock() + defer colReplica.mu.RUnlock() + + seg, err := colReplica.getSegmentByIDPrivate(segmentID) + if err != nil { + return nil, err + } + col, err2 := colReplica.getCollectionByIDPrivate(seg.collectionID) + if err2 != nil { + return nil, err2 + } + + vecFields := make(map[int64]string) + for _, field := range col.Schema().Fields { + if field.DataType == schemapb.DataType_VECTOR_BINARY || field.DataType == schemapb.DataType_VECTOR_FLOAT { + vecFields[field.FieldID] = field.Name + } + } + + if len(vecFields) <= 0 { + return nil, errors.New("no vector field in segment " + strconv.FormatInt(segmentID, 10)) + } + + // return map[fieldID]fieldName + return vecFields, nil +} + func (colReplica *collectionReplicaImpl) getSealedSegments() ([]UniqueID, []UniqueID) { colReplica.mu.RLock() defer colReplica.mu.RUnlock() @@ -578,28 +577,6 @@ func (colReplica *collectionReplicaImpl) getSealedSegments() ([]UniqueID, []Uniq return collectionIDs, segmentIDs } -func (colReplica *collectionReplicaImpl) replaceGrowingSegmentBySealedSegment(segment *Segment) error { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() - targetSegment, ok := colReplica.segments[segment.ID()] - if ok { - if targetSegment.segmentType != segTypeGrowing { - return nil - } - deleteSegment(targetSegment) - targetSegment = segment - } else { - // add segment - targetPartition, err := colReplica.getPartitionByIDPrivate(segment.collectionID, segment.partitionID) - if err != nil { - return err - } - targetPartition.segments = append(targetPartition.segments, segment) - colReplica.segments[segment.ID()] = segment - } - return nil -} - //----------------------------------------------------------------------------------------------------- func (colReplica *collectionReplicaImpl) freeAll() { colReplica.mu.Lock() @@ -611,7 +588,4 @@ func (colReplica *collectionReplicaImpl) freeAll() { for _, col := range colReplica.collections { deleteCollection(col) } - - colReplica.segments = make(map[UniqueID]*Segment) - colReplica.collections = make([]*Collection, 0) } diff --git a/internal/querynode/load_service.go b/internal/querynode/load_service.go index 285ea354446f4fd25daa0f5daadecf33f869c865..8f40b4ed3095d5848d855ea026c2481f4d79f453 100644 --- a/internal/querynode/load_service.go +++ b/internal/querynode/load_service.go @@ -11,15 +11,12 @@ import ( "strings" "time" - "github.com/zilliztech/milvus-distributed/internal/kv" minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio" "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/datapb" - "github.com/zilliztech/milvus-distributed/internal/proto/indexpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" - "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" "github.com/zilliztech/milvus-distributed/internal/storage" ) @@ -28,59 +25,83 @@ const indexCheckInterval = 1 type loadService struct { ctx context.Context cancel context.CancelFunc + client *minioKV.MinIOKV - replica collectionReplica + queryNodeID UniqueID + replica collectionReplica fieldIndexes map[string][]*internalpb2.IndexStats fieldStatsChan chan []*internalpb2.FieldStats - dmStream msgstream.MsgStream + loadIndexReqChan chan []msgstream.TsMsg + loadIndexMsgStream msgstream.MsgStream - masterClient MasterServiceInterface - dataClient DataServiceInterface - indexClient IndexServiceInterface - - kv kv.Base // minio kv - iCodec *storage.InsertCodec + segManager *segmentManager } -type loadIndex struct { - segmentID UniqueID - fieldID int64 - fieldName string - indexPaths []string +func (lis *loadService) consume() { + for { + select { + case <-lis.ctx.Done(): + return + default: + messages := lis.loadIndexMsgStream.Consume() + if messages == nil || len(messages.Msgs) <= 0 { + log.Println("null msg pack") + continue + } + lis.loadIndexReqChan <- messages.Msgs + } + } } -// -------------------------------------------- load index -------------------------------------------- // -func (s *loadService) start() { +func (lis *loadService) indexListener() { for { select { - case <-s.ctx.Done(): + case <-lis.ctx.Done(): return case <-time.After(indexCheckInterval * time.Second): - collectionIDs, segmentIDs := s.replica.getSealedSegments() - if len(collectionIDs) <= 0 { - continue - } - fmt.Println("do load index for segments:", segmentIDs) + collectionIDs, segmentIDs := lis.replica.getSealedSegments() for i := range collectionIDs { // we don't need index id yet - _, buildID, err := s.getIndexInfo(collectionIDs[i], segmentIDs[i]) + _, buildID, err := lis.segManager.getIndexInfo(collectionIDs[i], segmentIDs[i]) if err != nil { - indexPaths, err := s.getIndexPaths(buildID) + indexPaths, err := lis.segManager.getIndexPaths(buildID) if err != nil { log.Println(err) continue } - err = s.loadIndexDelayed(collectionIDs[i], segmentIDs[i], indexPaths) + err = lis.segManager.loadIndex(segmentIDs[i], indexPaths) if err != nil { log.Println(err) continue } } } + } + } +} + +func (lis *loadService) start() { + lis.loadIndexMsgStream.Start() + go lis.consume() + go lis.indexListener() + + for { + select { + case <-lis.ctx.Done(): + return + case messages := <-lis.loadIndexReqChan: + for _, msg := range messages { + err := lis.execute(msg) + if err != nil { + log.Println(err) + continue + } + } + // sendQueryNodeStats - err := s.sendQueryNodeStats() + err := lis.sendQueryNodeStats() if err != nil { log.Println(err) continue @@ -89,13 +110,17 @@ func (s *loadService) start() { } } -func (s *loadService) execute(l *loadIndex) error { +func (lis *loadService) execute(msg msgstream.TsMsg) error { + indexMsg, ok := msg.(*msgstream.LoadIndexMsg) + if !ok { + return errors.New("type assertion failed for LoadIndexMsg") + } // 1. use msg's index paths to get index bytes var err error var indexBuffer [][]byte var indexParams indexParam fn := func() error { - indexBuffer, indexParams, err = s.loadIndex(l.indexPaths) + indexBuffer, indexParams, err = lis.loadIndex(indexMsg.IndexPaths) if err != nil { return err } @@ -105,7 +130,7 @@ func (s *loadService) execute(l *loadIndex) error { if err != nil { return err } - ok, err := s.checkIndexReady(indexParams, l) + ok, err = lis.checkIndexReady(indexParams, indexMsg) if err != nil { return err } @@ -114,12 +139,12 @@ func (s *loadService) execute(l *loadIndex) error { return errors.New("") } // 2. use index bytes and index path to update segment - err = s.updateSegmentIndex(indexParams, indexBuffer, l) + err = lis.updateSegmentIndex(indexParams, indexBuffer, indexMsg) if err != nil { return err } //3. update segment index stats - err = s.updateSegmentIndexStats(indexParams, l) + err = lis.updateSegmentIndexStats(indexParams, indexMsg) if err != nil { return err } @@ -127,18 +152,21 @@ func (s *loadService) execute(l *loadIndex) error { return nil } -func (s *loadService) close() { - s.cancel() +func (lis *loadService) close() { + if lis.loadIndexMsgStream != nil { + lis.loadIndexMsgStream.Close() + } + lis.cancel() } -func (s *loadService) printIndexParams(index []*commonpb.KeyValuePair) { +func (lis *loadService) printIndexParams(index []*commonpb.KeyValuePair) { fmt.Println("=================================================") for i := 0; i < len(index); i++ { fmt.Println(index[i]) } } -func (s *loadService) indexParamsEqual(index1 []*commonpb.KeyValuePair, index2 []*commonpb.KeyValuePair) bool { +func (lis *loadService) indexParamsEqual(index1 []*commonpb.KeyValuePair, index2 []*commonpb.KeyValuePair) bool { if len(index1) != len(index2) { return false } @@ -154,11 +182,11 @@ func (s *loadService) indexParamsEqual(index1 []*commonpb.KeyValuePair, index2 [ return true } -func (s *loadService) fieldsStatsIDs2Key(collectionID UniqueID, fieldID UniqueID) string { +func (lis *loadService) fieldsStatsIDs2Key(collectionID UniqueID, fieldID UniqueID) string { return strconv.FormatInt(collectionID, 10) + "/" + strconv.FormatInt(fieldID, 10) } -func (s *loadService) fieldsStatsKey2IDs(key string) (UniqueID, UniqueID, error) { +func (lis *loadService) fieldsStatsKey2IDs(key string) (UniqueID, UniqueID, error) { ids := strings.Split(key, "/") if len(ids) != 2 { return 0, 0, errors.New("illegal fieldsStatsKey") @@ -174,14 +202,14 @@ func (s *loadService) fieldsStatsKey2IDs(key string) (UniqueID, UniqueID, error) return collectionID, fieldID, nil } -func (s *loadService) updateSegmentIndexStats(indexParams indexParam, l *loadIndex) error { - targetSegment, err := s.replica.getSegmentByID(l.segmentID) +func (lis *loadService) updateSegmentIndexStats(indexParams indexParam, indexMsg *msgstream.LoadIndexMsg) error { + targetSegment, err := lis.replica.getSegmentByID(indexMsg.SegmentID) if err != nil { return err } - fieldStatsKey := s.fieldsStatsIDs2Key(targetSegment.collectionID, l.fieldID) - _, ok := s.fieldIndexes[fieldStatsKey] + fieldStatsKey := lis.fieldsStatsIDs2Key(targetSegment.collectionID, indexMsg.FieldID) + _, ok := lis.fieldIndexes[fieldStatsKey] newIndexParams := make([]*commonpb.KeyValuePair, 0) for k, v := range indexParams { newIndexParams = append(newIndexParams, &commonpb.KeyValuePair{ @@ -193,38 +221,38 @@ func (s *loadService) updateSegmentIndexStats(indexParams indexParam, l *loadInd // sort index params by key sort.Slice(newIndexParams, func(i, j int) bool { return newIndexParams[i].Key < newIndexParams[j].Key }) if !ok { - s.fieldIndexes[fieldStatsKey] = make([]*internalpb2.IndexStats, 0) - s.fieldIndexes[fieldStatsKey] = append(s.fieldIndexes[fieldStatsKey], + lis.fieldIndexes[fieldStatsKey] = make([]*internalpb2.IndexStats, 0) + lis.fieldIndexes[fieldStatsKey] = append(lis.fieldIndexes[fieldStatsKey], &internalpb2.IndexStats{ IndexParams: newIndexParams, NumRelatedSegments: 1, }) } else { isNewIndex := true - for _, index := range s.fieldIndexes[fieldStatsKey] { - if s.indexParamsEqual(newIndexParams, index.IndexParams) { + for _, index := range lis.fieldIndexes[fieldStatsKey] { + if lis.indexParamsEqual(newIndexParams, index.IndexParams) { index.NumRelatedSegments++ isNewIndex = false } } if isNewIndex { - s.fieldIndexes[fieldStatsKey] = append(s.fieldIndexes[fieldStatsKey], + lis.fieldIndexes[fieldStatsKey] = append(lis.fieldIndexes[fieldStatsKey], &internalpb2.IndexStats{ IndexParams: newIndexParams, NumRelatedSegments: 1, }) } } - return targetSegment.setIndexParam(l.fieldID, newIndexParams) + return targetSegment.setIndexParam(indexMsg.FieldID, indexMsg.IndexParams) } -func (s *loadService) loadIndex(indexPath []string) ([][]byte, indexParam, error) { +func (lis *loadService) loadIndex(indexPath []string) ([][]byte, indexParam, error) { index := make([][]byte, 0) var indexParams indexParam for _, p := range indexPath { fmt.Println("load path = ", indexPath) - indexPiece, err := s.kv.Load(p) + indexPiece, err := (*lis.client).Load(p) if err != nil { return nil, nil, err } @@ -251,8 +279,8 @@ func (s *loadService) loadIndex(indexPath []string) ([][]byte, indexParam, error return index, indexParams, nil } -func (s *loadService) updateSegmentIndex(indexParams indexParam, bytesIndex [][]byte, l *loadIndex) error { - segment, err := s.replica.getSegmentByID(l.segmentID) +func (lis *loadService) updateSegmentIndex(indexParams indexParam, bytesIndex [][]byte, loadIndexMsg *msgstream.LoadIndexMsg) error { + segment, err := lis.replica.getSegmentByID(loadIndexMsg.SegmentID) if err != nil { return err } @@ -262,7 +290,7 @@ func (s *loadService) updateSegmentIndex(indexParams indexParam, bytesIndex [][] if err != nil { return err } - err = loadIndexInfo.appendFieldInfo(l.fieldName, l.fieldID) + err = loadIndexInfo.appendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID) if err != nil { return err } @@ -272,17 +300,17 @@ func (s *loadService) updateSegmentIndex(indexParams indexParam, bytesIndex [][] return err } } - err = loadIndexInfo.appendIndex(bytesIndex, l.indexPaths) + err = loadIndexInfo.appendIndex(bytesIndex, loadIndexMsg.IndexPaths) if err != nil { return err } return segment.updateSegmentIndex(loadIndexInfo) } -func (s *loadService) sendQueryNodeStats() error { +func (lis *loadService) sendQueryNodeStats() error { resultFieldsStats := make([]*internalpb2.FieldStats, 0) - for fieldStatsKey, indexStats := range s.fieldIndexes { - colID, fieldID, err := s.fieldsStatsKey2IDs(fieldStatsKey) + for fieldStatsKey, indexStats := range lis.fieldIndexes { + colID, fieldID, err := lis.fieldsStatsKey2IDs(fieldStatsKey) if err != nil { return err } @@ -294,306 +322,21 @@ func (s *loadService) sendQueryNodeStats() error { resultFieldsStats = append(resultFieldsStats, &fieldStats) } - s.fieldStatsChan <- resultFieldsStats + lis.fieldStatsChan <- resultFieldsStats fmt.Println("sent field stats") return nil } -func (s *loadService) checkIndexReady(indexParams indexParam, l *loadIndex) (bool, error) { - segment, err := s.replica.getSegmentByID(l.segmentID) +func (lis *loadService) checkIndexReady(indexParams indexParam, loadIndexMsg *msgstream.LoadIndexMsg) (bool, error) { + segment, err := lis.replica.getSegmentByID(loadIndexMsg.SegmentID) if err != nil { return false, err } - if !segment.matchIndexParam(l.fieldID, indexParams) { + if !segment.matchIndexParam(loadIndexMsg.FieldID, indexParams) { return false, nil } return true, nil -} - -func (s *loadService) getIndexInfo(collectionID UniqueID, segmentID UniqueID) (UniqueID, UniqueID, error) { - req := &milvuspb.DescribeSegmentRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kDescribeSegment, - }, - CollectionID: collectionID, - SegmentID: segmentID, - } - response, err := s.masterClient.DescribeSegment(req) - if err != nil { - return 0, 0, err - } - return response.IndexID, response.BuildID, nil -} - -// -------------------------------------------- load segment -------------------------------------------- // -func (s *loadService) loadSegment(collectionID UniqueID, partitionID UniqueID, segmentIDs []UniqueID, fieldIDs []int64) error { - // TODO: interim solution - if len(fieldIDs) == 0 { - collection, err := s.replica.getCollectionByID(collectionID) - if err != nil { - return err - } - fieldIDs = make([]int64, 0) - for _, field := range collection.Schema().Fields { - fieldIDs = append(fieldIDs, field.FieldID) - } - } - for _, segmentID := range segmentIDs { - // we don't need index id yet - _, buildID, errIndex := s.getIndexInfo(collectionID, segmentID) - if errIndex == nil { - // we don't need load to vector fields - vectorFields, err := s.replica.getVecFieldsByCollectionID(segmentID) - if err != nil { - return err - } - fieldIDs = s.filterOutVectorFields(fieldIDs, vectorFields) - } - paths, srcFieldIDs, err := s.getInsertBinlogPaths(segmentID) - if err != nil { - return err - } - - targetFields := s.getTargetFields(paths, srcFieldIDs, fieldIDs) - collection, err := s.replica.getCollectionByID(collectionID) - if err != nil { - return err - } - segment := newSegment(collection, segmentID, partitionID, collectionID, segTypeSealed) - err = s.loadSegmentFieldsData(segment, targetFields) - if err != nil { - return err - } - if errIndex == nil { - indexPaths, err := s.getIndexPaths(buildID) - if err != nil { - return err - } - err = s.loadIndexImmediate(segment, indexPaths) - if err != nil { - // TODO: return or continue? - return err - } - } - } - return nil -} - -func (s *loadService) releaseSegment(segmentID UniqueID) error { - err := s.replica.removeSegment(segmentID) - return err -} - -func (s *loadService) seekSegment(positions []*internalpb2.MsgPosition) error { - // TODO: open seek - //for _, position := range positions { - // err := s.dmStream.Seek(position) - // if err != nil { - // return err - // } - //} - return nil -} - -func (s *loadService) getIndexPaths(buildID UniqueID) ([]string, error) { - if s.indexClient == nil { - return nil, errors.New("null index service client") - } - - indexFilePathRequest := &indexpb.IndexFilePathsRequest{ - // TODO: rename indexIDs to buildIDs - IndexIDs: []UniqueID{buildID}, - } - pathResponse, err := s.indexClient.GetIndexFilePaths(indexFilePathRequest) - if err != nil || pathResponse.Status.ErrorCode != commonpb.ErrorCode_SUCCESS { - return nil, err - } - - if len(pathResponse.FilePaths) <= 0 { - return nil, errors.New("illegal index file paths") - } - - return pathResponse.FilePaths[0].IndexFilePaths, nil -} - -func (s *loadService) loadIndexImmediate(segment *Segment, indexPaths []string) error { - // get vector field ids from schema to load index - vecFieldIDs, err := s.replica.getVecFieldsByCollectionID(segment.collectionID) - if err != nil { - return err - } - for id, name := range vecFieldIDs { - l := &loadIndex{ - segmentID: segment.ID(), - fieldName: name, - fieldID: id, - indexPaths: indexPaths, - } - err = s.execute(l) - if err != nil { - return err - } - // replace segment - err = s.replica.replaceGrowingSegmentBySealedSegment(segment) - if err != nil { - return err - } - } - return nil -} - -func (s *loadService) loadIndexDelayed(collectionID, segmentID UniqueID, indexPaths []string) error { - // get vector field ids from schema to load index - vecFieldIDs, err := s.replica.getVecFieldsByCollectionID(collectionID) - if err != nil { - return err - } - for id, name := range vecFieldIDs { - l := &loadIndex{ - segmentID: segmentID, - fieldName: name, - fieldID: id, - indexPaths: indexPaths, - } - - err = s.execute(l) - if err != nil { - return err - } - } - - return nil -} - -func (s *loadService) getInsertBinlogPaths(segmentID UniqueID) ([]*internalpb2.StringList, []int64, error) { - if s.dataClient == nil { - return nil, nil, errors.New("null data service client") - } - - insertBinlogPathRequest := &datapb.InsertBinlogPathRequest{ - SegmentID: segmentID, - } - - pathResponse, err := s.dataClient.GetInsertBinlogPaths(insertBinlogPathRequest) - if err != nil { - return nil, nil, err - } - - if len(pathResponse.FieldIDs) != len(pathResponse.Paths) { - return nil, nil, errors.New("illegal InsertBinlogPathsResponse") - } - - return pathResponse.Paths, pathResponse.FieldIDs, nil -} - -func (s *loadService) filterOutVectorFields(fieldIDs []int64, vectorFields map[int64]string) []int64 { - targetFields := make([]int64, 0) - for _, id := range fieldIDs { - if _, ok := vectorFields[id]; !ok { - targetFields = append(targetFields, id) - } - } - return targetFields -} - -func (s *loadService) getTargetFields(paths []*internalpb2.StringList, srcFieldIDS []int64, dstFields []int64) map[int64]*internalpb2.StringList { - targetFields := make(map[int64]*internalpb2.StringList) - - containsFunc := func(s []int64, e int64) bool { - for _, a := range s { - if a == e { - return true - } - } - return false - } - - for i, fieldID := range srcFieldIDS { - if containsFunc(dstFields, fieldID) { - targetFields[fieldID] = paths[i] - } - } - - return targetFields -} - -func (s *loadService) loadSegmentFieldsData(segment *Segment, targetFields map[int64]*internalpb2.StringList) error { - for id, p := range targetFields { - if id == timestampFieldID { - // seg core doesn't need timestamp field - continue - } - - paths := p.Values - blobs := make([]*storage.Blob, 0) - for _, path := range paths { - binLog, err := s.kv.Load(path) - if err != nil { - // TODO: return or continue? - return err - } - blobs = append(blobs, &storage.Blob{ - Key: strconv.FormatInt(id, 10), // TODO: key??? - Value: []byte(binLog), - }) - } - _, _, insertData, err := s.iCodec.Deserialize(blobs) - if err != nil { - // TODO: return or continue - return err - } - if len(insertData.Data) != 1 { - return errors.New("we expect only one field in deserialized insert data") - } - - for _, value := range insertData.Data { - var numRows int - var data interface{} - - switch fieldData := value.(type) { - case *storage.BoolFieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.Int8FieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.Int16FieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.Int32FieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.Int64FieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.FloatFieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.DoubleFieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case storage.StringFieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.FloatVectorFieldData: - numRows = fieldData.NumRows - data = fieldData.Data - case *storage.BinaryVectorFieldData: - numRows = fieldData.NumRows - data = fieldData.Data - default: - return errors.New("unexpected field data type") - } - err = segment.segmentLoadFieldData(id, numRows, data) - if err != nil { - // TODO: return or continue? - return err - } - } - } - - return nil } func newLoadService(ctx context.Context, masterClient MasterServiceInterface, dataClient DataServiceInterface, indexClient IndexServiceInterface, replica collectionReplica, dmStream msgstream.MsgStream) *loadService { @@ -608,27 +351,47 @@ func newLoadService(ctx context.Context, masterClient MasterServiceInterface, da BucketName: Params.MinioBucketName, } - client, err := minioKV.NewMinIOKV(ctx1, option) + MinioKV, err := minioKV.NewMinIOKV(ctx1, option) if err != nil { panic(err) } + // init msgStream + receiveBufSize := Params.LoadIndexReceiveBufSize + pulsarBufSize := Params.LoadIndexPulsarBufSize + + msgStreamURL := Params.PulsarAddress + + consumeChannels := Params.LoadIndexChannelNames + consumeSubName := Params.MsgChannelSubName + + loadIndexStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) + loadIndexStream.SetPulsarClient(msgStreamURL) + unmarshalDispatcher := util.NewUnmarshalDispatcher() + loadIndexStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize) + + var stream msgstream.MsgStream = loadIndexStream + + // init index load requests channel size by message receive buffer size + indexLoadChanSize := receiveBufSize + + // init segment manager + loadIndexReqChan := make(chan []msgstream.TsMsg, indexLoadChanSize) + manager := newSegmentManager(ctx1, masterClient, dataClient, indexClient, replica, dmStream, loadIndexReqChan) + return &loadService{ ctx: ctx1, cancel: cancel, + client: MinioKV, - replica: replica, - + replica: replica, + queryNodeID: Params.QueryNodeID, fieldIndexes: make(map[string][]*internalpb2.IndexStats), fieldStatsChan: make(chan []*internalpb2.FieldStats, 1), - dmStream: dmStream, - - masterClient: masterClient, - dataClient: dataClient, - indexClient: indexClient, + loadIndexReqChan: loadIndexReqChan, + loadIndexMsgStream: stream, - kv: client, - iCodec: &storage.InsertCodec{}, + segManager: manager, } } diff --git a/internal/querynode/load_service_test.go b/internal/querynode/load_service_test.go index f22ab9e80e6b6ad206eaf43121f65995e28705bd..2c63746e75b73f142fe3b7847ffa399571c4586a 100644 --- a/internal/querynode/load_service_test.go +++ b/internal/querynode/load_service_test.go @@ -1,15 +1,14 @@ package querynode import ( - "context" "encoding/binary" "fmt" + "log" "math" "math/rand" - "path" + "sort" "strconv" "testing" - "time" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" @@ -22,776 +21,210 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "github.com/zilliztech/milvus-distributed/internal/querynode/client" "github.com/zilliztech/milvus-distributed/internal/storage" ) -//func TestLoadService_LoadIndex_FloatVector(t *testing.T) { -// node := newQueryNodeMock() -// collectionID := rand.Int63n(1000000) -// segmentID := rand.Int63n(1000000) -// initTestMeta(t, node, "collection0", collectionID, segmentID) -// -// // loadService and statsService -// suffix := "-test-search" + strconv.FormatInt(rand.Int63n(1000000), 10) -// oldSearchChannelNames := Params.SearchChannelNames -// newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) -// Params.SearchChannelNames = newSearchChannelNames -// -// oldSearchResultChannelNames := Params.SearchChannelNames -// newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) -// Params.SearchResultChannelNames = newSearchResultChannelNames -// -// oldLoadIndexChannelNames := Params.LoadIndexChannelNames -// newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) -// Params.LoadIndexChannelNames = newLoadIndexChannelNames -// -// oldStatsChannelName := Params.StatsChannelName -// newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) -// Params.StatsChannelName = newStatsChannelNames[0] -// go node.Start() -// -// //generate insert data -// const msgLength = 1000 -// const receiveBufSize = 1024 -// const DIM = 16 -// var insertRowBlob []*commonpb.Blob -// var timestamps []uint64 -// var rowIDs []int64 -// var hashValues []uint32 -// for n := 0; n < msgLength; n++ { -// rowData := make([]byte, 0) -// for i := 0; i < DIM; i++ { -// vec := make([]byte, 4) -// binary.LittleEndian.PutUint32(vec, math.Float32bits(float32(n*i))) -// rowData = append(rowData, vec...) -// } -// age := make([]byte, 4) -// binary.LittleEndian.PutUint32(age, 1) -// rowData = append(rowData, age...) -// blob := &commonpb.Blob{ -// Value: rowData, -// } -// insertRowBlob = append(insertRowBlob, blob) -// timestamps = append(timestamps, uint64(n)) -// rowIDs = append(rowIDs, int64(n)) -// hashValues = append(hashValues, uint32(n)) -// } -// -// var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: hashValues, -// }, -// InsertRequest: internalpb2.InsertRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kInsert, -// MsgID: 0, -// Timestamp: timestamps[0], -// SourceID: 0, -// }, -// CollectionName: "collection0", -// PartitionName: "default", -// SegmentID: segmentID, -// ChannelID: "0", -// Timestamps: timestamps, -// RowIDs: rowIDs, -// RowData: insertRowBlob, -// }, -// } -// insertMsgPack := msgstream.MsgPack{ -// BeginTs: 0, -// EndTs: math.MaxUint64, -// Msgs: []msgstream.TsMsg{insertMsg}, -// } -// -// // generate timeTick -// timeTickMsg := &msgstream.TimeTickMsg{ -// BaseMsg: msgstream.BaseMsg{ -// BeginTimestamp: 0, -// EndTimestamp: 0, -// HashValues: []uint32{0}, -// }, -// TimeTickMsg: internalpb2.TimeTickMsg{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kTimeTick, -// MsgID: 0, -// Timestamp: math.MaxUint64, -// SourceID: 0, -// }, -// }, -// } -// timeTickMsgPack := &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{timeTickMsg}, -// } -// -// // pulsar produce -// insertChannels := Params.InsertChannelNames -// ddChannels := Params.DDChannelNames -// -// insertStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// insertStream.SetPulsarClient(Params.PulsarAddress) -// insertStream.CreatePulsarProducers(insertChannels) -// ddStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// ddStream.SetPulsarClient(Params.PulsarAddress) -// ddStream.CreatePulsarProducers(ddChannels) -// -// var insertMsgStream msgstream.MsgStream = insertStream -// insertMsgStream.Start() -// var ddMsgStream msgstream.MsgStream = ddStream -// ddMsgStream.Start() -// -// err := insertMsgStream.Produce(&insertMsgPack) -// assert.NoError(t, err) -// err = insertMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// err = ddMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// -// // generator searchRowData -// var searchRowData []float32 -// for i := 0; i < DIM; i++ { -// searchRowData = append(searchRowData, float32(42*i)) -// } -// -// //generate search data and send search msg -// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" -// var searchRowByteData []byte -// for i := range searchRowData { -// vec := make([]byte, 4) -// binary.LittleEndian.PutUint32(vec, math.Float32bits(searchRowData[i])) -// searchRowByteData = append(searchRowByteData, vec...) -// } -// placeholderValue := milvuspb.PlaceholderValue{ -// Tag: "$0", -// Type: milvuspb.PlaceholderType_VECTOR_FLOAT, -// Values: [][]byte{searchRowByteData}, -// } -// placeholderGroup := milvuspb.PlaceholderGroup{ -// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, -// } -// placeGroupByte, err := proto.Marshal(&placeholderGroup) -// if err != nil { -// log.Print("marshal placeholderGroup failed") -// } -// query := milvuspb.SearchRequest{ -// CollectionName: "collection0", -// PartitionNames: []string{"default"}, -// Dsl: dslString, -// PlaceholderGroup: placeGroupByte, -// } -// queryByte, err := proto.Marshal(&query) -// if err != nil { -// log.Print("marshal query failed") -// } -// blob := commonpb.Blob{ -// Value: queryByte, -// } -// fn := func(n int64) *msgstream.MsgPack { -// searchMsg := &msgstream.SearchMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: []uint32{0}, -// }, -// SearchRequest: internalpb2.SearchRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kSearch, -// MsgID: n, -// Timestamp: uint64(msgLength), -// SourceID: 1, -// }, -// ResultChannelID: "0", -// Query: &blob, -// }, -// } -// return &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{searchMsg}, -// } -// } -// searchStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchStream.SetPulsarClient(Params.PulsarAddress) -// searchStream.CreatePulsarProducers(newSearchChannelNames) -// searchStream.Start() -// err = searchStream.Produce(fn(1)) -// assert.NoError(t, err) -// -// //get search result -// searchResultStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchResultStream.SetPulsarClient(Params.PulsarAddress) -// unmarshalDispatcher := util.NewUnmarshalDispatcher() -// searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult", unmarshalDispatcher, receiveBufSize) -// searchResultStream.Start() -// searchResult := searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// unMarshaledHit := milvuspb.Hits{} -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// // gen load index message pack -// indexParams := make(map[string]string) -// indexParams["index_type"] = "IVF_PQ" -// indexParams["index_mode"] = "cpu" -// indexParams["dim"] = "16" -// indexParams["k"] = "10" -// indexParams["nlist"] = "100" -// indexParams["nprobe"] = "10" -// indexParams["m"] = "4" -// indexParams["nbits"] = "8" -// indexParams["metric_type"] = "L2" -// indexParams["SLICE_SIZE"] = "4" -// -// var indexParamsKV []*commonpb.KeyValuePair -// for key, value := range indexParams { -// indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ -// Key: key, -// Value: value, -// }) -// } -// -// // generator index -// typeParams := make(map[string]string) -// typeParams["dim"] = "16" -// var indexRowData []float32 -// for n := 0; n < msgLength; n++ { -// for i := 0; i < DIM; i++ { -// indexRowData = append(indexRowData, float32(n*i)) -// } -// } -// index, err := indexnode.NewCIndex(typeParams, indexParams) -// assert.Nil(t, err) -// err = index.BuildFloatVecIndexWithoutIds(indexRowData) -// assert.Equal(t, err, nil) -// -// option := &minioKV.Option{ -// Address: Params.MinioEndPoint, -// AccessKeyID: Params.MinioAccessKeyID, -// SecretAccessKeyID: Params.MinioSecretAccessKey, -// UseSSL: Params.MinioUseSSLStr, -// BucketName: Params.MinioBucketName, -// CreateBucket: true, -// } -// -// minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) -// assert.Equal(t, err, nil) -// //save index to minio -// binarySet, err := index.Serialize() -// assert.Equal(t, err, nil) -// indexPaths := make([]string, 0) -// var indexCodec storage.IndexCodec -// binarySet, err = indexCodec.Serialize(binarySet, indexParams) -// assert.NoError(t, err) -// for _, index := range binarySet { -// path := strconv.Itoa(int(segmentID)) + "/" + index.Key -// indexPaths = append(indexPaths, path) -// minioKV.Save(path, string(index.Value)) -// } -// -// //test index search result -// indexResult, err := index.QueryOnFloatVecIndexWithParam(searchRowData, indexParams) -// assert.Equal(t, err, nil) -// -// // create loadIndexClient -// fieldID := UniqueID(100) -// loadIndexChannelNames := Params.LoadIndexChannelNames -// client := client.NewQueryNodeClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) -// client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) -// -// // init message stream consumer and do checks -// statsMs := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) -// statsMs.SetPulsarClient(Params.PulsarAddress) -// statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, util.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) -// statsMs.Start() -// -// findFiledStats := false -// for { -// receiveMsg := msgstream.MsgStream(statsMs).Consume() -// assert.NotNil(t, receiveMsg) -// assert.NotEqual(t, len(receiveMsg.Msgs), 0) -// -// for _, msg := range receiveMsg.Msgs { -// statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) -// if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { -// continue -// } -// findFiledStats = true -// assert.Equal(t, ok, true) -// assert.Equal(t, len(statsMsg.FieldStats), 1) -// fieldStats0 := statsMsg.FieldStats[0] -// assert.Equal(t, fieldStats0.FieldID, fieldID) -// assert.Equal(t, fieldStats0.CollectionID, collectionID) -// assert.Equal(t, len(fieldStats0.IndexStats), 1) -// indexStats0 := fieldStats0.IndexStats[0] -// params := indexStats0.IndexParams -// // sort index params by key -// sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) -// indexEqual := node.loadService.indexParamsEqual(params, indexParamsKV) -// assert.Equal(t, indexEqual, true) -// } -// -// if findFiledStats { -// break -// } -// } -// -// err = searchStream.Produce(fn(2)) -// assert.NoError(t, err) -// searchResult = searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// idsIndex := indexResult.IDs() -// idsSegment := unMarshaledHit.IDs -// assert.Equal(t, len(idsIndex), len(idsSegment)) -// for i := 0; i < len(idsIndex); i++ { -// assert.Equal(t, idsIndex[i], idsSegment[i]) -// } -// Params.SearchChannelNames = oldSearchChannelNames -// Params.SearchResultChannelNames = oldSearchResultChannelNames -// Params.LoadIndexChannelNames = oldLoadIndexChannelNames -// Params.StatsChannelName = oldStatsChannelName -// fmt.Println("loadIndex floatVector test Done!") -// -// defer assert.Equal(t, findFiledStats, true) -// <-node.queryNodeLoopCtx.Done() -// node.Stop() -//} -// -//func TestLoadService_LoadIndex_BinaryVector(t *testing.T) { -// node := newQueryNodeMock() -// collectionID := rand.Int63n(1000000) -// segmentID := rand.Int63n(1000000) -// initTestMeta(t, node, "collection0", collectionID, segmentID, true) -// -// // loadService and statsService -// suffix := "-test-search-binary" + strconv.FormatInt(rand.Int63n(1000000), 10) -// oldSearchChannelNames := Params.SearchChannelNames -// newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) -// Params.SearchChannelNames = newSearchChannelNames -// -// oldSearchResultChannelNames := Params.SearchChannelNames -// newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) -// Params.SearchResultChannelNames = newSearchResultChannelNames -// -// oldLoadIndexChannelNames := Params.LoadIndexChannelNames -// newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) -// Params.LoadIndexChannelNames = newLoadIndexChannelNames -// -// oldStatsChannelName := Params.StatsChannelName -// newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) -// Params.StatsChannelName = newStatsChannelNames[0] -// go node.Start() -// -// const msgLength = 1000 -// const receiveBufSize = 1024 -// const DIM = 128 -// -// // generator index data -// var indexRowData []byte -// for n := 0; n < msgLength; n++ { -// for i := 0; i < DIM/8; i++ { -// indexRowData = append(indexRowData, byte(rand.Intn(8))) -// } -// } -// -// //generator insert data -// var insertRowBlob []*commonpb.Blob -// var timestamps []uint64 -// var rowIDs []int64 -// var hashValues []uint32 -// offset := 0 -// for n := 0; n < msgLength; n++ { -// rowData := make([]byte, 0) -// rowData = append(rowData, indexRowData[offset:offset+(DIM/8)]...) -// offset += DIM / 8 -// age := make([]byte, 4) -// binary.LittleEndian.PutUint32(age, 1) -// rowData = append(rowData, age...) -// blob := &commonpb.Blob{ -// Value: rowData, -// } -// insertRowBlob = append(insertRowBlob, blob) -// timestamps = append(timestamps, uint64(n)) -// rowIDs = append(rowIDs, int64(n)) -// hashValues = append(hashValues, uint32(n)) -// } -// -// var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: hashValues, -// }, -// InsertRequest: internalpb2.InsertRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kInsert, -// MsgID: 0, -// Timestamp: timestamps[0], -// SourceID: 0, -// }, -// CollectionName: "collection0", -// PartitionName: "default", -// SegmentID: segmentID, -// ChannelID: "0", -// Timestamps: timestamps, -// RowIDs: rowIDs, -// RowData: insertRowBlob, -// }, -// } -// insertMsgPack := msgstream.MsgPack{ -// BeginTs: 0, -// EndTs: math.MaxUint64, -// Msgs: []msgstream.TsMsg{insertMsg}, -// } -// -// // generate timeTick -// timeTickMsg := &msgstream.TimeTickMsg{ -// BaseMsg: msgstream.BaseMsg{ -// BeginTimestamp: 0, -// EndTimestamp: 0, -// HashValues: []uint32{0}, -// }, -// TimeTickMsg: internalpb2.TimeTickMsg{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kTimeTick, -// MsgID: 0, -// Timestamp: math.MaxUint64, -// SourceID: 0, -// }, -// }, -// } -// timeTickMsgPack := &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{timeTickMsg}, -// } -// -// // pulsar produce -// insertChannels := Params.InsertChannelNames -// ddChannels := Params.DDChannelNames -// -// insertStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// insertStream.SetPulsarClient(Params.PulsarAddress) -// insertStream.CreatePulsarProducers(insertChannels) -// ddStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// ddStream.SetPulsarClient(Params.PulsarAddress) -// ddStream.CreatePulsarProducers(ddChannels) -// -// var insertMsgStream msgstream.MsgStream = insertStream -// insertMsgStream.Start() -// var ddMsgStream msgstream.MsgStream = ddStream -// ddMsgStream.Start() -// -// err := insertMsgStream.Produce(&insertMsgPack) -// assert.NoError(t, err) -// err = insertMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// err = ddMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// -// //generate search data and send search msg -// searchRowData := indexRowData[42*(DIM/8) : 43*(DIM/8)] -// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"JACCARD\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" -// placeholderValue := milvuspb.PlaceholderValue{ -// Tag: "$0", -// Type: milvuspb.PlaceholderType_VECTOR_BINARY, -// Values: [][]byte{searchRowData}, -// } -// placeholderGroup := milvuspb.PlaceholderGroup{ -// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, -// } -// placeGroupByte, err := proto.Marshal(&placeholderGroup) -// if err != nil { -// log.Print("marshal placeholderGroup failed") -// } -// query := milvuspb.SearchRequest{ -// CollectionName: "collection0", -// PartitionNames: []string{"default"}, -// Dsl: dslString, -// PlaceholderGroup: placeGroupByte, -// } -// queryByte, err := proto.Marshal(&query) -// if err != nil { -// log.Print("marshal query failed") -// } -// blob := commonpb.Blob{ -// Value: queryByte, -// } -// fn := func(n int64) *msgstream.MsgPack { -// searchMsg := &msgstream.SearchMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: []uint32{0}, -// }, -// SearchRequest: internalpb2.SearchRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kSearch, -// MsgID: n, -// Timestamp: uint64(msgLength), -// SourceID: 1, -// }, -// ResultChannelID: "0", -// Query: &blob, -// }, -// } -// return &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{searchMsg}, -// } -// } -// searchStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchStream.SetPulsarClient(Params.PulsarAddress) -// searchStream.CreatePulsarProducers(newSearchChannelNames) -// searchStream.Start() -// err = searchStream.Produce(fn(1)) -// assert.NoError(t, err) -// -// //get search result -// searchResultStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchResultStream.SetPulsarClient(Params.PulsarAddress) -// unmarshalDispatcher := util.NewUnmarshalDispatcher() -// searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult2", unmarshalDispatcher, receiveBufSize) -// searchResultStream.Start() -// searchResult := searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// unMarshaledHit := milvuspb.Hits{} -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// // gen load index message pack -// indexParams := make(map[string]string) -// indexParams["index_type"] = "BIN_IVF_FLAT" -// indexParams["index_mode"] = "cpu" -// indexParams["dim"] = "128" -// indexParams["k"] = "10" -// indexParams["nlist"] = "100" -// indexParams["nprobe"] = "10" -// indexParams["m"] = "4" -// indexParams["nbits"] = "8" -// indexParams["metric_type"] = "JACCARD" -// indexParams["SLICE_SIZE"] = "4" -// -// var indexParamsKV []*commonpb.KeyValuePair -// for key, value := range indexParams { -// indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ -// Key: key, -// Value: value, -// }) -// } -// -// // generator index -// typeParams := make(map[string]string) -// typeParams["dim"] = "128" -// index, err := indexnode.NewCIndex(typeParams, indexParams) -// assert.Nil(t, err) -// err = index.BuildBinaryVecIndexWithoutIds(indexRowData) -// assert.Equal(t, err, nil) -// -// option := &minioKV.Option{ -// Address: Params.MinioEndPoint, -// AccessKeyID: Params.MinioAccessKeyID, -// SecretAccessKeyID: Params.MinioSecretAccessKey, -// UseSSL: Params.MinioUseSSLStr, -// BucketName: Params.MinioBucketName, -// CreateBucket: true, -// } -// -// minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) -// assert.Equal(t, err, nil) -// //save index to minio -// binarySet, err := index.Serialize() -// assert.Equal(t, err, nil) -// var indexCodec storage.IndexCodec -// binarySet, err = indexCodec.Serialize(binarySet, indexParams) -// assert.NoError(t, err) -// indexPaths := make([]string, 0) -// for _, index := range binarySet { -// path := strconv.Itoa(int(segmentID)) + "/" + index.Key -// indexPaths = append(indexPaths, path) -// minioKV.Save(path, string(index.Value)) -// } -// -// //test index search result -// indexResult, err := index.QueryOnBinaryVecIndexWithParam(searchRowData, indexParams) -// assert.Equal(t, err, nil) -// -// // create loadIndexClient -// fieldID := UniqueID(100) -// loadIndexChannelNames := Params.LoadIndexChannelNames -// client := client.NewQueryNodeClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) -// client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) -// -// // init message stream consumer and do checks -// statsMs := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) -// statsMs.SetPulsarClient(Params.PulsarAddress) -// statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, util.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) -// statsMs.Start() -// -// findFiledStats := false -// for { -// receiveMsg := msgstream.MsgStream(statsMs).Consume() -// assert.NotNil(t, receiveMsg) -// assert.NotEqual(t, len(receiveMsg.Msgs), 0) -// -// for _, msg := range receiveMsg.Msgs { -// statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) -// if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { -// continue -// } -// findFiledStats = true -// assert.Equal(t, ok, true) -// assert.Equal(t, len(statsMsg.FieldStats), 1) -// fieldStats0 := statsMsg.FieldStats[0] -// assert.Equal(t, fieldStats0.FieldID, fieldID) -// assert.Equal(t, fieldStats0.CollectionID, collectionID) -// assert.Equal(t, len(fieldStats0.IndexStats), 1) -// indexStats0 := fieldStats0.IndexStats[0] -// params := indexStats0.IndexParams -// // sort index params by key -// sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) -// indexEqual := node.loadService.indexParamsEqual(params, indexParamsKV) -// assert.Equal(t, indexEqual, true) -// } -// -// if findFiledStats { -// break -// } -// } -// -// err = searchStream.Produce(fn(2)) -// assert.NoError(t, err) -// searchResult = searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// idsIndex := indexResult.IDs() -// idsSegment := unMarshaledHit.IDs -// assert.Equal(t, len(idsIndex), len(idsSegment)) -// for i := 0; i < len(idsIndex); i++ { -// assert.Equal(t, idsIndex[i], idsSegment[i]) -// } -// Params.SearchChannelNames = oldSearchChannelNames -// Params.SearchResultChannelNames = oldSearchResultChannelNames -// Params.LoadIndexChannelNames = oldLoadIndexChannelNames -// Params.StatsChannelName = oldStatsChannelName -// fmt.Println("loadIndex binaryVector test Done!") -// -// defer assert.Equal(t, findFiledStats, true) -// <-node.queryNodeLoopCtx.Done() -// node.Stop() -//} - -/////////////////////////////////////////////////////////////////////////////////////////////////////////// -func generateInsertBinLog(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, keyPrefix string) ([]*internalpb2.StringList, []int64, error) { - const ( - msgLength = 1000 - DIM = 16 - ) - - idData := make([]int64, 0) - for n := 0; n < msgLength; n++ { - idData = append(idData, int64(n)) - } - - var timestamps []int64 - for n := 0; n < msgLength; n++ { - timestamps = append(timestamps, int64(n+1)) - } - - var fieldAgeData []int32 - for n := 0; n < msgLength; n++ { - fieldAgeData = append(fieldAgeData, int32(n)) - } - - fieldVecData := make([]float32, 0) +func TestLoadService_LoadIndex_FloatVector(t *testing.T) { + node := newQueryNodeMock() + collectionID := rand.Int63n(1000000) + segmentID := rand.Int63n(1000000) + initTestMeta(t, node, "collection0", collectionID, segmentID) + + // loadService and statsService + suffix := "-test-search" + strconv.FormatInt(rand.Int63n(1000000), 10) + oldSearchChannelNames := Params.SearchChannelNames + newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) + Params.SearchChannelNames = newSearchChannelNames + + oldSearchResultChannelNames := Params.SearchChannelNames + newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) + Params.SearchResultChannelNames = newSearchResultChannelNames + + oldLoadIndexChannelNames := Params.LoadIndexChannelNames + newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) + Params.LoadIndexChannelNames = newLoadIndexChannelNames + + oldStatsChannelName := Params.StatsChannelName + newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) + Params.StatsChannelName = newStatsChannelNames[0] + go node.Start() + + //generate insert data + const msgLength = 1000 + const receiveBufSize = 1024 + const DIM = 16 + var insertRowBlob []*commonpb.Blob + var timestamps []uint64 + var rowIDs []int64 + var hashValues []uint32 for n := 0; n < msgLength; n++ { + rowData := make([]byte, 0) for i := 0; i < DIM; i++ { - fieldVecData = append(fieldVecData, float32(n*i)*0.1) + vec := make([]byte, 4) + binary.LittleEndian.PutUint32(vec, math.Float32bits(float32(n*i))) + rowData = append(rowData, vec...) } + age := make([]byte, 4) + binary.LittleEndian.PutUint32(age, 1) + rowData = append(rowData, age...) + blob := &commonpb.Blob{ + Value: rowData, + } + insertRowBlob = append(insertRowBlob, blob) + timestamps = append(timestamps, uint64(n)) + rowIDs = append(rowIDs, int64(n)) + hashValues = append(hashValues, uint32(n)) } - insertData := &storage.InsertData{ - Data: map[int64]storage.FieldData{ - 0: &storage.Int64FieldData{ - NumRows: msgLength, - Data: idData, - }, - 1: &storage.Int64FieldData{ - NumRows: msgLength, - Data: timestamps, - }, - 100: &storage.FloatVectorFieldData{ - NumRows: msgLength, - Data: fieldVecData, - Dim: DIM, + var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: hashValues, + }, + InsertRequest: internalpb2.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: 0, + Timestamp: timestamps[0], + SourceID: 0, }, - 101: &storage.Int32FieldData{ - NumRows: msgLength, - Data: fieldAgeData, + CollectionName: "collection0", + PartitionName: "default", + SegmentID: segmentID, + ChannelID: "0", + Timestamps: timestamps, + RowIDs: rowIDs, + RowData: insertRowBlob, + }, + } + insertMsgPack := msgstream.MsgPack{ + BeginTs: 0, + EndTs: math.MaxUint64, + Msgs: []msgstream.TsMsg{insertMsg}, + } + + // generate timeTick + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{0}, + }, + TimeTickMsg: internalpb2.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: 0, + Timestamp: math.MaxUint64, + SourceID: 0, }, }, } + timeTickMsgPack := &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{timeTickMsg}, + } - // buffer data to binLogs - collMeta := genTestCollectionMeta("collection0", collectionID, false) - collMeta.Schema.Fields = append(collMeta.Schema.Fields, &schemapb.FieldSchema{ - FieldID: 0, - Name: "uid", - DataType: schemapb.DataType_INT64, - }) - collMeta.Schema.Fields = append(collMeta.Schema.Fields, &schemapb.FieldSchema{ - FieldID: 1, - Name: "timestamp", - DataType: schemapb.DataType_INT64, - }) - inCodec := storage.NewInsertCodec(collMeta) - binLogs, err := inCodec.Serialize(partitionID, segmentID, insertData) + // pulsar produce + insertChannels := Params.InsertChannelNames + ddChannels := Params.DDChannelNames - if err != nil { - return nil, nil, err + insertStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream.SetPulsarClient(Params.PulsarAddress) + insertStream.CreatePulsarProducers(insertChannels) + ddStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + ddStream.SetPulsarClient(Params.PulsarAddress) + ddStream.CreatePulsarProducers(ddChannels) + + var insertMsgStream msgstream.MsgStream = insertStream + insertMsgStream.Start() + var ddMsgStream msgstream.MsgStream = ddStream + ddMsgStream.Start() + + err := insertMsgStream.Produce(&insertMsgPack) + assert.NoError(t, err) + err = insertMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + err = ddMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + + // generator searchRowData + var searchRowData []float32 + for i := 0; i < DIM; i++ { + searchRowData = append(searchRowData, float32(42*i)) } - // create minio client - bucketName := Params.MinioBucketName - option := &minioKV.Option{ - Address: Params.MinioEndPoint, - AccessKeyID: Params.MinioAccessKeyID, - SecretAccessKeyID: Params.MinioSecretAccessKey, - UseSSL: Params.MinioUseSSLStr, - BucketName: bucketName, - CreateBucket: true, + //generate search data and send search msg + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + var searchRowByteData []byte + for i := range searchRowData { + vec := make([]byte, 4) + binary.LittleEndian.PutUint32(vec, math.Float32bits(searchRowData[i])) + searchRowByteData = append(searchRowByteData, vec...) + } + placeholderValue := milvuspb.PlaceholderValue{ + Tag: "$0", + Type: milvuspb.PlaceholderType_VECTOR_FLOAT, + Values: [][]byte{searchRowByteData}, } - kv, err := minioKV.NewMinIOKV(context.Background(), option) + placeholderGroup := milvuspb.PlaceholderGroup{ + Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, + } + placeGroupByte, err := proto.Marshal(&placeholderGroup) if err != nil { - return nil, nil, err + log.Print("marshal placeholderGroup failed") } - - // binLogs -> minIO/S3 - segIDStr := strconv.FormatInt(segmentID, 10) - keyPrefix = path.Join(keyPrefix, segIDStr) - - paths := make([]*internalpb2.StringList, 0) - fieldIDs := make([]int64, 0) - fmt.Println(".. saving binlog to MinIO ...", len(binLogs)) - for _, blob := range binLogs { - uid := rand.Int63n(100000000) - key := path.Join(keyPrefix, blob.Key, strconv.FormatInt(uid, 10)) - err = kv.Save(key, string(blob.Value[:])) - if err != nil { - return nil, nil, err + query := milvuspb.SearchRequest{ + CollectionName: "collection0", + PartitionNames: []string{"default"}, + Dsl: dslString, + PlaceholderGroup: placeGroupByte, + } + queryByte, err := proto.Marshal(&query) + if err != nil { + log.Print("marshal query failed") + } + blob := commonpb.Blob{ + Value: queryByte, + } + fn := func(n int64) *msgstream.MsgPack { + searchMsg := &msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{0}, + }, + SearchRequest: internalpb2.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearch, + MsgID: n, + Timestamp: uint64(msgLength), + SourceID: 1, + }, + ResultChannelID: "0", + Query: &blob, + }, } - paths = append(paths, &internalpb2.StringList{ - Values: []string{key}, - }) - fieldID, err := strconv.Atoi(blob.Key) - if err != nil { - return nil, nil, err + return &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{searchMsg}, } - fieldIDs = append(fieldIDs, int64(fieldID)) } + searchStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchStream.SetPulsarClient(Params.PulsarAddress) + searchStream.CreatePulsarProducers(newSearchChannelNames) + searchStream.Start() + err = searchStream.Produce(fn(1)) + assert.NoError(t, err) - return paths, fieldIDs, nil -} - -func generateIndex(segmentID UniqueID) ([]string, error) { - const ( - msgLength = 1000 - DIM = 16 - ) + //get search result + searchResultStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchResultStream.SetPulsarClient(Params.PulsarAddress) + unmarshalDispatcher := util.NewUnmarshalDispatcher() + searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult", unmarshalDispatcher, receiveBufSize) + searchResultStream.Start() + searchResult := searchResultStream.Consume() + assert.NotNil(t, searchResult) + unMarshaledHit := milvuspb.Hits{} + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) + // gen load index message pack indexParams := make(map[string]string) indexParams["index_type"] = "IVF_PQ" indexParams["index_mode"] = "cpu" @@ -812,24 +245,19 @@ func generateIndex(segmentID UniqueID) ([]string, error) { }) } + // generator index typeParams := make(map[string]string) - typeParams["dim"] = strconv.Itoa(DIM) + typeParams["dim"] = "16" var indexRowData []float32 for n := 0; n < msgLength; n++ { for i := 0; i < DIM; i++ { indexRowData = append(indexRowData, float32(n*i)) } } - index, err := indexnode.NewCIndex(typeParams, indexParams) - if err != nil { - return nil, err - } - + assert.Nil(t, err) err = index.BuildFloatVecIndexWithoutIds(indexRowData) - if err != nil { - return nil, err - } + assert.Equal(t, err, nil) option := &minioKV.Option{ Address: Params.MinioEndPoint, @@ -840,292 +268,407 @@ func generateIndex(segmentID UniqueID) ([]string, error) { CreateBucket: true, } - kv, err := minioKV.NewMinIOKV(context.Background(), option) - if err != nil { - return nil, err - } - - // save index to minio + minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) + assert.Equal(t, err, nil) + //save index to minio binarySet, err := index.Serialize() - if err != nil { - return nil, err - } - - // serialize index params + assert.Equal(t, err, nil) + indexPaths := make([]string, 0) var indexCodec storage.IndexCodec - serializedIndexBlobs, err := indexCodec.Serialize(binarySet, indexParams) - if err != nil { - return nil, err - } + binarySet, err = indexCodec.Serialize(binarySet, indexParams) + assert.NoError(t, err) + for _, index := range binarySet { + path := strconv.Itoa(int(segmentID)) + "/" + index.Key + indexPaths = append(indexPaths, path) + minioKV.Save(path, string(index.Value)) + } + + //test index search result + indexResult, err := index.QueryOnFloatVecIndexWithParam(searchRowData, indexParams) + assert.Equal(t, err, nil) + + // create loadIndexClient + fieldID := UniqueID(100) + loadIndexChannelNames := Params.LoadIndexChannelNames + client := client.NewQueryNodeClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) + client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) + + // init message stream consumer and do checks + statsMs := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) + statsMs.SetPulsarClient(Params.PulsarAddress) + statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, util.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) + statsMs.Start() + + findFiledStats := false + for { + receiveMsg := msgstream.MsgStream(statsMs).Consume() + assert.NotNil(t, receiveMsg) + assert.NotEqual(t, len(receiveMsg.Msgs), 0) + + for _, msg := range receiveMsg.Msgs { + statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) + if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { + continue + } + findFiledStats = true + assert.Equal(t, ok, true) + assert.Equal(t, len(statsMsg.FieldStats), 1) + fieldStats0 := statsMsg.FieldStats[0] + assert.Equal(t, fieldStats0.FieldID, fieldID) + assert.Equal(t, fieldStats0.CollectionID, collectionID) + assert.Equal(t, len(fieldStats0.IndexStats), 1) + indexStats0 := fieldStats0.IndexStats[0] + params := indexStats0.IndexParams + // sort index params by key + sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) + indexEqual := node.loadService.indexParamsEqual(params, indexParamsKV) + assert.Equal(t, indexEqual, true) + } - indexPaths := make([]string, 0) - for _, index := range serializedIndexBlobs { - p := strconv.Itoa(int(segmentID)) + "/" + index.Key - indexPaths = append(indexPaths, p) - err := kv.Save(p, string(index.Value)) - if err != nil { - return nil, err + if findFiledStats { + break } } - return indexPaths, nil + err = searchStream.Produce(fn(2)) + assert.NoError(t, err) + searchResult = searchResultStream.Consume() + assert.NotNil(t, searchResult) + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) + + idsIndex := indexResult.IDs() + idsSegment := unMarshaledHit.IDs + assert.Equal(t, len(idsIndex), len(idsSegment)) + for i := 0; i < len(idsIndex); i++ { + assert.Equal(t, idsIndex[i], idsSegment[i]) + } + Params.SearchChannelNames = oldSearchChannelNames + Params.SearchResultChannelNames = oldSearchResultChannelNames + Params.LoadIndexChannelNames = oldLoadIndexChannelNames + Params.StatsChannelName = oldStatsChannelName + fmt.Println("loadIndex floatVector test Done!") + + defer assert.Equal(t, findFiledStats, true) + <-node.queryNodeLoopCtx.Done() + node.Stop() } -func doInsert(ctx context.Context, collectionName string, partitionTag string, segmentID UniqueID) error { +func TestLoadService_LoadIndex_BinaryVector(t *testing.T) { + node := newQueryNodeMock() + collectionID := rand.Int63n(1000000) + segmentID := rand.Int63n(1000000) + initTestMeta(t, node, "collection0", collectionID, segmentID, true) + + // loadService and statsService + suffix := "-test-search-binary" + strconv.FormatInt(rand.Int63n(1000000), 10) + oldSearchChannelNames := Params.SearchChannelNames + newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) + Params.SearchChannelNames = newSearchChannelNames + + oldSearchResultChannelNames := Params.SearchChannelNames + newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) + Params.SearchResultChannelNames = newSearchResultChannelNames + + oldLoadIndexChannelNames := Params.LoadIndexChannelNames + newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) + Params.LoadIndexChannelNames = newLoadIndexChannelNames + + oldStatsChannelName := Params.StatsChannelName + newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) + Params.StatsChannelName = newStatsChannelNames[0] + go node.Start() + const msgLength = 1000 - const DIM = 16 + const receiveBufSize = 1024 + const DIM = 128 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) + // generator index data + var indexRowData []byte + for n := 0; n < msgLength; n++ { + for i := 0; i < DIM/8; i++ { + indexRowData = append(indexRowData, byte(rand.Intn(8))) + } } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - timeRange := TimeRange{ - timestampMin: 0, - timestampMax: math.MaxUint64, + //generator insert data + var insertRowBlob []*commonpb.Blob + var timestamps []uint64 + var rowIDs []int64 + var hashValues []uint32 + offset := 0 + for n := 0; n < msgLength; n++ { + rowData := make([]byte, 0) + rowData = append(rowData, indexRowData[offset:offset+(DIM/8)]...) + offset += DIM / 8 + age := make([]byte, 4) + binary.LittleEndian.PutUint32(age, 1) + rowData = append(rowData, age...) + blob := &commonpb.Blob{ + Value: rowData, + } + insertRowBlob = append(insertRowBlob, blob) + timestamps = append(timestamps, uint64(n)) + rowIDs = append(rowIDs, int64(n)) + hashValues = append(hashValues, uint32(n)) } - // messages generate - insertMessages := make([]msgstream.TsMsg, 0) - for i := 0; i < msgLength; i++ { - var msg msgstream.TsMsg = &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{ - uint32(i), - }, - }, - InsertRequest: internalpb2.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kInsert, - MsgID: 0, - Timestamp: uint64(i + 1000), - SourceID: 0, - }, - CollectionName: collectionName, - PartitionName: partitionTag, - SegmentID: segmentID, - ChannelID: "0", - Timestamps: []uint64{uint64(i + 1000)}, - RowIDs: []int64{int64(i)}, - RowData: []*commonpb.Blob{ - {Value: rawData}, - }, + var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: hashValues, + }, + InsertRequest: internalpb2.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: 0, + Timestamp: timestamps[0], + SourceID: 0, }, - } - insertMessages = append(insertMessages, msg) + CollectionName: "collection0", + PartitionName: "default", + SegmentID: segmentID, + ChannelID: "0", + Timestamps: timestamps, + RowIDs: rowIDs, + RowData: insertRowBlob, + }, } - - msgPack := msgstream.MsgPack{ - BeginTs: timeRange.timestampMin, - EndTs: timeRange.timestampMax, - Msgs: insertMessages, + insertMsgPack := msgstream.MsgPack{ + BeginTs: 0, + EndTs: math.MaxUint64, + Msgs: []msgstream.TsMsg{insertMsg}, } // generate timeTick - timeTickMsgPack := msgstream.MsgPack{} - baseMsg := msgstream.BaseMsg{ - BeginTimestamp: 1000, - EndTimestamp: 1500, - HashValues: []uint32{0}, - } - timeTickResult := internalpb2.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kTimeTick, - MsgID: 0, - Timestamp: 1000, - SourceID: 0, + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{0}, + }, + TimeTickMsg: internalpb2.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: 0, + Timestamp: math.MaxUint64, + SourceID: 0, + }, }, } - timeTickMsg := &msgstream.TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, + timeTickMsgPack := &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{timeTickMsg}, } - timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) // pulsar produce - const receiveBufSize = 1024 insertChannels := Params.InsertChannelNames ddChannels := Params.DDChannelNames - pulsarURL := Params.PulsarAddress - insertStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) - insertStream.SetPulsarClient(pulsarURL) + insertStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream.SetPulsarClient(Params.PulsarAddress) insertStream.CreatePulsarProducers(insertChannels) - unmarshalDispatcher := util.NewUnmarshalDispatcher() - insertStream.CreatePulsarConsumers(insertChannels, Params.MsgChannelSubName, unmarshalDispatcher, receiveBufSize) - - ddStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) - ddStream.SetPulsarClient(pulsarURL) + ddStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + ddStream.SetPulsarClient(Params.PulsarAddress) ddStream.CreatePulsarProducers(ddChannels) var insertMsgStream msgstream.MsgStream = insertStream insertMsgStream.Start() - var ddMsgStream msgstream.MsgStream = ddStream ddMsgStream.Start() - err := insertMsgStream.Produce(&msgPack) - if err != nil { - return err - } + err := insertMsgStream.Produce(&insertMsgPack) + assert.NoError(t, err) + err = insertMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + err = ddMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) - err = insertMsgStream.Broadcast(&timeTickMsgPack) - if err != nil { - return err - } - err = ddMsgStream.Broadcast(&timeTickMsgPack) - if err != nil { - return err + //generate search data and send search msg + searchRowData := indexRowData[42*(DIM/8) : 43*(DIM/8)] + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"JACCARD\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + placeholderValue := milvuspb.PlaceholderValue{ + Tag: "$0", + Type: milvuspb.PlaceholderType_VECTOR_BINARY, + Values: [][]byte{searchRowData}, } - - return nil -} - -func sentTimeTick(ctx context.Context) error { - timeTickMsgPack := msgstream.MsgPack{} - baseMsg := msgstream.BaseMsg{ - BeginTimestamp: 1500, - EndTimestamp: 2000, - HashValues: []uint32{0}, + placeholderGroup := milvuspb.PlaceholderGroup{ + Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, } - timeTickResult := internalpb2.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kTimeTick, - MsgID: 0, - Timestamp: math.MaxUint64, - SourceID: 0, - }, + placeGroupByte, err := proto.Marshal(&placeholderGroup) + if err != nil { + log.Print("marshal placeholderGroup failed") } - timeTickMsg := &msgstream.TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, + query := milvuspb.SearchRequest{ + CollectionName: "collection0", + PartitionNames: []string{"default"}, + Dsl: dslString, + PlaceholderGroup: placeGroupByte, } - timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) - - // pulsar produce - const receiveBufSize = 1024 - insertChannels := Params.InsertChannelNames - ddChannels := Params.DDChannelNames - pulsarURL := Params.PulsarAddress - - insertStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) - insertStream.SetPulsarClient(pulsarURL) - insertStream.CreatePulsarProducers(insertChannels) - unmarshalDispatcher := util.NewUnmarshalDispatcher() - insertStream.CreatePulsarConsumers(insertChannels, Params.MsgChannelSubName, unmarshalDispatcher, receiveBufSize) - - ddStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) - ddStream.SetPulsarClient(pulsarURL) - ddStream.CreatePulsarProducers(ddChannels) - - var insertMsgStream msgstream.MsgStream = insertStream - insertMsgStream.Start() - - var ddMsgStream msgstream.MsgStream = ddStream - ddMsgStream.Start() - - err := insertMsgStream.Broadcast(&timeTickMsgPack) + queryByte, err := proto.Marshal(&query) if err != nil { - return err + log.Print("marshal query failed") } - err = ddMsgStream.Broadcast(&timeTickMsgPack) - if err != nil { - return err + blob := commonpb.Blob{ + Value: queryByte, } - return nil -} - -func TestSegmentLoad_Search_Vector(t *testing.T) { - collectionID := UniqueID(0) - partitionID := UniqueID(1) - segmentID := UniqueID(2) - fieldIDs := []int64{0, 101} - - // mock write insert bin log - keyPrefix := path.Join("query-node-seg-manager-test-minio-prefix", strconv.FormatInt(collectionID, 10), strconv.FormatInt(partitionID, 10)) - Params.WriteNodeSegKvSubPath = keyPrefix - - node := newQueryNodeMock() - defer node.Stop() - - ctx := node.queryNodeLoopCtx - node.loadService = newLoadService(ctx, nil, nil, nil, node.replica, nil) - - collectionName := "collection0" - initTestMeta(t, node, collectionName, collectionID, 0) - - err := node.replica.addPartition(collectionID, partitionID) - assert.NoError(t, err) - - err = node.replica.addSegment(segmentID, partitionID, collectionID, segTypeSealed) + fn := func(n int64) *msgstream.MsgPack { + searchMsg := &msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{0}, + }, + SearchRequest: internalpb2.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearch, + MsgID: n, + Timestamp: uint64(msgLength), + SourceID: 1, + }, + ResultChannelID: "0", + Query: &blob, + }, + } + return &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{searchMsg}, + } + } + searchStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchStream.SetPulsarClient(Params.PulsarAddress) + searchStream.CreatePulsarProducers(newSearchChannelNames) + searchStream.Start() + err = searchStream.Produce(fn(1)) assert.NoError(t, err) - paths, srcFieldIDs, err := generateInsertBinLog(collectionID, partitionID, segmentID, keyPrefix) - assert.NoError(t, err) + //get search result + searchResultStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchResultStream.SetPulsarClient(Params.PulsarAddress) + unmarshalDispatcher := util.NewUnmarshalDispatcher() + searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult2", unmarshalDispatcher, receiveBufSize) + searchResultStream.Start() + searchResult := searchResultStream.Consume() + assert.NotNil(t, searchResult) + unMarshaledHit := milvuspb.Hits{} + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) - fieldsMap := node.loadService.getTargetFields(paths, srcFieldIDs, fieldIDs) - assert.Equal(t, len(fieldsMap), 2) + // gen load index message pack + indexParams := make(map[string]string) + indexParams["index_type"] = "BIN_IVF_FLAT" + indexParams["index_mode"] = "cpu" + indexParams["dim"] = "128" + indexParams["k"] = "10" + indexParams["nlist"] = "100" + indexParams["nprobe"] = "10" + indexParams["m"] = "4" + indexParams["nbits"] = "8" + indexParams["metric_type"] = "JACCARD" + indexParams["SLICE_SIZE"] = "4" - segment, err := node.replica.getSegmentByID(segmentID) - assert.NoError(t, err) + var indexParamsKV []*commonpb.KeyValuePair + for key, value := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } - err = node.loadService.loadSegmentFieldsData(segment, fieldsMap) - assert.NoError(t, err) + // generator index + typeParams := make(map[string]string) + typeParams["dim"] = "128" + index, err := indexnode.NewCIndex(typeParams, indexParams) + assert.Nil(t, err) + err = index.BuildBinaryVecIndexWithoutIds(indexRowData) + assert.Equal(t, err, nil) - indexPaths, err := generateIndex(segmentID) - assert.NoError(t, err) + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: Params.MinioBucketName, + CreateBucket: true, + } - err = node.loadService.loadIndexImmediate(segment, indexPaths) + minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) + assert.Equal(t, err, nil) + //save index to minio + binarySet, err := index.Serialize() + assert.Equal(t, err, nil) + var indexCodec storage.IndexCodec + binarySet, err = indexCodec.Serialize(binarySet, indexParams) assert.NoError(t, err) + indexPaths := make([]string, 0) + for _, index := range binarySet { + path := strconv.Itoa(int(segmentID)) + "/" + index.Key + indexPaths = append(indexPaths, path) + minioKV.Save(path, string(index.Value)) + } + + //test index search result + indexResult, err := index.QueryOnBinaryVecIndexWithParam(searchRowData, indexParams) + assert.Equal(t, err, nil) + + // create loadIndexClient + fieldID := UniqueID(100) + loadIndexChannelNames := Params.LoadIndexChannelNames + client := client.NewQueryNodeClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) + client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) + + // init message stream consumer and do checks + statsMs := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) + statsMs.SetPulsarClient(Params.PulsarAddress) + statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, util.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) + statsMs.Start() + + findFiledStats := false + for { + receiveMsg := msgstream.MsgStream(statsMs).Consume() + assert.NotNil(t, receiveMsg) + assert.NotEqual(t, len(receiveMsg.Msgs), 0) + + for _, msg := range receiveMsg.Msgs { + statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) + if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { + continue + } + findFiledStats = true + assert.Equal(t, ok, true) + assert.Equal(t, len(statsMsg.FieldStats), 1) + fieldStats0 := statsMsg.FieldStats[0] + assert.Equal(t, fieldStats0.FieldID, fieldID) + assert.Equal(t, fieldStats0.CollectionID, collectionID) + assert.Equal(t, len(fieldStats0.IndexStats), 1) + indexStats0 := fieldStats0.IndexStats[0] + params := indexStats0.IndexParams + // sort index params by key + sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) + indexEqual := node.loadService.indexParamsEqual(params, indexParamsKV) + assert.Equal(t, indexEqual, true) + } - // do search - dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" - - const DIM = 16 - var searchRawData []byte - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - searchRawData = append(searchRawData, buf...) - } - placeholderValue := milvuspb.PlaceholderValue{ - Tag: "$0", - Type: milvuspb.PlaceholderType_VECTOR_FLOAT, - Values: [][]byte{searchRawData}, - } - - placeholderGroup := milvuspb.PlaceholderGroup{ - Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, + if findFiledStats { + break + } } - placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup) + err = searchStream.Produce(fn(2)) assert.NoError(t, err) - - searchTimestamp := Timestamp(1020) - collection, err := node.replica.getCollectionByID(collectionID) - assert.NoError(t, err) - plan, err := createPlan(*collection, dslString) - assert.NoError(t, err) - holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) - assert.NoError(t, err) - placeholderGroups := make([]*PlaceholderGroup, 0) - placeholderGroups = append(placeholderGroups, holder) - - // wait for segment building index - time.Sleep(3 * time.Second) - - _, err = segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}) + searchResult = searchResultStream.Consume() + assert.NotNil(t, searchResult) + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) assert.Nil(t, err) - plan.delete() - holder.delete() - - <-ctx.Done() + idsIndex := indexResult.IDs() + idsSegment := unMarshaledHit.IDs + assert.Equal(t, len(idsIndex), len(idsSegment)) + for i := 0; i < len(idsIndex); i++ { + assert.Equal(t, idsIndex[i], idsSegment[i]) + } + Params.SearchChannelNames = oldSearchChannelNames + Params.SearchResultChannelNames = oldSearchResultChannelNames + Params.LoadIndexChannelNames = oldLoadIndexChannelNames + Params.StatsChannelName = oldStatsChannelName + fmt.Println("loadIndex binaryVector test Done!") + + defer assert.Equal(t, findFiledStats, true) + <-node.queryNodeLoopCtx.Done() + node.Stop() } diff --git a/internal/querynode/param_table.go b/internal/querynode/param_table.go index e8a9a0a2fe9b6b75c5aff1647772b0a64d1bc28f..7a41c50e3ecc399c11ceee15662ca930be16aee7 100644 --- a/internal/querynode/param_table.go +++ b/internal/querynode/param_table.go @@ -58,6 +58,11 @@ type ParamTable struct { StatsChannelName string StatsReceiveBufSize int64 + // load index + LoadIndexChannelNames []string + LoadIndexReceiveBufSize int64 + LoadIndexPulsarBufSize int64 + GracefulTime int64 MsgChannelSubName string DefaultPartitionTag string @@ -157,6 +162,10 @@ func (p *ParamTable) Init() { p.initStatsPublishInterval() p.initStatsChannelName() p.initStatsReceiveBufSize() + + p.initLoadIndexChannelNames() + p.initLoadIndexReceiveBufSize() + p.initLoadIndexPulsarBufSize() } // ---------------------------------------------------------- query node @@ -478,3 +487,19 @@ func (p *ParamTable) initSliceIndex() { } p.SliceIndex = -1 } + +func (p *ParamTable) initLoadIndexChannelNames() { + loadIndexChannelName, err := p.Load("msgChannel.chanNamePrefix.cmd") + if err != nil { + panic(err) + } + p.LoadIndexChannelNames = []string{loadIndexChannelName} +} + +func (p *ParamTable) initLoadIndexReceiveBufSize() { + p.LoadIndexReceiveBufSize = p.ParseInt64("queryNode.msgStream.loadIndex.recvBufSize") +} + +func (p *ParamTable) initLoadIndexPulsarBufSize() { + p.LoadIndexPulsarBufSize = p.ParseInt64("queryNode.msgStream.loadIndex.pulsarBufSize") +} diff --git a/internal/querynode/param_table_test.go b/internal/querynode/param_table_test.go index 04bf7807f28c7bcc3df7de9449e31c56417cad51..d58eec9784331e8d263031c7e51c42277ad43d1a 100644 --- a/internal/querynode/param_table_test.go +++ b/internal/querynode/param_table_test.go @@ -60,6 +60,24 @@ func TestParamTable_minio(t *testing.T) { }) } +func TestParamTable_LoadIndex(t *testing.T) { + t.Run("Test channel names", func(t *testing.T) { + names := Params.LoadIndexChannelNames + assert.Equal(t, len(names), 1) + assert.Contains(t, names[0], "cmd") + }) + + t.Run("Test recvBufSize", func(t *testing.T) { + size := Params.LoadIndexReceiveBufSize + assert.Equal(t, size, int64(512)) + }) + + t.Run("Test pulsarBufSize", func(t *testing.T) { + size := Params.LoadIndexPulsarBufSize + assert.Equal(t, size, int64(512)) + }) +} + func TestParamTable_insertChannelRange(t *testing.T) { channelRange := Params.InsertChannelRange assert.Equal(t, 2, len(channelRange)) diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index c399915a165df6dc62db81ff3c9f45645388c708..c6b85d19414da5694acc072ee046433004954de9 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -455,7 +455,7 @@ func (node *QueryNode) LoadSegments(in *queryPb.LoadSegmentRequest) (*commonpb.S if in.LastSegmentState.State == datapb.SegmentState_SegmentGrowing { segmentNum := len(segmentIDs) positions := in.LastSegmentState.StartPositions - err = node.loadService.seekSegment(positions) + err = node.loadService.segManager.seekSegment(positions) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, @@ -466,7 +466,7 @@ func (node *QueryNode) LoadSegments(in *queryPb.LoadSegmentRequest) (*commonpb.S segmentIDs = segmentIDs[:segmentNum-1] } - err = node.loadService.loadSegment(collectionID, partitionID, segmentIDs, fieldIDs) + err = node.loadService.segManager.loadSegment(collectionID, partitionID, segmentIDs, fieldIDs) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, @@ -493,7 +493,7 @@ func (node *QueryNode) ReleaseSegments(in *queryPb.ReleaseSegmentRequest) (*comm // release all fields in the segments for _, id := range in.SegmentIDs { - err := node.loadService.releaseSegment(id) + err := node.loadService.segManager.releaseSegment(id) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index d489e9ed91a1903a2f76c2774bb4a9a1ee131ab9..287adc23c7ee5e0deb26b1828cb3535d46e0bd64 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -159,6 +159,7 @@ func refreshChannelNames() { Params.SearchChannelNames = makeNewChannelNames(Params.SearchChannelNames, suffix) Params.SearchResultChannelNames = makeNewChannelNames(Params.SearchResultChannelNames, suffix) Params.StatsChannelName = Params.StatsChannelName + suffix + Params.LoadIndexChannelNames = makeNewChannelNames(Params.LoadIndexChannelNames, suffix) } func (q *queryServiceMock) RegisterNode(req *querypb.RegisterNodeRequest) (*querypb.RegisterNodeResponse, error) { diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index 98e594cabbb71d4eab4497483dbedde491942065..56da38bff0197d1f13e58038275525f4f3613fa3 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -126,8 +126,6 @@ func deleteSegment(segment *Segment) { */ cPtr := segment.segmentPtr C.DeleteSegment(cPtr) - segment.segmentPtr = nil - segment = nil } func (s *Segment) getRowCount() int64 { @@ -135,9 +133,6 @@ func (s *Segment) getRowCount() int64 { long int getRowCount(CSegmentInterface c_segment); */ - if s.segmentPtr == nil { - return -1 - } var rowCount = C.GetRowCount(s.segmentPtr) return int64(rowCount) } @@ -147,9 +142,6 @@ func (s *Segment) getDeletedCount() int64 { long int getDeletedCount(CSegmentInterface c_segment); */ - if s.segmentPtr == nil { - return -1 - } var deletedCount = C.GetDeletedCount(s.segmentPtr) return int64(deletedCount) } @@ -159,9 +151,6 @@ func (s *Segment) getMemSize() int64 { long int GetMemoryUsageInBytes(CSegmentInterface c_segment); */ - if s.segmentPtr == nil { - return -1 - } var memoryUsageInBytes = C.GetMemoryUsageInBytes(s.segmentPtr) return int64(memoryUsageInBytes) @@ -179,9 +168,7 @@ func (s *Segment) segmentSearch(plan *Plan, long int* result_ids, float* result_distances); */ - if s.segmentPtr == nil { - return nil, errors.New("null seg core pointer") - } + cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) for _, pg := range placeHolderGroups { cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) @@ -207,9 +194,6 @@ func (s *Segment) segmentSearch(plan *Plan, func (s *Segment) fillTargetEntry(plan *Plan, result *SearchResult) error { - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } var status = C.FillTargetEntry(s.segmentPtr, plan.cPlan, result.cQueryResult) errorCode := status.error_code @@ -225,9 +209,6 @@ func (s *Segment) fillTargetEntry(plan *Plan, // segment, err := loadService.replica.getSegmentByID(segmentID) func (s *Segment) updateSegmentIndex(loadIndexInfo *LoadIndexInfo) error { - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } var status C.CStatus if s.segmentType == segTypeGrowing { @@ -256,7 +237,7 @@ func (s *Segment) setIndexParam(fieldID int64, indexParamKv []*commonpb.KeyValue defer s.paramMutex.Unlock() indexParamMap := make(indexParam) if indexParamKv == nil { - return errors.New("empty loadIndexMsg's indexParam") + return errors.New("loadIndexMsg's indexParam empty") } for _, param := range indexParamKv { indexParamMap[param.Key] = param.Value @@ -320,9 +301,6 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps int sizeof_per_row, signed long int count); */ - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } // Blobs to one big blob var numOfRow = len(*entityIDs) var sizeofPerRow = len((*records)[0].Value) @@ -373,9 +351,6 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps const long* primary_keys, const unsigned long* timestamps); */ - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } var cOffset = C.long(offset) var cSize = C.long(len(*entityIDs)) var cEntityIdsPtr = (*C.long)(&(*entityIDs)[0]) @@ -400,9 +375,6 @@ func (s *Segment) segmentLoadFieldData(fieldID int64, rowCount int, data interfa CStatus LoadFieldData(CSegmentInterface c_segment, CLoadFieldDataInfo load_field_data_info); */ - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } if s.segmentType != segTypeSealed { return errors.New("illegal segment type when loading field data") } diff --git a/internal/querynode/segment_manager.go b/internal/querynode/segment_manager.go new file mode 100644 index 0000000000000000000000000000000000000000..1e17d18580920727b3873f80128be7cd120cbbb4 --- /dev/null +++ b/internal/querynode/segment_manager.go @@ -0,0 +1,341 @@ +package querynode + +import ( + "context" + "errors" + "strconv" + + "github.com/zilliztech/milvus-distributed/internal/kv" + miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/datapb" + "github.com/zilliztech/milvus-distributed/internal/proto/indexpb" + internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" + "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" + "github.com/zilliztech/milvus-distributed/internal/storage" +) + +type segmentManager struct { + replica collectionReplica + + dmStream msgstream.MsgStream + loadIndexReqChan chan []msgstream.TsMsg + + masterClient MasterServiceInterface + dataClient DataServiceInterface + indexClient IndexServiceInterface + + kv kv.Base // minio kv + iCodec *storage.InsertCodec +} + +func (s *segmentManager) seekSegment(positions []*internalPb.MsgPosition) error { + // TODO: open seek + //for _, position := range positions { + // err := s.dmStream.Seek(position) + // if err != nil { + // return err + // } + //} + return nil +} + +func (s *segmentManager) getIndexInfo(collectionID UniqueID, segmentID UniqueID) (UniqueID, UniqueID, error) { + req := &milvuspb.DescribeSegmentRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kDescribeSegment, + }, + CollectionID: collectionID, + SegmentID: segmentID, + } + response, err := s.masterClient.DescribeSegment(req) + if err != nil { + return 0, 0, err + } + return response.IndexID, response.BuildID, nil +} + +func (s *segmentManager) loadSegment(collectionID UniqueID, partitionID UniqueID, segmentIDs []UniqueID, fieldIDs []int64) error { + // TODO: interim solution + if len(fieldIDs) == 0 { + collection, err := s.replica.getCollectionByID(collectionID) + if err != nil { + return err + } + fieldIDs = make([]int64, 0) + for _, field := range collection.Schema().Fields { + fieldIDs = append(fieldIDs, field.FieldID) + } + } + for _, segmentID := range segmentIDs { + // we don't need index id yet + _, buildID, err := s.getIndexInfo(collectionID, segmentID) + if err == nil { + // we don't need load to vector fields + vectorFields, err := s.replica.getVecFieldsBySegmentID(segmentID) + if err != nil { + return err + } + fieldIDs = s.filterOutVectorFields(fieldIDs, vectorFields) + } + paths, srcFieldIDs, err := s.getInsertBinlogPaths(segmentID) + if err != nil { + return err + } + + targetFields := s.getTargetFields(paths, srcFieldIDs, fieldIDs) + // replace segment + err = s.replica.removeSegment(segmentID) + if err != nil { + return err + } + err = s.replica.addSegment(segmentID, partitionID, collectionID, segTypeSealed) + if err != nil { + return err + } + err = s.loadSegmentFieldsData(segmentID, targetFields) + if err != nil { + return err + } + indexPaths, err := s.getIndexPaths(buildID) + if err != nil { + return err + } + err = s.loadIndex(segmentID, indexPaths) + if err != nil { + // TODO: return or continue? + return err + } + } + return nil +} + +func (s *segmentManager) releaseSegment(segmentID UniqueID) error { + err := s.replica.removeSegment(segmentID) + return err +} + +//------------------------------------------------------------------------------------------------- internal functions +func (s *segmentManager) getInsertBinlogPaths(segmentID UniqueID) ([]*internalPb.StringList, []int64, error) { + if s.dataClient == nil { + return nil, nil, errors.New("null data service client") + } + + insertBinlogPathRequest := &datapb.InsertBinlogPathRequest{ + SegmentID: segmentID, + } + + pathResponse, err := s.dataClient.GetInsertBinlogPaths(insertBinlogPathRequest) + if err != nil { + return nil, nil, err + } + + if len(pathResponse.FieldIDs) != len(pathResponse.Paths) { + return nil, nil, errors.New("illegal InsertBinlogPathsResponse") + } + + return pathResponse.Paths, pathResponse.FieldIDs, nil +} + +func (s *segmentManager) filterOutVectorFields(fieldIDs []int64, vectorFields map[int64]string) []int64 { + targetFields := make([]int64, 0) + for _, id := range fieldIDs { + if _, ok := vectorFields[id]; !ok { + targetFields = append(targetFields, id) + } + } + return targetFields +} + +func (s *segmentManager) getTargetFields(paths []*internalPb.StringList, srcFieldIDS []int64, dstFields []int64) map[int64]*internalPb.StringList { + targetFields := make(map[int64]*internalPb.StringList) + + containsFunc := func(s []int64, e int64) bool { + for _, a := range s { + if a == e { + return true + } + } + return false + } + + for i, fieldID := range srcFieldIDS { + if containsFunc(dstFields, fieldID) { + targetFields[fieldID] = paths[i] + } + } + + return targetFields +} + +func (s *segmentManager) loadSegmentFieldsData(segmentID UniqueID, targetFields map[int64]*internalPb.StringList) error { + for id, p := range targetFields { + if id == timestampFieldID { + // seg core doesn't need timestamp field + continue + } + + paths := p.Values + blobs := make([]*storage.Blob, 0) + for _, path := range paths { + binLog, err := s.kv.Load(path) + if err != nil { + // TODO: return or continue? + return err + } + blobs = append(blobs, &storage.Blob{ + Key: strconv.FormatInt(id, 10), // TODO: key??? + Value: []byte(binLog), + }) + } + _, _, insertData, err := s.iCodec.Deserialize(blobs) + if err != nil { + // TODO: return or continue + return err + } + if len(insertData.Data) != 1 { + return errors.New("we expect only one field in deserialized insert data") + } + + for _, value := range insertData.Data { + var numRows int + var data interface{} + + switch fieldData := value.(type) { + case *storage.BoolFieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.Int8FieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.Int16FieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.Int32FieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.Int64FieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.FloatFieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.DoubleFieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case storage.StringFieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.FloatVectorFieldData: + numRows = fieldData.NumRows + data = fieldData.Data + case *storage.BinaryVectorFieldData: + numRows = fieldData.NumRows + data = fieldData.Data + default: + return errors.New("unexpected field data type") + } + + segment, err := s.replica.getSegmentByID(segmentID) + if err != nil { + // TODO: return or continue? + return err + } + err = segment.segmentLoadFieldData(id, numRows, data) + if err != nil { + // TODO: return or continue? + return err + } + } + } + + return nil +} + +func (s *segmentManager) getIndexPaths(buildID UniqueID) ([]string, error) { + if s.indexClient == nil { + return nil, errors.New("null index service client") + } + + indexFilePathRequest := &indexpb.IndexFilePathsRequest{ + // TODO: rename indexIDs to buildIDs + IndexIDs: []UniqueID{buildID}, + } + pathResponse, err := s.indexClient.GetIndexFilePaths(indexFilePathRequest) + if err != nil || pathResponse.Status.ErrorCode != commonpb.ErrorCode_SUCCESS { + return nil, err + } + + if len(pathResponse.FilePaths) <= 0 { + return nil, errors.New("illegal index file paths") + } + + return pathResponse.FilePaths[0].IndexFilePaths, nil +} + +func (s *segmentManager) loadIndex(segmentID UniqueID, indexPaths []string) error { + // get vector field ids from schema to load index + vecFieldIDs, err := s.replica.getVecFieldsBySegmentID(segmentID) + if err != nil { + return err + } + for id, name := range vecFieldIDs { + // non-blocking sending + go s.sendLoadIndex(indexPaths, segmentID, id, name) + } + + return nil +} + +func (s *segmentManager) sendLoadIndex(indexPaths []string, + segmentID int64, + fieldID int64, + fieldName string) { + loadIndexRequest := internalPb.LoadIndex{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearchResult, + }, + SegmentID: segmentID, + FieldName: fieldName, + FieldID: fieldID, + IndexPaths: indexPaths, + } + + loadIndexMsg := &msgstream.LoadIndexMsg{ + LoadIndex: loadIndexRequest, + } + + messages := []msgstream.TsMsg{loadIndexMsg} + s.loadIndexReqChan <- messages +} + +func newSegmentManager(ctx context.Context, masterClient MasterServiceInterface, dataClient DataServiceInterface, indexClient IndexServiceInterface, replica collectionReplica, dmStream msgstream.MsgStream, loadIndexReqChan chan []msgstream.TsMsg) *segmentManager { + bucketName := Params.MinioBucketName + option := &miniokv.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: bucketName, + CreateBucket: true, + } + + minioKV, err := miniokv.NewMinIOKV(ctx, option) + if err != nil { + panic(err) + } + + return &segmentManager{ + replica: replica, + dmStream: dmStream, + loadIndexReqChan: loadIndexReqChan, + + masterClient: masterClient, + dataClient: dataClient, + indexClient: indexClient, + + kv: minioKV, + iCodec: &storage.InsertCodec{}, + } +} diff --git a/internal/querynode/segment_manager_test.go b/internal/querynode/segment_manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..02afa508adba001eae65701882763906baa42e49 --- /dev/null +++ b/internal/querynode/segment_manager_test.go @@ -0,0 +1,590 @@ +package querynode + +import ( + "context" + "encoding/binary" + "fmt" + "math" + "math/rand" + "path" + "strconv" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + "github.com/zilliztech/milvus-distributed/internal/indexnode" + minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio" + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" + "github.com/zilliztech/milvus-distributed/internal/msgstream/util" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" + "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "github.com/zilliztech/milvus-distributed/internal/storage" +) + +func generateInsertBinLog(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, keyPrefix string) ([]*internalPb.StringList, []int64, error) { + const ( + msgLength = 1000 + DIM = 16 + ) + + idData := make([]int64, 0) + for n := 0; n < msgLength; n++ { + idData = append(idData, int64(n)) + } + + var timestamps []int64 + for n := 0; n < msgLength; n++ { + timestamps = append(timestamps, int64(n+1)) + } + + var fieldAgeData []int32 + for n := 0; n < msgLength; n++ { + fieldAgeData = append(fieldAgeData, int32(n)) + } + + fieldVecData := make([]float32, 0) + for n := 0; n < msgLength; n++ { + for i := 0; i < DIM; i++ { + fieldVecData = append(fieldVecData, float32(n*i)*0.1) + } + } + + insertData := &storage.InsertData{ + Data: map[int64]storage.FieldData{ + 0: &storage.Int64FieldData{ + NumRows: msgLength, + Data: idData, + }, + 1: &storage.Int64FieldData{ + NumRows: msgLength, + Data: timestamps, + }, + 100: &storage.FloatVectorFieldData{ + NumRows: msgLength, + Data: fieldVecData, + Dim: DIM, + }, + 101: &storage.Int32FieldData{ + NumRows: msgLength, + Data: fieldAgeData, + }, + }, + } + + // buffer data to binLogs + collMeta := genTestCollectionMeta("collection0", collectionID, false) + collMeta.Schema.Fields = append(collMeta.Schema.Fields, &schemapb.FieldSchema{ + FieldID: 0, + Name: "uid", + DataType: schemapb.DataType_INT64, + }) + collMeta.Schema.Fields = append(collMeta.Schema.Fields, &schemapb.FieldSchema{ + FieldID: 1, + Name: "timestamp", + DataType: schemapb.DataType_INT64, + }) + inCodec := storage.NewInsertCodec(collMeta) + binLogs, err := inCodec.Serialize(partitionID, segmentID, insertData) + + if err != nil { + return nil, nil, err + } + + // create minio client + bucketName := Params.MinioBucketName + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: bucketName, + CreateBucket: true, + } + kv, err := minioKV.NewMinIOKV(context.Background(), option) + if err != nil { + return nil, nil, err + } + + // binLogs -> minIO/S3 + segIDStr := strconv.FormatInt(segmentID, 10) + keyPrefix = path.Join(keyPrefix, segIDStr) + + paths := make([]*internalPb.StringList, 0) + fieldIDs := make([]int64, 0) + fmt.Println(".. saving binlog to MinIO ...", len(binLogs)) + for _, blob := range binLogs { + uid := rand.Int63n(100000000) + key := path.Join(keyPrefix, blob.Key, strconv.FormatInt(uid, 10)) + err = kv.Save(key, string(blob.Value[:])) + if err != nil { + return nil, nil, err + } + paths = append(paths, &internalPb.StringList{ + Values: []string{key}, + }) + fieldID, err := strconv.Atoi(blob.Key) + if err != nil { + return nil, nil, err + } + fieldIDs = append(fieldIDs, int64(fieldID)) + } + + return paths, fieldIDs, nil +} + +func generateIndex(segmentID UniqueID) ([]string, error) { + const ( + msgLength = 1000 + DIM = 16 + ) + + indexParams := make(map[string]string) + indexParams["index_type"] = "IVF_PQ" + indexParams["index_mode"] = "cpu" + indexParams["dim"] = "16" + indexParams["k"] = "10" + indexParams["nlist"] = "100" + indexParams["nprobe"] = "10" + indexParams["m"] = "4" + indexParams["nbits"] = "8" + indexParams["metric_type"] = "L2" + indexParams["SLICE_SIZE"] = "4" + + var indexParamsKV []*commonpb.KeyValuePair + for key, value := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } + + typeParams := make(map[string]string) + typeParams["dim"] = strconv.Itoa(DIM) + var indexRowData []float32 + for n := 0; n < msgLength; n++ { + for i := 0; i < DIM; i++ { + indexRowData = append(indexRowData, float32(n*i)) + } + } + + index, err := indexnode.NewCIndex(typeParams, indexParams) + if err != nil { + return nil, err + } + + err = index.BuildFloatVecIndexWithoutIds(indexRowData) + if err != nil { + return nil, err + } + + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: Params.MinioBucketName, + CreateBucket: true, + } + + kv, err := minioKV.NewMinIOKV(context.Background(), option) + if err != nil { + return nil, err + } + + // save index to minio + binarySet, err := index.Serialize() + if err != nil { + return nil, err + } + + // serialize index params + var indexCodec storage.IndexCodec + serializedIndexBlobs, err := indexCodec.Serialize(binarySet, indexParams) + if err != nil { + return nil, err + } + + indexPaths := make([]string, 0) + for _, index := range serializedIndexBlobs { + p := strconv.Itoa(int(segmentID)) + "/" + index.Key + indexPaths = append(indexPaths, p) + err := kv.Save(p, string(index.Value)) + if err != nil { + return nil, err + } + } + + return indexPaths, nil +} + +func doInsert(ctx context.Context, collectionName string, partitionTag string, segmentID UniqueID) error { + const msgLength = 1000 + const DIM = 16 + + var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var rawData []byte + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) + rawData = append(rawData, buf...) + } + bs := make([]byte, 4) + binary.LittleEndian.PutUint32(bs, 1) + rawData = append(rawData, bs...) + + timeRange := TimeRange{ + timestampMin: 0, + timestampMax: math.MaxUint64, + } + + // messages generate + insertMessages := make([]msgstream.TsMsg, 0) + for i := 0; i < msgLength; i++ { + var msg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{ + uint32(i), + }, + }, + InsertRequest: internalPb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: 0, + Timestamp: uint64(i + 1000), + SourceID: 0, + }, + CollectionName: collectionName, + PartitionName: partitionTag, + SegmentID: segmentID, + ChannelID: "0", + Timestamps: []uint64{uint64(i + 1000)}, + RowIDs: []int64{int64(i)}, + RowData: []*commonpb.Blob{ + {Value: rawData}, + }, + }, + } + insertMessages = append(insertMessages, msg) + } + + msgPack := msgstream.MsgPack{ + BeginTs: timeRange.timestampMin, + EndTs: timeRange.timestampMax, + Msgs: insertMessages, + } + + // generate timeTick + timeTickMsgPack := msgstream.MsgPack{} + baseMsg := msgstream.BaseMsg{ + BeginTimestamp: 1000, + EndTimestamp: 1500, + HashValues: []uint32{0}, + } + timeTickResult := internalPb.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: 0, + Timestamp: 1000, + SourceID: 0, + }, + } + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickResult, + } + timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) + + // pulsar produce + const receiveBufSize = 1024 + insertChannels := Params.InsertChannelNames + ddChannels := Params.DDChannelNames + pulsarURL := Params.PulsarAddress + + insertStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) + insertStream.SetPulsarClient(pulsarURL) + insertStream.CreatePulsarProducers(insertChannels) + unmarshalDispatcher := util.NewUnmarshalDispatcher() + insertStream.CreatePulsarConsumers(insertChannels, Params.MsgChannelSubName, unmarshalDispatcher, receiveBufSize) + + ddStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) + ddStream.SetPulsarClient(pulsarURL) + ddStream.CreatePulsarProducers(ddChannels) + + var insertMsgStream msgstream.MsgStream = insertStream + insertMsgStream.Start() + + var ddMsgStream msgstream.MsgStream = ddStream + ddMsgStream.Start() + + err := insertMsgStream.Produce(&msgPack) + if err != nil { + return err + } + + err = insertMsgStream.Broadcast(&timeTickMsgPack) + if err != nil { + return err + } + err = ddMsgStream.Broadcast(&timeTickMsgPack) + if err != nil { + return err + } + + return nil +} + +func sentTimeTick(ctx context.Context) error { + timeTickMsgPack := msgstream.MsgPack{} + baseMsg := msgstream.BaseMsg{ + BeginTimestamp: 1500, + EndTimestamp: 2000, + HashValues: []uint32{0}, + } + timeTickResult := internalPb.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: 0, + Timestamp: math.MaxUint64, + SourceID: 0, + }, + } + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickResult, + } + timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) + + // pulsar produce + const receiveBufSize = 1024 + insertChannels := Params.InsertChannelNames + ddChannels := Params.DDChannelNames + pulsarURL := Params.PulsarAddress + + insertStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) + insertStream.SetPulsarClient(pulsarURL) + insertStream.CreatePulsarProducers(insertChannels) + unmarshalDispatcher := util.NewUnmarshalDispatcher() + insertStream.CreatePulsarConsumers(insertChannels, Params.MsgChannelSubName, unmarshalDispatcher, receiveBufSize) + + ddStream := pulsarms.NewPulsarMsgStream(ctx, receiveBufSize) + ddStream.SetPulsarClient(pulsarURL) + ddStream.CreatePulsarProducers(ddChannels) + + var insertMsgStream msgstream.MsgStream = insertStream + insertMsgStream.Start() + + var ddMsgStream msgstream.MsgStream = ddStream + ddMsgStream.Start() + + err := insertMsgStream.Broadcast(&timeTickMsgPack) + if err != nil { + return err + } + err = ddMsgStream.Broadcast(&timeTickMsgPack) + if err != nil { + return err + } + return nil +} + +func TestSegmentManager_load_release_and_search(t *testing.T) { + collectionID := UniqueID(0) + partitionID := UniqueID(1) + segmentID := UniqueID(2) + fieldIDs := []int64{0, 101} + + // mock write insert bin log + keyPrefix := path.Join("query-node-seg-manager-test-minio-prefix", strconv.FormatInt(collectionID, 10), strconv.FormatInt(partitionID, 10)) + Params.WriteNodeSegKvSubPath = keyPrefix + + node := newQueryNodeMock() + defer node.Stop() + + ctx := node.queryNodeLoopCtx + node.loadService = newLoadService(ctx, nil, nil, nil, node.replica, nil) + go node.loadService.start() + + collectionName := "collection0" + initTestMeta(t, node, collectionName, collectionID, 0) + + err := node.replica.addPartition(collectionID, partitionID) + assert.NoError(t, err) + + err = node.replica.addSegment(segmentID, partitionID, collectionID, segTypeSealed) + assert.NoError(t, err) + + paths, srcFieldIDs, err := generateInsertBinLog(collectionID, partitionID, segmentID, keyPrefix) + assert.NoError(t, err) + + fieldsMap := node.loadService.segManager.getTargetFields(paths, srcFieldIDs, fieldIDs) + assert.Equal(t, len(fieldsMap), 2) + + err = node.loadService.segManager.loadSegmentFieldsData(segmentID, fieldsMap) + assert.NoError(t, err) + + indexPaths, err := generateIndex(segmentID) + assert.NoError(t, err) + + err = node.loadService.segManager.loadIndex(segmentID, indexPaths) + assert.NoError(t, err) + + // do search + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + + const DIM = 16 + var searchRawData []byte + var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) + searchRawData = append(searchRawData, buf...) + } + placeholderValue := milvuspb.PlaceholderValue{ + Tag: "$0", + Type: milvuspb.PlaceholderType_VECTOR_FLOAT, + Values: [][]byte{searchRawData}, + } + + placeholderGroup := milvuspb.PlaceholderGroup{ + Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, + } + + placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup) + assert.NoError(t, err) + + searchTimestamp := Timestamp(1020) + collection, err := node.replica.getCollectionByID(collectionID) + assert.NoError(t, err) + plan, err := createPlan(*collection, dslString) + assert.NoError(t, err) + holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) + assert.NoError(t, err) + placeholderGroups := make([]*PlaceholderGroup, 0) + placeholderGroups = append(placeholderGroups, holder) + + // wait for segment building index + time.Sleep(3 * time.Second) + + segment, err := node.replica.getSegmentByID(segmentID) + assert.NoError(t, err) + _, err = segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}) + assert.Nil(t, err) + + plan.delete() + holder.delete() + + <-ctx.Done() +} + +//// NOTE: start pulsar before test +//func TestSegmentManager_with_seek(t *testing.T) { +// collectionID := UniqueID(0) +// partitionID := UniqueID(1) +// //segmentID := UniqueID(2) +// fieldIDs := []int64{0, 101} +// +// //// mock write insert bin log +// //keyPrefix := path.Join("query-node-seg-manager-test-minio-prefix", strconv.FormatInt(collectionID, 10), strconv.FormatInt(partitionID, 10)) +// //Params.WriteNodeSegKvSubPath = keyPrefix + "/" +// node := newQueryNodeMock() +// +// ctx := node.queryNodeLoopCtx +// go node.Start() +// +// collectionName := "collection0" +// initTestMeta(t, node, collectionName, collectionID, 0) +// +// err := node.replica.addPartition(collectionID, partitionID) +// assert.NoError(t, err) +// +// //err = node.replica.addSegment(segmentID, partitionID, collectionID, segTypeSealed) +// //assert.NoError(t, err) +// +// //paths, srcFieldIDs, err := generateInsertBinLog(collectionID, partitionID, segmentID, keyPrefix) +// //assert.NoError(t, err) +// +// //fieldsMap := node.segManager.getTargetFields(paths, srcFieldIDs, fieldIDs) +// //assert.Equal(t, len(fieldsMap), 2) +// +// segmentIDToInsert := UniqueID(3) +// err = doInsert(ctx, collectionName, "default", segmentIDToInsert) +// assert.NoError(t, err) +// +// startPositions := make([]*internalPb.MsgPosition, 0) +// for _, ch := range Params.InsertChannelNames { +// startPositions = append(startPositions, &internalPb.MsgPosition{ +// ChannelName: ch, +// }) +// } +// var positions []*internalPb.MsgPosition +// lastSegStates := &datapb.SegmentStatesResponse{ +// State: datapb.SegmentState_SegmentGrowing, +// StartPositions: positions, +// } +// loadReq := &querypb.LoadSegmentRequest{ +// CollectionID: collectionID, +// PartitionID: partitionID, +// SegmentIDs: []UniqueID{segmentIDToInsert}, +// FieldIDs: fieldIDs, +// LastSegmentState: lastSegStates, +// } +// _, err = node.LoadSegments(loadReq) +// assert.NoError(t, err) +// +// err = sentTimeTick(ctx) +// assert.NoError(t, err) +// +// // do search +// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" +// +// const DIM = 16 +// var searchRawData []byte +// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +// for _, ele := range vec { +// buf := make([]byte, 4) +// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) +// searchRawData = append(searchRawData, buf...) +// } +// placeholderValue := milvuspb.PlaceholderValue{ +// Tag: "$0", +// Type: milvuspb.PlaceholderType_VECTOR_FLOAT, +// Values: [][]byte{searchRawData}, +// } +// +// placeholderGroup := milvuspb.PlaceholderGroup{ +// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, +// } +// +// placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup) +// assert.NoError(t, err) +// +// //searchTimestamp := Timestamp(1020) +// collection, err := node.replica.getCollectionByID(collectionID) +// assert.NoError(t, err) +// plan, err := createPlan(*collection, dslString) +// assert.NoError(t, err) +// holder, err := parserPlaceholderGroup(plan, placeHolderGroupBlob) +// assert.NoError(t, err) +// placeholderGroups := make([]*PlaceholderGroup, 0) +// placeholderGroups = append(placeholderGroups, holder) +// +// // wait for segment building index +// time.Sleep(3 * time.Second) +// +// //segment, err := node.replica.getSegmentByID(segmentIDToInsert) +// //assert.NoError(t, err) +// //_, err = segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}) +// //assert.Nil(t, err) +// +// plan.delete() +// holder.delete() +// +// <-ctx.Done() +// err = node.Stop() +// assert.NoError(t, err) +//}