diff --git a/internal/msgstream/rmqms/factory.go b/internal/msgstream/rmqms/factory.go index 7e0e3902f79ea0a358d1f9394865a43de87a681b..0da978f1ca10dd38e1e0ec14bc6cb4d657c678e8 100644 --- a/internal/msgstream/rmqms/factory.go +++ b/internal/msgstream/rmqms/factory.go @@ -8,21 +8,18 @@ import ( type Factory struct { dispatcherFactory msgstream.ProtoUDFactory - address string receiveBufSize int64 - pulsarBufSize int64 + rmqBufSize int64 } func (f *Factory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { - return newRmqMsgStream(ctx, f.receiveBufSize, f.dispatcherFactory.NewUnmarshalDispatcher()) + return newRmqMsgStream(ctx, f.receiveBufSize, f.rmqBufSize, f.dispatcherFactory.NewUnmarshalDispatcher()) } func NewFactory(address string, receiveBufSize int64, pulsarBufSize int64) *Factory { f := &Factory{ dispatcherFactory: msgstream.ProtoUDFactory{}, - address: address, receiveBufSize: receiveBufSize, - pulsarBufSize: pulsarBufSize, } return f } diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index b2033f4464dc6fd3daae28c467137d31c6833185..db7ff5749177010e8e1381b287c3e9a6410e5f82 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -42,18 +42,24 @@ type RmqMsgStream struct { receiveBuf chan *msgstream.MsgPack wait *sync.WaitGroup // tso ticker - streamCancel func() + streamCancel func() + rmqBufSize int64 + consumerReflects []reflect.SelectCase } -func newRmqMsgStream(ctx context.Context, receiveBufSize int64, +func newRmqMsgStream(ctx context.Context, receiveBufSize int64, rmqBufSize int64, unmarshal msgstream.UnmarshalDispatcher) (*RmqMsgStream, error) { + streamCtx, streamCancel := context.WithCancel(ctx) receiveBuf := make(chan *msgstream.MsgPack, receiveBufSize) + consumerReflects := make([]reflect.SelectCase, 0) stream := &RmqMsgStream{ - ctx: streamCtx, - receiveBuf: receiveBuf, - unmarshal: unmarshal, - streamCancel: streamCancel, + ctx: streamCtx, + receiveBuf: receiveBuf, + unmarshal: unmarshal, + streamCancel: streamCancel, + rmqBufSize: rmqBufSize, + consumerReflects: consumerReflects, } return stream, nil @@ -68,6 +74,17 @@ func (ms *RmqMsgStream) Start() { } func (ms *RmqMsgStream) Close() { + ms.streamCancel() + + for _, producer := range ms.producers { + if producer != "" { + _ = rocksmq.Rmq.DestroyChannel(producer) + } + } + for _, consumer := range ms.consumers { + _ = rocksmq.Rmq.DestroyConsumerGroup(consumer.GroupName, consumer.ChannelName) + close(consumer.MsgNum) + } } type propertiesReaderWriter struct { @@ -85,16 +102,22 @@ func (ms *RmqMsgStream) AsProducer(channels []string) { errMsg := "Failed to create producer " + channel + ", error = " + err.Error() panic(errMsg) } + ms.producers = append(ms.producers, channel) } } func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) { for _, channelName := range channels { - if err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName); err != nil { + consumer, err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName) + if err != nil { panic(err.Error()) } - msgNum := make(chan int) - ms.consumers = append(ms.consumers, rocksmq.Consumer{GroupName: groupName, ChannelName: channelName, MsgNum: msgNum}) + consumer.MsgNum = make(chan int, ms.rmqBufSize) + ms.consumers = append(ms.consumers, *consumer) + ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(consumer.MsgNum), + }) } } @@ -240,12 +263,6 @@ func (ms *RmqMsgStream) Consume() *msgstream.MsgPack { func (ms *RmqMsgStream) bufMsgPackToChannel() { defer ms.wait.Done() - cases := make([]reflect.SelectCase, len(ms.consumers)) - for i := 0; i < len(ms.consumers); i++ { - ch := ms.consumers[i].MsgNum - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} - } - for { select { case <-ms.ctx.Done(): @@ -255,7 +272,7 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() { tsMsgList := make([]msgstream.TsMsg, 0) for { - chosen, value, ok := reflect.Select(cases) + chosen, value, ok := reflect.Select(ms.consumerReflects) if !ok { log.Printf("channel closed") return diff --git a/internal/msgstream/rmqms/rmq_msgstream_test.go b/internal/msgstream/rmqms/rmq_msgstream_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6d4a23b386bd9fcf134386e4b2da8bc73e3ad1a8 --- /dev/null +++ b/internal/msgstream/rmqms/rmq_msgstream_test.go @@ -0,0 +1,343 @@ +package rmqms + +import ( + "context" + "fmt" + "log" + "os" + "testing" + + etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd" + "github.com/zilliztech/milvus-distributed/internal/util/rocksmq" + "go.etcd.io/etcd/clientv3" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" +) + +var rocksmqName string = "/tmp/rocksmq" + +func repackFunc(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) { + result := make(map[int32]*MsgPack) + for i, request := range msgs { + keys := hashKeys[i] + for _, channelID := range keys { + _, ok := result[channelID] + if ok == false { + msgPack := MsgPack{} + result[channelID] = &msgPack + } + result[channelID].Msgs = append(result[channelID].Msgs, request) + } + } + return result, nil +} + +func getTsMsg(msgType MsgType, reqID UniqueID, hashValue uint32) TsMsg { + baseMsg := BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{hashValue}, + } + switch msgType { + case commonpb.MsgType_kInsert: + insertRequest := internalpb2.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, + MsgID: reqID, + Timestamp: 11, + SourceID: reqID, + }, + CollectionName: "Collection", + PartitionName: "Partition", + SegmentID: 1, + ChannelID: "0", + Timestamps: []Timestamp{uint64(reqID)}, + RowIDs: []int64{1}, + RowData: []*commonpb.Blob{{}}, + } + insertMsg := &msgstream.InsertMsg{ + BaseMsg: baseMsg, + InsertRequest: insertRequest, + } + return insertMsg + case commonpb.MsgType_kDelete: + deleteRequest := internalpb2.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kDelete, + MsgID: reqID, + Timestamp: 11, + SourceID: reqID, + }, + CollectionName: "Collection", + ChannelID: "1", + Timestamps: []Timestamp{1}, + PrimaryKeys: []IntPrimaryKey{1}, + } + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: baseMsg, + DeleteRequest: deleteRequest, + } + return deleteMsg + case commonpb.MsgType_kSearch: + searchRequest := internalpb2.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearch, + MsgID: reqID, + Timestamp: 11, + SourceID: reqID, + }, + Query: nil, + ResultChannelID: "0", + } + searchMsg := &msgstream.SearchMsg{ + BaseMsg: baseMsg, + SearchRequest: searchRequest, + } + return searchMsg + case commonpb.MsgType_kSearchResult: + searchResult := internalpb2.SearchResults{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kSearchResult, + MsgID: reqID, + Timestamp: 1, + SourceID: reqID, + }, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, + ResultChannelID: "0", + } + searchResultMsg := &msgstream.SearchResultMsg{ + BaseMsg: baseMsg, + SearchResults: searchResult, + } + return searchResultMsg + case commonpb.MsgType_kTimeTick: + timeTickResult := internalpb2.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kTimeTick, + MsgID: reqID, + Timestamp: 1, + SourceID: reqID, + }, + } + timeTickMsg := &TimeTickMsg{ + BaseMsg: baseMsg, + TimeTickMsg: timeTickResult, + } + return timeTickMsg + case commonpb.MsgType_kQueryNodeStats: + queryNodeSegStats := internalpb2.QueryNodeStats{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kQueryNodeStats, + SourceID: reqID, + }, + } + queryNodeSegStatsMsg := &QueryNodeStatsMsg{ + BaseMsg: baseMsg, + QueryNodeStats: queryNodeSegStats, + } + return queryNodeSegStatsMsg + } + return nil +} + +func initRmq(name string) *etcdkv.EtcdKV { + etcdAddr := os.Getenv("ETCD_ADDRESS") + if etcdAddr == "" { + etcdAddr = "localhost:2379" + } + cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}}) + if err != nil { + log.Fatalf("New clientv3 error = %v", err) + } + etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root") + idAllocator := rocksmq.NewGlobalIDAllocator("dummy", etcdKV) + _ = idAllocator.Initialize() + + err = rocksmq.InitRmq(name, idAllocator) + + if err != nil { + log.Fatalf("InitRmq error = %v", err) + } + return etcdKV +} + +func Close(intputStream, outputStream msgstream.MsgStream, etcdKV *etcdkv.EtcdKV) { + intputStream.Close() + outputStream.Close() + etcdKV.Close() + _ = os.RemoveAll(rocksmqName) +} + +func initRmqStream(producerChannels []string, + consumerChannels []string, + consumerGroupName string, + opts ...RepackFunc) (msgstream.MsgStream, msgstream.MsgStream) { + factory := msgstream.ProtoUDFactory{} + + inputStream, _ := newRmqMsgStream(context.Background(), 100, 100, factory.NewUnmarshalDispatcher()) + inputStream.AsProducer(producerChannels) + for _, opt := range opts { + inputStream.SetRepackFunc(opt) + } + inputStream.Start() + var input msgstream.MsgStream = inputStream + + outputStream, _ := newRmqMsgStream(context.Background(), 100, 100, factory.NewUnmarshalDispatcher()) + outputStream.AsConsumer(consumerChannels, consumerGroupName) + outputStream.Start() + var output msgstream.MsgStream = outputStream + + return input, output +} + +func receiveMsg(outputStream msgstream.MsgStream, msgCount int) { + receiveCount := 0 + for { + result := outputStream.Consume() + if len(result.Msgs) > 0 { + msgs := result.Msgs + for _, v := range msgs { + receiveCount++ + fmt.Println("msg type: ", v.Type(), ", msg value: ", v) + } + } + if receiveCount >= msgCount { + break + } + } +} + +func TestStream_RmqMsgStream_Insert(t *testing.T) { + producerChannels := []string{"insert1", "insert2"} + consumerChannels := []string{"insert1", "insert2"} + consumerGroupName := "InsertGroup" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3)) + + etcdKV := initRmq("/tmp/rocksmq_insert") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerGroupName) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + + receiveMsg(outputStream, len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} + +func TestStream_RmqMsgStream_Delete(t *testing.T) { + producerChannels := []string{"delete"} + consumerChannels := []string{"delete"} + consumerSubName := "subDelete" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kDelete, 1, 1)) + + etcdKV := initRmq("/tmp/rocksmq_delete") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} + +func TestStream_RmqMsgStream_Search(t *testing.T) { + producerChannels := []string{"search"} + consumerChannels := []string{"search"} + consumerSubName := "subSearch" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearch, 3, 3)) + + etcdKV := initRmq("/tmp/rocksmq_search") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} + +func TestStream_RmqMsgStream_SearchResult(t *testing.T) { + producerChannels := []string{"searchResult"} + consumerChannels := []string{"searchResult"} + consumerSubName := "subSearchResult" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kSearchResult, 3, 3)) + + etcdKV := initRmq("/tmp/rocksmq_searchresult") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} + +func TestStream_RmqMsgStream_TimeTick(t *testing.T) { + producerChannels := []string{"timeTick"} + consumerChannels := []string{"timeTick"} + consumerSubName := "subTimeTick" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 3, 3)) + + etcdKV := initRmq("/tmp/rocksmq_timetick") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} + +func TestStream_RmqMsgStream_BroadCast(t *testing.T) { + producerChannels := []string{"insert1", "insert2"} + consumerChannels := []string{"insert1", "insert2"} + consumerSubName := "subInsert" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kTimeTick, 3, 3)) + + etcdKV := initRmq("/tmp/rocksmq_broadcast") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) + err := inputStream.Broadcast(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(consumerChannels)*len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} + +func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { + producerChannels := []string{"insert1", "insert2"} + consumerChannels := []string{"insert1", "insert2"} + consumerSubName := "subInsert" + + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 1, 1)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_kInsert, 3, 3)) + + etcdKV := initRmq("/tmp/rocksmq_repackfunc") + inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName, repackFunc) + err := inputStream.Produce(&msgPack) + if err != nil { + log.Fatalf("produce error = %v", err) + } + receiveMsg(outputStream, len(msgPack.Msgs)) + Close(inputStream, outputStream, etcdKV) +} diff --git a/internal/util/rocksmq/rocksmq.go b/internal/util/rocksmq/rocksmq.go index 796f1e385185ab001ffe7d334e46ade081eac5f5..322b5a0e19e31f13906da1a542f9817d771f66fb 100644 --- a/internal/util/rocksmq/rocksmq.go +++ b/internal/util/rocksmq/rocksmq.go @@ -76,7 +76,7 @@ type RocksMQ struct { produceMu sync.Mutex consumeMu sync.Mutex - notify map[string][]Consumer + notify map[string][]*Consumer //ctx context.Context //serverLoopWg sync.WaitGroup //serverLoopCtx context.Context @@ -107,7 +107,7 @@ func NewRocksMQ(name string, idAllocator IDAllocator) (*RocksMQ, error) { idAllocator: idAllocator, } rmq.channels = make(map[string]*Channel) - rmq.notify = make(map[string][]Consumer) + rmq.notify = make(map[string][]*Consumer) return rmq, nil } @@ -166,17 +166,24 @@ func (rmq *RocksMQ) DestroyChannel(channelName string) error { return nil } -func (rmq *RocksMQ) CreateConsumerGroup(groupName string, channelName string) error { +func (rmq *RocksMQ) CreateConsumerGroup(groupName string, channelName string) (*Consumer, error) { key := groupName + "/" + channelName + "/current_id" if rmq.checkKeyExist(key) { - return errors.New("ConsumerGroup " + groupName + " already exists.") + return nil, errors.New("ConsumerGroup " + groupName + " already exists.") } err := rmq.kv.Save(key, DefaultMessageID) if err != nil { - return err + return nil, err } - return nil + //msgNum := make(chan int, 100) + consumer := Consumer{ + GroupName: groupName, + ChannelName: channelName, + //MsgNum: msgNum, + } + rmq.notify[channelName] = append(rmq.notify[channelName], &consumer) + return &consumer, nil } func (rmq *RocksMQ) DestroyConsumerGroup(groupName string, channelName string) error { @@ -243,7 +250,9 @@ func (rmq *RocksMQ) Produce(channelName string, messages []ProducerMessage) erro } for _, consumer := range rmq.notify[channelName] { - consumer.MsgNum <- msgLen + if consumer.MsgNum != nil { + consumer.MsgNum <- msgLen + } } return nil } diff --git a/internal/util/rocksmq/rocksmq_test.go b/internal/util/rocksmq/rocksmq_test.go index e3b92a2f9bc2e0b089e182544be2882db19ae000..f8e0d58c728d802bf8c782c187c62f310a053026 100644 --- a/internal/util/rocksmq/rocksmq_test.go +++ b/internal/util/rocksmq/rocksmq_test.go @@ -61,7 +61,7 @@ func TestRocksMQ(t *testing.T) { groupName := "test_group" _ = rmq.DestroyConsumerGroup(groupName, channelName) - err = rmq.CreateConsumerGroup(groupName, channelName) + _, err = rmq.CreateConsumerGroup(groupName, channelName) assert.Nil(t, err) cMsgs, err := rmq.Consume(groupName, channelName, 1) assert.Nil(t, err) @@ -122,7 +122,7 @@ func TestRocksMQ_Loop(t *testing.T) { // Consume loopNum message once groupName := "test_group" _ = rmq.DestroyConsumerGroup(groupName, channelName) - err = rmq.CreateConsumerGroup(groupName, channelName) + _, err = rmq.CreateConsumerGroup(groupName, channelName) assert.Nil(t, err) cMsgs, err := rmq.Consume(groupName, channelName, loopNum) assert.Nil(t, err) @@ -189,7 +189,7 @@ func TestRocksMQ_Goroutines(t *testing.T) { groupName := "test_group" _ = rmq.DestroyConsumerGroup(groupName, channelName) - err = rmq.CreateConsumerGroup(groupName, channelName) + _, err = rmq.CreateConsumerGroup(groupName, channelName) assert.Nil(t, err) // Consume one message in each goroutine for i := 0; i < loopNum; i++ {