From e3fadc45d43c64e0a7c0ceb391751dc681022c5f Mon Sep 17 00:00:00 2001
From: yukun <kun.yu@zilliz.com>
Date: Thu, 4 Feb 2021 15:23:21 +0800
Subject: [PATCH] Fix for new msgstream interface

Signed-off-by: yukun <kun.yu@zilliz.com>
---
 internal/dataservice/meta.go                  |  2 +-
 internal/dataservice/segment_allocator.go     |  4 +-
 .../dataservice/segment_allocator_test.go     |  9 +-
 internal/dataservice/server.go                |  5 +-
 internal/dataservice/watcher_test.go          |  2 +-
 .../msgstream/pulsarms/pulsar_msgstream.go    |  2 +-
 internal/msgstream/rmqms/factory.go           | 28 ++++++
 internal/msgstream/rmqms/rmq_msgstream.go     | 89 +++++++++++++++----
 8 files changed, 111 insertions(+), 30 deletions(-)
 create mode 100644 internal/msgstream/rmqms/factory.go

diff --git a/internal/dataservice/meta.go b/internal/dataservice/meta.go
index b69cc746e..61b67842b 100644
--- a/internal/dataservice/meta.go
+++ b/internal/dataservice/meta.go
@@ -183,7 +183,7 @@ func (meta *meta) UpdateSegment(segmentInfo *datapb.SegmentInfo) error {
 
 func (meta *meta) DropSegment(segmentID UniqueID) error {
 	meta.ddLock.Lock()
-	meta.ddLock.Unlock()
+	defer meta.ddLock.Unlock()
 
 	if _, ok := meta.segID2Info[segmentID]; !ok {
 		return newErrSegmentNotFound(segmentID)
diff --git a/internal/dataservice/segment_allocator.go b/internal/dataservice/segment_allocator.go
index d15d3b66a..a85dd4111 100644
--- a/internal/dataservice/segment_allocator.go
+++ b/internal/dataservice/segment_allocator.go
@@ -71,7 +71,7 @@ type (
 	}
 )
 
-func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl, error) {
+func newSegmentAllocator(meta *meta, allocator allocator) *segmentAllocatorImpl {
 	segmentAllocator := &segmentAllocatorImpl{
 		mt:                     meta,
 		segments:               make(map[UniqueID]*segmentStatus),
@@ -80,7 +80,7 @@ func newSegmentAllocator(meta *meta, allocator allocator) (*segmentAllocatorImpl
 		segmentThresholdFactor: Params.SegmentSizeFactor,
 		allocator:              allocator,
 	}
-	return segmentAllocator, nil
+	return segmentAllocator
 }
 
 func (allocator *segmentAllocatorImpl) OpenSegment(segmentInfo *datapb.SegmentInfo) error {
diff --git a/internal/dataservice/segment_allocator_test.go b/internal/dataservice/segment_allocator_test.go
index d0e248edc..9f81783ef 100644
--- a/internal/dataservice/segment_allocator_test.go
+++ b/internal/dataservice/segment_allocator_test.go
@@ -17,8 +17,7 @@ func TestAllocSegment(t *testing.T) {
 	mockAllocator := newMockAllocator()
 	meta, err := newMemoryMeta(mockAllocator)
 	assert.Nil(t, err)
-	segAllocator, err := newSegmentAllocator(meta, mockAllocator)
-	assert.Nil(t, err)
+	segAllocator := newSegmentAllocator(meta, mockAllocator)
 
 	schema := newTestSchema()
 	collID, err := mockAllocator.allocID()
@@ -68,8 +67,7 @@ func TestSealSegment(t *testing.T) {
 	mockAllocator := newMockAllocator()
 	meta, err := newMemoryMeta(mockAllocator)
 	assert.Nil(t, err)
-	segAllocator, err := newSegmentAllocator(meta, mockAllocator)
-	assert.Nil(t, err)
+	segAllocator := newSegmentAllocator(meta, mockAllocator)
 
 	schema := newTestSchema()
 	collID, err := mockAllocator.allocID()
@@ -105,8 +103,7 @@ func TestExpireSegment(t *testing.T) {
 	mockAllocator := newMockAllocator()
 	meta, err := newMemoryMeta(mockAllocator)
 	assert.Nil(t, err)
-	segAllocator, err := newSegmentAllocator(meta, mockAllocator)
-	assert.Nil(t, err)
+	segAllocator := newSegmentAllocator(meta, mockAllocator)
 
 	schema := newTestSchema()
 	collID, err := mockAllocator.allocID()
diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go
index a5537e11c..fcb2c5cce 100644
--- a/internal/dataservice/server.go
+++ b/internal/dataservice/server.go
@@ -134,10 +134,7 @@ func (s *Server) Start() error {
 		return err
 	}
 	s.statsHandler = newStatsHandler(s.meta)
-	s.segAllocator, err = newSegmentAllocator(s.meta, s.allocator)
-	if err != nil {
-		return err
-	}
+	s.segAllocator = newSegmentAllocator(s.meta, s.allocator)
 	s.ddHandler = newDDHandler(s.meta, s.segAllocator)
 	s.initSegmentInfoChannel()
 	if err = s.loadMetaFromMaster(); err != nil {
diff --git a/internal/dataservice/watcher_test.go b/internal/dataservice/watcher_test.go
index e476f9f02..11e8fbe5f 100644
--- a/internal/dataservice/watcher_test.go
+++ b/internal/dataservice/watcher_test.go
@@ -21,7 +21,7 @@ func TestDataNodeTTWatcher(t *testing.T) {
 	allocator := newMockAllocator()
 	meta, err := newMemoryMeta(allocator)
 	assert.Nil(t, err)
-	segAllocator, err := newSegmentAllocator(meta, allocator)
+	segAllocator := newSegmentAllocator(meta, allocator)
 	assert.Nil(t, err)
 	watcher := newDataNodeTimeTickWatcher(meta, segAllocator, cluster)
 
diff --git a/internal/msgstream/pulsarms/pulsar_msgstream.go b/internal/msgstream/pulsarms/pulsar_msgstream.go
index 5f9986861..aedca490c 100644
--- a/internal/msgstream/pulsarms/pulsar_msgstream.go
+++ b/internal/msgstream/pulsarms/pulsar_msgstream.go
@@ -747,7 +747,7 @@ func checkTimeTickMsg(msg map[Consumer]Timestamp,
 	for consumer := range msg {
 		mu.RLock()
 		v := msg[consumer]
-		mu.Unlock()
+		mu.RUnlock()
 		if v != maxTime {
 			isChannelReady[consumer] = false
 		} else {
diff --git a/internal/msgstream/rmqms/factory.go b/internal/msgstream/rmqms/factory.go
new file mode 100644
index 000000000..7e0e3902f
--- /dev/null
+++ b/internal/msgstream/rmqms/factory.go
@@ -0,0 +1,28 @@
+package rmqms
+
+import (
+	"context"
+
+	"github.com/zilliztech/milvus-distributed/internal/msgstream"
+)
+
+type Factory struct {
+	dispatcherFactory msgstream.ProtoUDFactory
+	address           string
+	receiveBufSize    int64
+	pulsarBufSize     int64
+}
+
+func (f *Factory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) {
+	return newRmqMsgStream(ctx, f.receiveBufSize, f.dispatcherFactory.NewUnmarshalDispatcher())
+}
+
+func NewFactory(address string, receiveBufSize int64, pulsarBufSize int64) *Factory {
+	f := &Factory{
+		dispatcherFactory: msgstream.ProtoUDFactory{},
+		address:           address,
+		receiveBufSize:    receiveBufSize,
+		pulsarBufSize:     pulsarBufSize,
+	}
+	return f
+}
diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go
index ec8cfa3a3..b2033f446 100644
--- a/internal/msgstream/rmqms/rmq_msgstream.go
+++ b/internal/msgstream/rmqms/rmq_msgstream.go
@@ -16,6 +16,17 @@ import (
 	"github.com/zilliztech/milvus-distributed/internal/msgstream"
 )
 
+type TsMsg = msgstream.TsMsg
+type MsgPack = msgstream.MsgPack
+type MsgType = msgstream.MsgType
+type UniqueID = msgstream.UniqueID
+type BaseMsg = msgstream.BaseMsg
+type Timestamp = msgstream.Timestamp
+type IntPrimaryKey = msgstream.IntPrimaryKey
+type TimeTickMsg = msgstream.TimeTickMsg
+type QueryNodeStatsMsg = msgstream.QueryNodeStatsMsg
+type RepackFunc = msgstream.RepackFunc
+
 type RmqMsgStream struct {
 	isServing        int64
 	ctx              context.Context
@@ -23,7 +34,6 @@ type RmqMsgStream struct {
 	serverLoopCtx    context.Context
 	serverLoopCancel func()
 
-	rmq        *rocksmq.RocksMQ
 	repackFunc msgstream.RepackFunc
 	consumers  []rocksmq.Consumer
 	producers  []string
@@ -35,17 +45,18 @@ type RmqMsgStream struct {
 	streamCancel func()
 }
 
-func NewRmqMsgStream(ctx context.Context, rmq *rocksmq.RocksMQ, receiveBufSize int64) *RmqMsgStream {
+func newRmqMsgStream(ctx context.Context, receiveBufSize int64,
+	unmarshal msgstream.UnmarshalDispatcher) (*RmqMsgStream, error) {
 	streamCtx, streamCancel := context.WithCancel(ctx)
 	receiveBuf := make(chan *msgstream.MsgPack, receiveBufSize)
 	stream := &RmqMsgStream{
 		ctx:          streamCtx,
-		rmq:          nil,
 		receiveBuf:   receiveBuf,
+		unmarshal:    unmarshal,
 		streamCancel: streamCancel,
 	}
 
-	return stream
+	return stream, nil
 }
 
 func (ms *RmqMsgStream) Start() {
@@ -59,25 +70,32 @@ func (ms *RmqMsgStream) Start() {
 func (ms *RmqMsgStream) Close() {
 }
 
-func (ms *RmqMsgStream) CreateProducers(channels []string) error {
+type propertiesReaderWriter struct {
+	ppMap map[string]string
+}
+
+func (ms *RmqMsgStream) SetRepackFunc(repackFunc RepackFunc) {
+	ms.repackFunc = repackFunc
+}
+
+func (ms *RmqMsgStream) AsProducer(channels []string) {
 	for _, channel := range channels {
 		// TODO(yhz): Here may allow to create an existing channel
-		if err := ms.rmq.CreateChannel(channel); err != nil {
-			return err
+		if err := rocksmq.Rmq.CreateChannel(channel); err != nil {
+			errMsg := "Failed to create producer " + channel + ", error = " + err.Error()
+			panic(errMsg)
 		}
 	}
-	return nil
 }
 
-func (ms *RmqMsgStream) CreateConsumers(channels []string, groupName string) error {
+func (ms *RmqMsgStream) AsConsumer(channels []string, groupName string) {
 	for _, channelName := range channels {
-		if err := ms.rmq.CreateConsumerGroup(groupName, channelName); err != nil {
-			return err
+		if err := rocksmq.Rmq.CreateConsumerGroup(groupName, channelName); err != nil {
+			panic(err.Error())
 		}
 		msgNum := make(chan int)
 		ms.consumers = append(ms.consumers, rocksmq.Consumer{GroupName: groupName, ChannelName: channelName, MsgNum: msgNum})
 	}
-	return nil
 }
 
 func (ms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error {
@@ -172,7 +190,30 @@ func (ms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error {
 			}
 			msg := make([]rocksmq.ProducerMessage, 0)
 			msg = append(msg, *rocksmq.NewProducerMessage(m))
-			if err := ms.rmq.Produce(ms.producers[k], msg); err != nil {
+			if err := rocksmq.Rmq.Produce(ms.producers[k], msg); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func (ms *RmqMsgStream) Broadcast(msgPack *MsgPack) error {
+	producerLen := len(ms.producers)
+	for _, v := range msgPack.Msgs {
+		mb, err := v.Marshal(v)
+		if err != nil {
+			return err
+		}
+		m, err := msgstream.ConvertToByteArray(mb)
+		if err != nil {
+			return err
+		}
+		msg := make([]rocksmq.ProducerMessage, 0)
+		msg = append(msg, *rocksmq.NewProducerMessage(m))
+
+		for i := 0; i < producerLen; i++ {
+			if err := rocksmq.Rmq.Produce(ms.producers[i], msg); err != nil {
 				return err
 			}
 		}
@@ -221,7 +262,7 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() {
 				}
 
 				msgNum := value.Interface().(int)
-				rmqMsg, err := ms.rmq.Consume(ms.consumers[chosen].GroupName, ms.consumers[chosen].ChannelName, msgNum)
+				rmqMsg, err := rocksmq.Rmq.Consume(ms.consumers[chosen].GroupName, ms.consumers[chosen].ChannelName, msgNum)
 				if err != nil {
 					log.Printf("Failed to consume message in rocksmq, error = %v", err)
 					continue
@@ -261,5 +302,23 @@ func (ms *RmqMsgStream) bufMsgPackToChannel() {
 }
 
 func (ms *RmqMsgStream) Chan() <-chan *msgstream.MsgPack {
-	return nil
+	return ms.receiveBuf
+}
+
+func (ms *RmqMsgStream) Seek(offset *msgstream.MsgPosition) error {
+	for i := 0; i < len(ms.consumers); i++ {
+		if ms.consumers[i].ChannelName == offset.ChannelName {
+			messageID, err := strconv.ParseInt(offset.MsgID, 10, 64)
+			if err != nil {
+				return err
+			}
+			err = rocksmq.Rmq.Seek(ms.consumers[i].GroupName, ms.consumers[i].ChannelName, messageID)
+			if err != nil {
+				return err
+			}
+			return nil
+		}
+	}
+
+	return errors.New("msgStream seek fail")
 }
-- 
GitLab