From 76bc3651817c7d3af46ff37eac5adafa26f64037 Mon Sep 17 00:00:00 2001 From: xige-16 <xi.ge@zilliz.com> Date: Wed, 20 Jan 2021 17:34:50 +0800 Subject: [PATCH] Add seek function for pulsarTtMsgStream Signed-off-by: xige-16 <xi.ge@zilliz.com> --- internal/msgstream/msg.go | 11 + internal/msgstream/msgstream.go | 201 +---------- .../msgstream/pulsarms/pulsar_msgstream.go | 319 +++++++++++++---- .../pulsarms/pulsar_msgstream_test.go | 332 ++++++++++++++---- internal/msgstream/util/repack_func.go | 132 +++++++ internal/msgstream/util/unmarshal_test.go | 65 ++-- internal/msgstream/util/unpack.go | 162 --------- internal/util/typeutil/convension.go | 29 ++ 8 files changed, 722 insertions(+), 529 deletions(-) create mode 100644 internal/msgstream/util/repack_func.go delete mode 100644 internal/msgstream/util/unpack.go diff --git a/internal/msgstream/msg.go b/internal/msgstream/msg.go index 4b80037f6..06ae51112 100644 --- a/internal/msgstream/msg.go +++ b/internal/msgstream/msg.go @@ -20,6 +20,8 @@ type TsMsg interface { HashKeys() []uint32 Marshal(TsMsg) ([]byte, error) Unmarshal([]byte) (TsMsg, error) + Position() *MsgPosition + SetPosition(*MsgPosition) } type BaseMsg struct { @@ -27,6 +29,7 @@ type BaseMsg struct { BeginTimestamp Timestamp EndTimestamp Timestamp HashValues []uint32 + MsgPosition *MsgPosition } func (bm *BaseMsg) BeginTs() Timestamp { @@ -41,6 +44,14 @@ func (bm *BaseMsg) HashKeys() []uint32 { return bm.HashValues } +func (bm *BaseMsg) Position() *MsgPosition { + return bm.MsgPosition +} + +func (bm *BaseMsg) SetPosition(position *MsgPosition) { + bm.MsgPosition = position +} + /////////////////////////////////////////Insert////////////////////////////////////////// type InsertMsg struct { BaseMsg diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 6886629bd..d4f924157 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -1,10 +1,6 @@ package msgstream import ( - "sync" - - "github.com/zilliztech/milvus-distributed/internal/errors" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) @@ -12,11 +8,14 @@ import ( type UniqueID = typeutil.UniqueID type Timestamp = typeutil.Timestamp type IntPrimaryKey = typeutil.IntPrimaryKey +type MsgPosition = internalpb2.MsgPosition type MsgPack struct { - BeginTs Timestamp - EndTs Timestamp - Msgs []TsMsg + BeginTs Timestamp + EndTs Timestamp + Msgs []TsMsg + StartPositions []*MsgPosition + endPositions []*MsgPosition } type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) @@ -29,191 +28,5 @@ type MsgStream interface { Broadcast(*MsgPack) error Consume() *MsgPack Chan() <-chan *MsgPack -} - -//TODO test InMemMsgStream -/* -type InMemMsgStream struct { - buffer chan *MsgPack -} - -func (ms *InMemMsgStream) Start() {} -func (ms *InMemMsgStream) Close() {} - -func (ms *InMemMsgStream) ProduceOne(msg TsMsg) error { - msgPack := MsgPack{} - msgPack.BeginTs = msg.BeginTs() - msgPack.EndTs = msg.EndTs() - msgPack.Msgs = append(msgPack.Msgs, msg) - buffer <- &msgPack - return nil -} - -func (ms *InMemMsgStream) Produce(msgPack *MsgPack) error { - buffer <- msgPack - return nil -} - -func (ms *InMemMsgStream) Broadcast(msgPack *MsgPack) error { - return ms.Produce(msgPack) -} - -func (ms *InMemMsgStream) Consume() *MsgPack { - select { - case msgPack := <-ms.buffer: - return msgPack - } -} - -func (ms *InMemMsgStream) Chan() <- chan *MsgPack { - return buffer -} -*/ - -func CheckTimeTickMsg(msg map[int]Timestamp, isChannelReady []bool, mu *sync.RWMutex) (Timestamp, bool) { - checkMap := make(map[Timestamp]int) - var maxTime Timestamp = 0 - for _, v := range msg { - checkMap[v]++ - if v > maxTime { - maxTime = v - } - } - if len(checkMap) <= 1 { - for i := range msg { - isChannelReady[i] = false - } - return maxTime, true - } - for i := range msg { - mu.RLock() - v := msg[i] - mu.Unlock() - if v != maxTime { - isChannelReady[i] = false - } else { - isChannelReady[i] = true - } - } - - return 0, false -} - -func InsertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - if request.Type() != commonpb.MsgType_kInsert { - return nil, errors.New("msg's must be Insert") - } - insertRequest := request.(*InsertMsg) - keys := hashKeys[i] - - timestampLen := len(insertRequest.Timestamps) - rowIDLen := len(insertRequest.RowIDs) - rowDataLen := len(insertRequest.RowData) - keysLen := len(keys) - - if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { - return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal") - } - for index, key := range keys { - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - - sliceRequest := internalpb2.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kInsert, - MsgID: insertRequest.Base.MsgID, - Timestamp: insertRequest.Timestamps[index], - SourceID: insertRequest.Base.SourceID, - }, - CollectionName: insertRequest.CollectionName, - PartitionName: insertRequest.PartitionName, - SegmentID: insertRequest.SegmentID, - ChannelID: insertRequest.ChannelID, - Timestamps: []uint64{insertRequest.Timestamps[index]}, - RowIDs: []int64{insertRequest.RowIDs[index]}, - RowData: []*commonpb.Blob{insertRequest.RowData[index]}, - } - - insertMsg := &InsertMsg{ - BaseMsg: BaseMsg{ - MsgCtx: request.GetMsgContext(), - }, - InsertRequest: sliceRequest, - } - result[key].Msgs = append(result[key].Msgs, insertMsg) - } - } - return result, nil -} - -func DeleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - if request.Type() != commonpb.MsgType_kDelete { - return nil, errors.New("msg's must be Delete") - } - deleteRequest := request.(*DeleteMsg) - keys := hashKeys[i] - - timestampLen := len(deleteRequest.Timestamps) - primaryKeysLen := len(deleteRequest.PrimaryKeys) - keysLen := len(keys) - - if keysLen != timestampLen || keysLen != primaryKeysLen { - return nil, errors.New("the length of hashValue, timestamps, primaryKeys are not equal") - } - - for index, key := range keys { - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - - sliceRequest := internalpb2.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kDelete, - MsgID: deleteRequest.Base.MsgID, - Timestamp: deleteRequest.Timestamps[index], - SourceID: deleteRequest.Base.SourceID, - }, - CollectionName: deleteRequest.CollectionName, - ChannelID: deleteRequest.ChannelID, - Timestamps: []uint64{deleteRequest.Timestamps[index]}, - PrimaryKeys: []int64{deleteRequest.PrimaryKeys[index]}, - } - - deleteMsg := &DeleteMsg{ - BaseMsg: BaseMsg{ - MsgCtx: request.GetMsgContext(), - }, - DeleteRequest: sliceRequest, - } - result[key].Msgs = append(result[key].Msgs, deleteMsg) - } - } - return result, nil -} - -func DefaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range tsMsgs { - keys := hashKeys[i] - if len(keys) != 1 { - return nil, errors.New("len(msg.hashValue) must equal 1") - } - key := keys[0] - _, ok := result[key] - if !ok { - msgPack := MsgPack{} - result[key] = &msgPack - } - result[key].Msgs = append(result[key].Msgs, request) - } - return result, nil + Seek(offset *MsgPosition) error } diff --git a/internal/msgstream/pulsarms/pulsar_msgstream.go b/internal/msgstream/pulsarms/pulsar_msgstream.go index 3cf133a3d..95f02d60c 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream.go @@ -3,6 +3,7 @@ package pulsarms import ( "context" "log" + "path/filepath" "reflect" "strconv" "strings" @@ -14,10 +15,13 @@ import ( "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" oplog "github.com/opentracing/opentracing-go/log" + "github.com/zilliztech/milvus-distributed/internal/errors" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) type TsMsg = msgstream.TsMsg @@ -30,24 +34,33 @@ type IntPrimaryKey = msgstream.IntPrimaryKey type TimeTickMsg = msgstream.TimeTickMsg type QueryNodeStatsMsg = msgstream.QueryNodeStatsMsg type RepackFunc = msgstream.RepackFunc +type Consumer = pulsar.Consumer +type Producer = pulsar.Producer type PulsarMsgStream struct { - ctx context.Context - client *pulsar.Client - producers []*pulsar.Producer - consumers []*pulsar.Consumer - repackFunc RepackFunc - unmarshal *util.UnmarshalDispatcher - receiveBuf chan *MsgPack - wait *sync.WaitGroup - streamCancel func() + ctx context.Context + client pulsar.Client + producers []Producer + consumers []Consumer + consumerChannels []string + repackFunc RepackFunc + unmarshal *util.UnmarshalDispatcher + receiveBuf chan *MsgPack + wait *sync.WaitGroup + streamCancel func() } func NewPulsarMsgStream(ctx context.Context, receiveBufSize int64) *PulsarMsgStream { streamCtx, streamCancel := context.WithCancel(ctx) + producers := make([]Producer, 0) + consumers := make([]Consumer, 0) + consumerChannels := make([]string, 0) stream := &PulsarMsgStream{ - ctx: streamCtx, - streamCancel: streamCancel, + ctx: streamCtx, + streamCancel: streamCancel, + producers: producers, + consumers: consumers, + consumerChannels: consumerChannels, } stream.receiveBuf = make(chan *MsgPack, receiveBufSize) return stream @@ -58,20 +71,21 @@ func (ms *PulsarMsgStream) SetPulsarClient(address string) { if err != nil { log.Printf("Set pulsar client failed, error = %v", err) } - ms.client = &client + ms.client = client } func (ms *PulsarMsgStream) CreatePulsarProducers(channels []string) { for i := 0; i < len(channels); i++ { fn := func() error { - pp, err := (*ms.client).CreateProducer(pulsar.ProducerOptions{Topic: channels[i]}) + pp, err := ms.client.CreateProducer(pulsar.ProducerOptions{Topic: channels[i]}) if err != nil { return err } if pp == nil { return errors.New("pulsar is not ready, producer is nil") } - ms.producers = append(ms.producers, &pp) + + ms.producers = append(ms.producers, pp) return nil } err := util.Retry(10, time.Millisecond*200, fn) @@ -90,7 +104,7 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string, for i := 0; i < len(channels); i++ { fn := func() error { receiveChannel := make(chan pulsar.ConsumerMessage, pulsarBufSize) - pc, err := (*ms.client).Subscribe(pulsar.ConsumerOptions{ + pc, err := ms.client.Subscribe(pulsar.ConsumerOptions{ Topic: channels[i], SubscriptionName: subName, Type: pulsar.KeyShared, @@ -103,7 +117,8 @@ func (ms *PulsarMsgStream) CreatePulsarConsumers(channels []string, if pc == nil { return errors.New("pulsar is not ready, consumer is nil") } - ms.consumers = append(ms.consumers, &pc) + + ms.consumers = append(ms.consumers, pc) return nil } err := util.Retry(10, time.Millisecond*200, fn) @@ -131,16 +146,16 @@ func (ms *PulsarMsgStream) Close() { for _, producer := range ms.producers { if producer != nil { - (*producer).Close() + producer.Close() } } for _, consumer := range ms.consumers { if consumer != nil { - (*consumer).Close() + consumer.Close() } } if ms.client != nil { - (*ms.client).Close() + ms.client.Close() } } @@ -204,11 +219,11 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { msgType := (tsMsgs[0]).Type() switch msgType { case commonpb.MsgType_kInsert: - result, err = msgstream.InsertRepackFunc(tsMsgs, reBucketValues) + result, err = util.InsertRepackFunc(tsMsgs, reBucketValues) case commonpb.MsgType_kDelete: - result, err = msgstream.DeleteRepackFunc(tsMsgs, reBucketValues) + result, err = util.DeleteRepackFunc(tsMsgs, reBucketValues) default: - result, err = msgstream.DefaultRepackFunc(tsMsgs, reBucketValues) + result, err = util.DefaultRepackFunc(tsMsgs, reBucketValues) } } if err != nil { @@ -253,7 +268,7 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { child.LogFields(oplog.String("inject success", "inject success")) } - if _, err := (*ms.producers[k]).Send( + if _, err := ms.producers[k].Send( context.Background(), msg, ); err != nil { @@ -308,7 +323,7 @@ func (ms *PulsarMsgStream) Broadcast(msgPack *MsgPack) error { child.LogFields(oplog.String("inject success", "inject success")) } for i := 0; i < producerLen; i++ { - if _, err := (*ms.producers[i]).Send( + if _, err := ms.producers[i].Send( context.Background(), msg, ); err != nil { @@ -347,7 +362,7 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { cases := make([]reflect.SelectCase, len(ms.consumers)) for i := 0; i < len(ms.consumers); i++ { - ch := (*ms.consumers[i]).Chan() + ch := ms.consumers[i].Chan() cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} } @@ -372,7 +387,7 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { log.Printf("type assertion failed, not consumer message type") continue } - (*ms.consumers[chosen]).AckID(pulsarMsg.ID()) + ms.consumers[chosen].AckID(pulsarMsg.ID()) headerMsg := commonpb.MsgHeader{} err := proto.Unmarshal(pulsarMsg.Payload(), &headerMsg) @@ -406,7 +421,7 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { noMoreMessage := true for i := 0; i < len(ms.consumers); i++ { - if len((*ms.consumers[i]).Chan()) > 0 { + if len(ms.consumers[i].Chan()) > 0 { noMoreMessage = false } } @@ -428,10 +443,27 @@ func (ms *PulsarMsgStream) Chan() <-chan *MsgPack { return ms.receiveBuf } +func (ms *PulsarMsgStream) Seek(mp *internalpb2.MsgPosition) error { + for index, channel := range ms.consumerChannels { + if channel == mp.ChannelName { + messageID, err := typeutil.StringToPulsarMsgID(mp.MsgID) + if err != nil { + return err + } + err = ms.consumers[index].Seek(messageID) + if err != nil { + return err + } + return nil + } + } + + return errors.New("msgStream seek fail") +} + type PulsarTtMsgStream struct { PulsarMsgStream - inputBuf []TsMsg - unsolvedBuf []TsMsg + unsolvedBuf map[Consumer][]TsMsg lastTimeStamp Timestamp } @@ -457,12 +489,14 @@ func (ms *PulsarTtMsgStream) Start() { func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { defer ms.wait.Done() - ms.unsolvedBuf = make([]TsMsg, 0) - ms.inputBuf = make([]TsMsg, 0) - isChannelReady := make([]bool, len(ms.consumers)) - eofMsgTimeStamp := make(map[int]Timestamp) + ms.unsolvedBuf = make(map[Consumer][]TsMsg) + isChannelReady := make(map[Consumer]bool) + eofMsgTimeStamp := make(map[Consumer]Timestamp) spans := make(map[Timestamp]opentracing.Span) ctxs := make(map[Timestamp]context.Context) + for _, consumer := range ms.consumers { + ms.unsolvedBuf[consumer] = make([]TsMsg, 0) + } for { select { case <-ms.ctx.Done(): @@ -471,50 +505,72 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { wg := sync.WaitGroup{} mu := sync.Mutex{} findMapMutex := sync.RWMutex{} - for i := 0; i < len(ms.consumers); i++ { - if isChannelReady[i] { + for _, consumer := range ms.consumers { + if isChannelReady[consumer] { continue } wg.Add(1) - go ms.findTimeTick(i, eofMsgTimeStamp, &wg, &mu, &findMapMutex) + go ms.findTimeTick(consumer, eofMsgTimeStamp, &wg, &mu, &findMapMutex) } wg.Wait() - timeStamp, ok := msgstream.CheckTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex) + timeStamp, ok := checkTimeTickMsg(eofMsgTimeStamp, isChannelReady, &findMapMutex) if !ok || timeStamp <= ms.lastTimeStamp { - log.Printf("All timeTick's timestamps are inconsistent") + //log.Printf("All timeTick's timestamps are inconsistent") continue } timeTickBuf := make([]TsMsg, 0) - ms.inputBuf = append(ms.inputBuf, ms.unsolvedBuf...) - ms.unsolvedBuf = ms.unsolvedBuf[:0] - for _, v := range ms.inputBuf { - var ctx context.Context - var span opentracing.Span - if v.Type() == commonpb.MsgType_kInsert { - if _, ok := spans[v.BeginTs()]; !ok { - span, ctx = opentracing.StartSpanFromContext(v.GetMsgContext(), "after find time tick") - ctxs[v.BeginTs()] = ctx - spans[v.BeginTs()] = span + msgPositions := make([]*internalpb2.MsgPosition, 0) + for consumer, msgs := range ms.unsolvedBuf { + tempBuffer := make([]TsMsg, 0) + var timeTickMsg TsMsg + for _, v := range msgs { + if v.Type() == commonpb.MsgType_kTimeTick { + timeTickMsg = v + continue } - } - if v.EndTs() <= timeStamp { - timeTickBuf = append(timeTickBuf, v) + var ctx context.Context + var span opentracing.Span if v.Type() == commonpb.MsgType_kInsert { - v.SetMsgContext(ctxs[v.BeginTs()]) - spans[v.BeginTs()].Finish() - delete(spans, v.BeginTs()) + if _, ok := spans[v.BeginTs()]; !ok { + span, ctx = opentracing.StartSpanFromContext(v.GetMsgContext(), "after find time tick") + ctxs[v.BeginTs()] = ctx + spans[v.BeginTs()] = span + } } + if v.EndTs() <= timeStamp { + timeTickBuf = append(timeTickBuf, v) + if v.Type() == commonpb.MsgType_kInsert { + v.SetMsgContext(ctxs[v.BeginTs()]) + spans[v.BeginTs()].Finish() + delete(spans, v.BeginTs()) + } + } else { + tempBuffer = append(tempBuffer, v) + } + } + ms.unsolvedBuf[consumer] = tempBuffer + + if len(tempBuffer) > 0 { + msgPositions = append(msgPositions, &internalpb2.MsgPosition{ + ChannelName: tempBuffer[0].Position().ChannelName, + MsgID: tempBuffer[0].Position().MsgID, + Timestamp: timeStamp, + }) } else { - ms.unsolvedBuf = append(ms.unsolvedBuf, v) + msgPositions = append(msgPositions, &internalpb2.MsgPosition{ + ChannelName: timeTickMsg.Position().ChannelName, + MsgID: timeTickMsg.Position().MsgID, + Timestamp: timeStamp, + }) } } - ms.inputBuf = ms.inputBuf[:0] msgPack := MsgPack{ - BeginTs: ms.lastTimeStamp, - EndTs: timeStamp, - Msgs: timeTickBuf, + BeginTs: ms.lastTimeStamp, + EndTs: timeStamp, + Msgs: timeTickBuf, + StartPositions: msgPositions, } ms.receiveBuf <- &msgPack @@ -523,8 +579,8 @@ func (ms *PulsarTtMsgStream) bufMsgPackToChannel() { } } -func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, - eofMsgMap map[int]Timestamp, +func (ms *PulsarTtMsgStream) findTimeTick(consumer Consumer, + eofMsgMap map[Consumer]Timestamp, wg *sync.WaitGroup, mu *sync.Mutex, findMapMutex *sync.RWMutex) { @@ -533,12 +589,12 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, select { case <-ms.ctx.Done(): return - case pulsarMsg, ok := <-(*ms.consumers[channelIndex]).Chan(): + case pulsarMsg, ok := <-consumer.Chan(): if !ok { log.Printf("consumer closed!") return } - (*ms.consumers[channelIndex]).Ack(pulsarMsg) + consumer.Ack(pulsarMsg) headerMsg := commonpb.MsgHeader{} err := proto.Unmarshal(pulsarMsg.Payload(), &headerMsg) @@ -553,6 +609,11 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, if err != nil { log.Printf("Failed to unmarshal, error = %v", err) } + // set pulsar info to tsMsg + tsMsg.SetPosition(&msgstream.MsgPosition{ + ChannelName: pulsarMsg.Topic(), + MsgID: typeutil.PulsarMsgIDToString(pulsarMsg.ID()), + }) if tsMsg.Type() == commonpb.MsgType_kInsert { tracer := opentracing.GlobalTracer() @@ -571,15 +632,139 @@ func (ms *PulsarTtMsgStream) findTimeTick(channelIndex int, span.Finish() } + mu.Lock() + ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) + mu.Unlock() + if headerMsg.Base.MsgType == commonpb.MsgType_kTimeTick { findMapMutex.Lock() - eofMsgMap[channelIndex] = tsMsg.(*TimeTickMsg).Base.Timestamp + eofMsgMap[consumer] = tsMsg.(*TimeTickMsg).Base.Timestamp findMapMutex.Unlock() return } - mu.Lock() - ms.inputBuf = append(ms.inputBuf, tsMsg) - mu.Unlock() } } } + +func (ms *PulsarTtMsgStream) Seek(mp *internalpb2.MsgPosition) error { + for index, channel := range ms.consumerChannels { + if filepath.Base(channel) == filepath.Base(mp.ChannelName) { + messageID, err := typeutil.StringToPulsarMsgID(mp.MsgID) + if err != nil { + return err + } + consumer := ms.consumers[index] + err = (consumer).Seek(messageID) + if err != nil { + return err + } + + for { + select { + case <-ms.ctx.Done(): + return nil + case pulsarMsg, ok := <-consumer.Chan(): + if !ok { + return errors.New("consumer closed") + } + consumer.Ack(pulsarMsg) + + headerMsg := commonpb.MsgHeader{} + err := proto.Unmarshal(pulsarMsg.Payload(), &headerMsg) + if err != nil { + log.Printf("Failed to unmarshal msgHeader, error = %v", err) + } + unMarshalFunc := (*ms.unmarshal).TempMap[headerMsg.Base.MsgType] + if unMarshalFunc == nil { + panic("null unMarshalFunc for " + headerMsg.Base.MsgType.String() + " msg type") + } + tsMsg, err := unMarshalFunc(pulsarMsg.Payload()) + if err != nil { + log.Printf("Failed to unmarshal pulsarMsg, error = %v", err) + } + if tsMsg.Type() == commonpb.MsgType_kTimeTick { + if tsMsg.BeginTs() >= mp.Timestamp { + return nil + } + continue + } + if tsMsg.BeginTs() > mp.Timestamp { + ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) + } + } + } + } + } + + return errors.New("msgStream seek fail") +} + +func checkTimeTickMsg(msg map[Consumer]Timestamp, + isChannelReady map[Consumer]bool, + mu *sync.RWMutex) (Timestamp, bool) { + checkMap := make(map[Timestamp]int) + var maxTime Timestamp = 0 + for _, v := range msg { + checkMap[v]++ + if v > maxTime { + maxTime = v + } + } + if len(checkMap) <= 1 { + for consumer := range msg { + isChannelReady[consumer] = false + } + return maxTime, true + } + for consumer := range msg { + mu.RLock() + v := msg[consumer] + mu.Unlock() + if v != maxTime { + isChannelReady[consumer] = false + } else { + isChannelReady[consumer] = true + } + } + + return 0, false +} + +//TODO test InMemMsgStream +/* +type InMemMsgStream struct { + buffer chan *MsgPack +} + +func (ms *InMemMsgStream) Start() {} +func (ms *InMemMsgStream) Close() {} + +func (ms *InMemMsgStream) ProduceOne(msg TsMsg) error { + msgPack := MsgPack{} + msgPack.BeginTs = msg.BeginTs() + msgPack.EndTs = msg.EndTs() + msgPack.Msgs = append(msgPack.Msgs, msg) + buffer <- &msgPack + return nil +} + +func (ms *InMemMsgStream) Produce(msgPack *MsgPack) error { + buffer <- msgPack + return nil +} + +func (ms *InMemMsgStream) Broadcast(msgPack *MsgPack) error { + return ms.Produce(msgPack) +} + +func (ms *InMemMsgStream) Consume() *MsgPack { + select { + case msgPack := <-ms.buffer: + return msgPack + } +} + +func (ms *InMemMsgStream) Chan() <- chan *MsgPack { + return buffer +} +*/ diff --git a/internal/msgstream/pulsarms/pulsar_msgstream_test.go b/internal/msgstream/pulsarms/pulsar_msgstream_test.go index b314008ee..e4c4f4c60 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream_test.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream_test.go @@ -7,6 +7,8 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -22,11 +24,156 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } +func repackFunc(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range msgs { + keys := hashKeys[i] + for _, channelID := range keys { + _, ok := result[channelID] + if ok == false { + msgPack := MsgPack{} + result[channelID] = &msgPack + } + result[channelID].Msgs = append(result[channelID].Msgs, request) + } + } + return result, nil +} + +func getTsMsg(msgType MsgType, reqID UniqueID, hashValue uint32) TsMsg { + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{hashValue}, + } + switch msgType { + case commonpb.MsgType_kInsert: + insertRequest := internalpb2.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: reqID, + Timestamp: 11, + SourceID: reqID, + }, + CollectionName: "Collection", + PartitionName: "Partition", + SegmentID: 1, + ChannelID: "0", + Timestamps: []Timestamp{uint64(reqID)}, + RowIDs: []int64{1}, + RowData: []*commonpb.Blob{{}}, + } + insertMsg := &msgstream.InsertMsg{ + BaseMsg: baseMsg, + InsertRequest: insertRequest, + } + return insertMsg + case commonpb.MsgType_kDelete: + deleteRequest := internalpb2.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kDelete, + MsgID: reqID, + Timestamp: 11, + SourceID: reqID, + }, + CollectionName: "Collection", + ChannelID: "1", + Timestamps: []Timestamp{1}, + PrimaryKeys: []IntPrimaryKey{1}, + } + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: baseMsg, + DeleteRequest: deleteRequest, + } + return deleteMsg + case commonpb.MsgType_kSearch: + searchRequest := internalpb2.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearch, + MsgID: reqID, + Timestamp: 11, + SourceID: reqID, + }, + Query: nil, + ResultChannelID: "0", + } + searchMsg := &msgstream.SearchMsg{ + BaseMsg: baseMsg, + SearchRequest: searchRequest, + } + return searchMsg + case commonpb.MsgType_kSearchResult: + searchResult := internalpb2.SearchResults{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearchResult, + MsgID: reqID, + Timestamp: 1, + SourceID: reqID, + }, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, + ResultChannelID: "0", + } + searchResultMsg := &msgstream.SearchResultMsg{ + BaseMsg: baseMsg, + SearchResults: searchResult, + } + return searchResultMsg + case commonpb.MsgType_kTimeTick: + timeTickResult := internalpb2.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: reqID, + Timestamp: 1, + SourceID: reqID, + }, + } + timeTickMsg := &TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickResult, + } + return timeTickMsg + case commonpb.MsgType_kQueryNodeStats: + queryNodeSegStats := internalpb2.QueryNodeStats{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kQueryNodeStats, + SourceID: reqID, + }, + } + queryNodeSegStatsMsg := &QueryNodeStatsMsg{ + BaseMsg: baseMsg, + QueryNodeStats: queryNodeSegStats, + } + return queryNodeSegStatsMsg + } + return nil +} + +func getTimeTickMsg(reqID UniqueID, hashValue uint32, time uint64) TsMsg { + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{hashValue}, + } + timeTickResult := internalpb2.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: reqID, + Timestamp: time, + SourceID: reqID, + }, + } + timeTickMsg := &TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickResult, + } + return timeTickMsg +} + func initPulsarStream(pulsarAddress string, producerChannels []string, consumerChannels []string, consumerSubName string, - opts ...msgstream.RepackFunc) (*msgstream.MsgStream, *msgstream.MsgStream) { + opts ...RepackFunc) (msgstream.MsgStream, msgstream.MsgStream) { // set input stream inputStream := NewPulsarMsgStream(context.Background(), 100) @@ -46,14 +193,14 @@ func initPulsarStream(pulsarAddress string, outputStream.Start() var output msgstream.MsgStream = outputStream - return &input, &output + return input, output } func initPulsarTtStream(pulsarAddress string, producerChannels []string, consumerChannels []string, consumerSubName string, - opts ...msgstream.RepackFunc) (*msgstream.MsgStream, *msgstream.MsgStream) { + opts ...RepackFunc) (msgstream.MsgStream, msgstream.MsgStream) { // set input stream inputStream := NewPulsarMsgStream(context.Background(), 100) @@ -73,13 +220,13 @@ func initPulsarTtStream(pulsarAddress string, outputStream.Start() var output msgstream.MsgStream = outputStream - return &input, &output + return input, output } -func receiveMsg(outputStream *msgstream.MsgStream, msgCount int) { +func receiveMsg(outputStream msgstream.MsgStream, msgCount int) { receiveCount := 0 for { - result := (*outputStream).Consume() + result := outputStream.Consume() if len(result.Msgs) > 0 { msgs := result.Msgs for _, v := range msgs { @@ -100,18 +247,18 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { consumerSubName := "subInsert" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } @@ -122,17 +269,17 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { consumerSubName := "subDelete" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kDelete, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kDelete, 1, 1)) //msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kDelete, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } func TestStream_PulsarMsgStream_Search(t *testing.T) { @@ -142,17 +289,17 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) { consumerSubName := "subSearch" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kSearch, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kSearch, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { @@ -162,17 +309,17 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { consumerSubName := "subSearchResult" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kSearchResult, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kSearchResult, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { @@ -182,17 +329,17 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { consumerSubName := "subTimeTick" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kTimeTick, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Produce(&msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { @@ -202,17 +349,17 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { consumerSubName := "subInsert" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kTimeTick, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Broadcast(&msgPack) + err := inputStream.Broadcast(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(consumerChannels)*len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { @@ -222,17 +369,17 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { consumerSubName := "subInsert" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3)) - inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName, util.RepackFunc) - err := (*inputStream).Produce(&msgPack) + inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName, repackFunc) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveMsg(outputStream, len(msgPack.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { @@ -286,7 +433,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { if err != nil { log.Fatalf("produce error = %v", err) } - receiveMsg(&output, len(msgPack.Msgs)*2) + receiveMsg(output, len(msgPack.Msgs)*2) (*inputStream).Close() (*outputStream).Close() } @@ -339,7 +486,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { if err != nil { log.Fatalf("produce error = %v", err) } - receiveMsg(&output, len(msgPack.Msgs)*2) + receiveMsg(output, len(msgPack.Msgs)*2) (*inputStream).Close() (*outputStream).Close() } @@ -351,10 +498,10 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { consumerSubName := "subInsert" msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kSearch, 2, 2)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kSearchResult, 3, 3)) - msgPack.Msgs = append(msgPack.Msgs, util.GetTsMsg(commonpb.MsgType_kQueryNodeStats, 4, 4)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 2, 2)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 3, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kQueryNodeStats, 4, 4)) inputStream := NewPulsarMsgStream(context.Background(), 100) inputStream.SetPulsarClient(pulsarAddress) @@ -372,7 +519,7 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { if err != nil { log.Fatalf("produce error = %v", err) } - receiveMsg(&output, len(msgPack.Msgs)) + receiveMsg(output, len(msgPack.Msgs)) (*inputStream).Close() (*outputStream).Close() } @@ -384,31 +531,84 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { consumerSubName := "subInsert" msgPack0 := msgstream.MsgPack{} - msgPack0.Msgs = append(msgPack0.Msgs, util.GetTimeTickMsg(0, 0, 0)) + msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0, 0, 0)) msgPack1 := msgstream.MsgPack{} - msgPack1.Msgs = append(msgPack1.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 1, 1)) - msgPack1.Msgs = append(msgPack1.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 3, 3)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3)) msgPack2 := msgstream.MsgPack{} - msgPack2.Msgs = append(msgPack2.Msgs, util.GetTimeTickMsg(5, 5, 5)) + msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Broadcast(&msgPack0) + err := inputStream.Broadcast(&msgPack0) if err != nil { log.Fatalf("broadcast error = %v", err) } - err = (*inputStream).Produce(&msgPack1) + err = inputStream.Produce(&msgPack1) if err != nil { log.Fatalf("produce error = %v", err) } - err = (*inputStream).Broadcast(&msgPack2) + err = inputStream.Broadcast(&msgPack2) if err != nil { log.Fatalf("broadcast error = %v", err) } receiveMsg(outputStream, len(msgPack1.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() +} + +func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { + pulsarAddress, _ := Params.Load("_PulsarAddress") + producerChannels := []string{"seek_insert1", "seek_insert2"} + consumerChannels := []string{"seek_insert1", "seek_insert2"} + consumerSubName := "subInsert" + + msgPack0 := MsgPack{} + msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0, 0, 0)) + + msgPack1 := MsgPack{} + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_kInsert, 19, 19)) + + msgPack2 := MsgPack{} + msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) + + msgPack3 := MsgPack{} + msgPack3.Msgs = append(msgPack3.Msgs, getTsMsg(commonpb.MsgType_kInsert, 14, 14)) + msgPack3.Msgs = append(msgPack3.Msgs, getTsMsg(commonpb.MsgType_kInsert, 9, 9)) + + msgPack4 := MsgPack{} + msgPack4.Msgs = append(msgPack2.Msgs, getTimeTickMsg(11, 11, 11)) + + msgPack5 := MsgPack{} + msgPack5.Msgs = append(msgPack5.Msgs, getTimeTickMsg(15, 15, 15)) + + inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) + err := inputStream.Broadcast(&msgPack0) + assert.Nil(t, err) + err = inputStream.Produce(&msgPack1) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack2) + assert.Nil(t, err) + err = inputStream.Produce(&msgPack3) + assert.Nil(t, err) + err = inputStream.Broadcast(&msgPack4) + assert.Nil(t, err) + + outputStream.Consume() + receivedMsg := outputStream.Consume() + for _, position := range receivedMsg.StartPositions { + outputStream.Seek(position) + } + err = inputStream.Broadcast(&msgPack5) + assert.Nil(t, err) + seekMsg := outputStream.Consume() + for _, msg := range seekMsg.Msgs { + assert.Equal(t, msg.BeginTs(), uint64(14)) + } + inputStream.Close() + outputStream.Close() } func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { @@ -418,29 +618,29 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { consumerSubName := "subInsert" msgPack0 := msgstream.MsgPack{} - msgPack0.Msgs = append(msgPack0.Msgs, util.GetTimeTickMsg(0, 0, 0)) + msgPack0.Msgs = append(msgPack0.Msgs, getTimeTickMsg(0, 0, 0)) msgPack1 := msgstream.MsgPack{} - msgPack1.Msgs = append(msgPack1.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 1, 1)) - msgPack1.Msgs = append(msgPack1.Msgs, util.GetTsMsg(commonpb.MsgType_kInsert, 3, 3)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3)) msgPack2 := msgstream.MsgPack{} - msgPack2.Msgs = append(msgPack2.Msgs, util.GetTimeTickMsg(5, 5, 5)) + msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := (*inputStream).Broadcast(&msgPack0) + err := inputStream.Broadcast(&msgPack0) if err != nil { log.Fatalf("broadcast error = %v", err) } - err = (*inputStream).Produce(&msgPack1) + err = inputStream.Produce(&msgPack1) if err != nil { log.Fatalf("produce error = %v", err) } - err = (*inputStream).Broadcast(&msgPack2) + err = inputStream.Broadcast(&msgPack2) if err != nil { log.Fatalf("broadcast error = %v", err) } receiveMsg(outputStream, len(msgPack1.Msgs)) - (*inputStream).Close() - (*outputStream).Close() + inputStream.Close() + outputStream.Close() } diff --git a/internal/msgstream/util/repack_func.go b/internal/msgstream/util/repack_func.go new file mode 100644 index 000000000..176778065 --- /dev/null +++ b/internal/msgstream/util/repack_func.go @@ -0,0 +1,132 @@ +package util + +import ( + "github.com/zilliztech/milvus-distributed/internal/errors" + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" +) + +type MsgStream = msgstream.MsgStream +type TsMsg = msgstream.TsMsg +type MsgPack = msgstream.MsgPack +type BaseMsg = msgstream.BaseMsg + +func InsertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + if request.Type() != commonpb.MsgType_kInsert { + return nil, errors.New("msg's must be Insert") + } + insertRequest := request.(*msgstream.InsertMsg) + keys := hashKeys[i] + + timestampLen := len(insertRequest.Timestamps) + rowIDLen := len(insertRequest.RowIDs) + rowDataLen := len(insertRequest.RowData) + keysLen := len(keys) + + if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { + return nil, errors.New("the length of hashValue, timestamps, rowIDs, RowData are not equal") + } + for index, key := range keys { + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + + sliceRequest := internalpb2.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: insertRequest.Base.MsgID, + Timestamp: insertRequest.Timestamps[index], + SourceID: insertRequest.Base.SourceID, + }, + CollectionName: insertRequest.CollectionName, + PartitionName: insertRequest.PartitionName, + SegmentID: insertRequest.SegmentID, + ChannelID: insertRequest.ChannelID, + Timestamps: []uint64{insertRequest.Timestamps[index]}, + RowIDs: []int64{insertRequest.RowIDs[index]}, + RowData: []*commonpb.Blob{insertRequest.RowData[index]}, + } + + insertMsg := &msgstream.InsertMsg{ + BaseMsg: BaseMsg{ + MsgCtx: request.GetMsgContext(), + }, + InsertRequest: sliceRequest, + } + result[key].Msgs = append(result[key].Msgs, insertMsg) + } + } + return result, nil +} + +func DeleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + if request.Type() != commonpb.MsgType_kDelete { + return nil, errors.New("msg's must be Delete") + } + deleteRequest := request.(*msgstream.DeleteMsg) + keys := hashKeys[i] + + timestampLen := len(deleteRequest.Timestamps) + primaryKeysLen := len(deleteRequest.PrimaryKeys) + keysLen := len(keys) + + if keysLen != timestampLen || keysLen != primaryKeysLen { + return nil, errors.New("the length of hashValue, timestamps, primaryKeys are not equal") + } + + for index, key := range keys { + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + + sliceRequest := internalpb2.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kDelete, + MsgID: deleteRequest.Base.MsgID, + Timestamp: deleteRequest.Timestamps[index], + SourceID: deleteRequest.Base.SourceID, + }, + CollectionName: deleteRequest.CollectionName, + ChannelID: deleteRequest.ChannelID, + Timestamps: []uint64{deleteRequest.Timestamps[index]}, + PrimaryKeys: []int64{deleteRequest.PrimaryKeys[index]}, + } + + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: BaseMsg{ + MsgCtx: request.GetMsgContext(), + }, + DeleteRequest: sliceRequest, + } + result[key].Msgs = append(result[key].Msgs, deleteMsg) + } + } + return result, nil +} + +func DefaultRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range tsMsgs { + keys := hashKeys[i] + if len(keys) != 1 { + return nil, errors.New("len(msg.hashValue) must equal 1") + } + key := keys[0] + _, ok := result[key] + if !ok { + msgPack := MsgPack{} + result[key] = &msgPack + } + result[key].Msgs = append(result[key].Msgs, request) + } + return result, nil +} diff --git a/internal/msgstream/util/unmarshal_test.go b/internal/msgstream/util/unmarshal_test.go index 05eeaac54..b73230668 100644 --- a/internal/msgstream/util/unmarshal_test.go +++ b/internal/msgstream/util/unmarshal_test.go @@ -14,6 +14,8 @@ import ( var Params paramtable.BaseTable +type Timestamp = msgstream.Timestamp + func newInsertMsgUnmarshal(input []byte) (msgstream.TsMsg, error) { insertRequest := internalpb2.InsertRequest{} err := proto.Unmarshal(input, &insertRequest) @@ -27,25 +29,32 @@ func newInsertMsgUnmarshal(input []byte) (msgstream.TsMsg, error) { } func TestStream_unmarshal_Insert(t *testing.T) { - //pulsarAddress, _ := Params.Load("_PulsarAddress") - //producerChannels := []string{"insert1", "insert2"} - //consumerChannels := []string{"insert1", "insert2"} - //consumerSubName := "subInsert" - msgPack := msgstream.MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, GetTsMsg(commonpb.MsgType_kInsert, 1, 1)) - msgPack.Msgs = append(msgPack.Msgs, GetTsMsg(commonpb.MsgType_kInsert, 3, 3)) - - //inputStream := pulsarms.NewPulsarMsgStream(context.Background(), 100) - //inputStream.SetPulsarClient(pulsarAddress) - //inputStream.CreatePulsarProducers(producerChannels) - //inputStream.Start() + insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{1}, + }, + InsertRequest: internalpb2.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: 1, + Timestamp: 11, + SourceID: 1, + }, + CollectionName: "Collection", + PartitionName: "Partition", + SegmentID: 1, + ChannelID: "0", + Timestamps: []Timestamp{uint64(1)}, + RowIDs: []int64{1}, + RowData: []*commonpb.Blob{{}}, + }, + } + msgPack.Msgs = append(msgPack.Msgs, insertMsg) - //outputStream := pulsarms.NewPulsarMsgStream(context.Background(), 100) - //outputStream.SetPulsarClient(pulsarAddress) unmarshalDispatcher := NewUnmarshalDispatcher() - - //add a new unmarshall func for msgType kInsert unmarshalDispatcher.AddMsgTemplate(commonpb.MsgType_kInsert, newInsertMsgUnmarshal) for _, v := range msgPack.Msgs { @@ -58,28 +67,4 @@ func TestStream_unmarshal_Insert(t *testing.T) { assert.Nil(t, err) fmt.Println("msg type: ", msg.Type(), ", msg value: ", msg, "msg tag: ") } - - //outputStream.CreatePulsarConsumers(consumerChannels, consumerSubName, unmarshalDispatcher, 100) - //outputStream.Start() - - //err := inputStream.Produce(&msgPack) - //if err != nil { - // log.Fatalf("produce error = %v", err) - //} - //receiveCount := 0 - //for { - // result := (*outputStream).Consume() - // if len(result.Msgs) > 0 { - // msgs := result.Msgs - // for _, v := range msgs { - // receiveCount++ - // fmt.Println("msg type: ", v.Type(), ", msg value: ", v, "msg tag: ") - // } - // } - // if receiveCount >= len(msgPack.Msgs) { - // break - // } - //} - //inputStream.Close() - //outputStream.Close() } diff --git a/internal/msgstream/util/unpack.go b/internal/msgstream/util/unpack.go deleted file mode 100644 index c093749d3..000000000 --- a/internal/msgstream/util/unpack.go +++ /dev/null @@ -1,162 +0,0 @@ -package util - -import ( - "github.com/zilliztech/milvus-distributed/internal/msgstream" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" -) - -type TsMsg = msgstream.TsMsg -type MsgPack = msgstream.MsgPack -type MsgType = msgstream.MsgType -type UniqueID = msgstream.UniqueID -type BaseMsg = msgstream.BaseMsg -type Timestamp = msgstream.Timestamp -type IntPrimaryKey = msgstream.IntPrimaryKey -type TimeTickMsg = msgstream.TimeTickMsg -type QueryNodeStatsMsg = msgstream.QueryNodeStatsMsg - -func RepackFunc(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { - result := make(map[int32]*MsgPack) - for i, request := range msgs { - keys := hashKeys[i] - for _, channelID := range keys { - _, ok := result[channelID] - if !ok { - msgPack := MsgPack{} - result[channelID] = &msgPack - } - result[channelID].Msgs = append(result[channelID].Msgs, request) - } - } - return result, nil -} - -func GetTsMsg(msgType MsgType, reqID UniqueID, hashValue uint32) TsMsg { - baseMsg := BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []uint32{hashValue}, - } - switch msgType { - case commonpb.MsgType_kInsert: - insertRequest := internalpb2.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kInsert, - MsgID: reqID, - Timestamp: 11, - SourceID: reqID, - }, - CollectionName: "Collection", - PartitionName: "Partition", - SegmentID: 1, - ChannelID: "0", - Timestamps: []Timestamp{1}, - RowIDs: []int64{1}, - RowData: []*commonpb.Blob{{}}, - } - insertMsg := &msgstream.InsertMsg{ - BaseMsg: baseMsg, - InsertRequest: insertRequest, - } - return insertMsg - case commonpb.MsgType_kDelete: - deleteRequest := internalpb2.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kDelete, - MsgID: reqID, - Timestamp: 11, - SourceID: reqID, - }, - CollectionName: "Collection", - ChannelID: "1", - Timestamps: []Timestamp{1}, - PrimaryKeys: []IntPrimaryKey{1}, - } - deleteMsg := &msgstream.DeleteMsg{ - BaseMsg: baseMsg, - DeleteRequest: deleteRequest, - } - return deleteMsg - case commonpb.MsgType_kSearch: - searchRequest := internalpb2.SearchRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kSearch, - MsgID: reqID, - Timestamp: 11, - SourceID: reqID, - }, - Query: nil, - ResultChannelID: "0", - } - searchMsg := &msgstream.SearchMsg{ - BaseMsg: baseMsg, - SearchRequest: searchRequest, - } - return searchMsg - case commonpb.MsgType_kSearchResult: - searchResult := internalpb2.SearchResults{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kSearchResult, - MsgID: reqID, - Timestamp: 1, - SourceID: reqID, - }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, - ResultChannelID: "0", - } - searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: baseMsg, - SearchResults: searchResult, - } - return searchResultMsg - case commonpb.MsgType_kTimeTick: - timeTickResult := internalpb2.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kTimeTick, - MsgID: reqID, - Timestamp: 1, - SourceID: reqID, - }, - } - timeTickMsg := &TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, - } - return timeTickMsg - case commonpb.MsgType_kQueryNodeStats: - queryNodeSegStats := internalpb2.QueryNodeStats{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kQueryNodeStats, - SourceID: reqID, - }, - } - queryNodeSegStatsMsg := &QueryNodeStatsMsg{ - BaseMsg: baseMsg, - QueryNodeStats: queryNodeSegStats, - } - return queryNodeSegStatsMsg - } - return nil -} - -func GetTimeTickMsg(reqID UniqueID, hashValue uint32, time uint64) TsMsg { - baseMsg := BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []uint32{hashValue}, - } - timeTickResult := internalpb2.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_kTimeTick, - MsgID: reqID, - Timestamp: time, - SourceID: reqID, - }, - } - timeTickMsg := &TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, - } - return timeTickMsg -} diff --git a/internal/util/typeutil/convension.go b/internal/util/typeutil/convension.go index a444f79b8..68d6e2da1 100644 --- a/internal/util/typeutil/convension.go +++ b/internal/util/typeutil/convension.go @@ -2,6 +2,10 @@ package typeutil import ( "encoding/binary" + "fmt" + "reflect" + + "github.com/apache/pulsar-client-go/pulsar" "github.com/zilliztech/milvus-distributed/internal/errors" ) @@ -37,3 +41,28 @@ func Uint64ToBytes(v uint64) []byte { binary.BigEndian.PutUint64(b, v) return b } + +func PulsarMsgIDToString(messageID pulsar.MessageID) string { + return string(messageID.Serialize()) +} + +func StringToPulsarMsgID(msgString string) (pulsar.MessageID, error) { + return pulsar.DeserializeMessageID([]byte(msgString)) +} + +func SliceRemoveDuplicate(a interface{}) (ret []interface{}) { + if reflect.TypeOf(a).Kind() != reflect.Slice { + fmt.Printf("input is not slice but %T\n", a) + return ret + } + + va := reflect.ValueOf(a) + for i := 0; i < va.Len(); i++ { + if i > 0 && reflect.DeepEqual(va.Index(i-1).Interface(), va.Index(i).Interface()) { + continue + } + ret = append(ret, va.Index(i).Interface()) + } + + return ret +} -- GitLab