From 3aca35b10c7f8cd26bb334903f54a29eb42f4573 Mon Sep 17 00:00:00 2001
From: yukun <kun.yu@zilliz.com>
Date: Mon, 18 Jan 2021 15:05:49 +0800
Subject: [PATCH] Add rocksmq unittest

Signed-off-by: yukun <kun.yu@zilliz.com>
---
 internal/util/rocksmq/rocksmq.go      | 24 ++++++----
 internal/util/rocksmq/rocksmq_test.go | 66 +++++++++++++++++++++++++++
 2 files changed, 80 insertions(+), 10 deletions(-)
 create mode 100644 internal/util/rocksmq/rocksmq_test.go

diff --git a/internal/util/rocksmq/rocksmq.go b/internal/util/rocksmq/rocksmq.go
index dfa63946a..5a4651d51 100644
--- a/internal/util/rocksmq/rocksmq.go
+++ b/internal/util/rocksmq/rocksmq.go
@@ -16,7 +16,9 @@ import (
 type UniqueID = typeutil.UniqueID
 
 const (
-	FixedChannelNameLen = 32
+	DefaultMessageID        = "-1"
+	FixedChannelNameLen     = 32
+	RocksDBLRUCacheCapacity = 3 << 30
 )
 
 /**
@@ -83,9 +85,9 @@ type RocksMQ struct {
 	//tsoTicker *time.Ticker
 }
 
-func NewRocksMQ(name string) (*RocksMQ, error) {
+func NewRocksMQ(name string, idAllocator master.IDAllocator) (*RocksMQ, error) {
 	bbto := gorocksdb.NewDefaultBlockBasedTableOptions()
-	bbto.SetBlockCache(gorocksdb.NewLRUCache(3 << 30))
+	bbto.SetBlockCache(gorocksdb.NewLRUCache(RocksDBLRUCacheCapacity))
 	opts := gorocksdb.NewDefaultOptions()
 	opts.SetBlockBasedTableFactory(bbto)
 	opts.SetCreateIfMissing(true)
@@ -99,8 +101,9 @@ func NewRocksMQ(name string) (*RocksMQ, error) {
 	mkv := memkv.NewMemoryKV()
 
 	rmq := &RocksMQ{
-		store: db,
-		kv:    mkv,
+		store:       db,
+		kv:          mkv,
+		idAllocator: idAllocator,
 	}
 	return rmq, nil
 }
@@ -176,8 +179,8 @@ func NewRocksMQ(name string) (*RocksMQ, error) {
 //}
 
 func (rmq *RocksMQ) checkKeyExist(key string) bool {
-	_, err := rmq.kv.Load(key)
-	return err == nil
+	val, _ := rmq.kv.Load(key)
+	return val != ""
 }
 
 func (rmq *RocksMQ) CreateChannel(channelName string) error {
@@ -229,7 +232,7 @@ func (rmq *RocksMQ) CreateConsumerGroup(groupName string, channelName string) er
 	if rmq.checkKeyExist(key) {
 		return errors.New("ConsumerGroup " + groupName + " already exists.")
 	}
-	err := rmq.kv.Save(key, "-1")
+	err := rmq.kv.Save(key, DefaultMessageID)
 	if err != nil {
 		return err
 	}
@@ -316,11 +319,12 @@ func (rmq *RocksMQ) Consume(groupName string, channelName string, n int) ([]Cons
 	}
 	dataKey := fixChanName + "/" + currentID
 
-	// msgID is "-1" means this is the first consume operation
-	if currentID == "-1" {
+	// msgID is DefaultMessageID means this is the first consume operation
+	if currentID == DefaultMessageID {
 		iter.SeekToFirst()
 	} else {
 		iter.Seek([]byte(dataKey))
+		iter.Next()
 	}
 
 	offset := 0
diff --git a/internal/util/rocksmq/rocksmq_test.go b/internal/util/rocksmq/rocksmq_test.go
new file mode 100644
index 000000000..9f4804716
--- /dev/null
+++ b/internal/util/rocksmq/rocksmq_test.go
@@ -0,0 +1,66 @@
+package rocksmq
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
+	master "github.com/zilliztech/milvus-distributed/internal/master"
+	"go.etcd.io/etcd/clientv3"
+)
+
+func TestFixChannelName(t *testing.T) {
+	name := "abcd"
+	fixName, err := fixChannelName(name)
+	assert.Nil(t, err)
+	assert.Equal(t, len(fixName), FixedChannelNameLen)
+}
+
+func TestRocksMQ(t *testing.T) {
+	master.Init()
+
+	etcdAddr := master.Params.EtcdAddress
+	cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddr}})
+	assert.Nil(t, err)
+	etcdKV := etcdkv.NewEtcdKV(cli, "/etcd/test/root")
+	idAllocator := master.NewGlobalIDAllocator("dummy", etcdKV)
+	_ = idAllocator.Initialize()
+
+	name := "/tmp/rocksmq"
+	rmq, err := NewRocksMQ(name, idAllocator)
+	assert.Nil(t, err)
+
+	channelName := "channel_a"
+	msgA := "a_message"
+	pMsgs := make([]ProducerMessage, 1)
+	pMsgA := ProducerMessage{payload: []byte(msgA)}
+	pMsgs[0] = pMsgA
+
+	_ = idAllocator.UpdateID()
+	err = rmq.Produce(channelName, pMsgs)
+	assert.Nil(t, err)
+
+	pMsgB := ProducerMessage{payload: []byte("b_message")}
+	pMsgC := ProducerMessage{payload: []byte("c_message")}
+
+	pMsgs[0] = pMsgB
+	pMsgs = append(pMsgs, pMsgC)
+	_ = idAllocator.UpdateID()
+	err = rmq.Produce(channelName, pMsgs)
+	assert.Nil(t, err)
+
+	groupName := "query_node"
+	_ = rmq.DestroyConsumerGroup(groupName, channelName)
+	err = rmq.CreateConsumerGroup(groupName, channelName)
+	assert.Nil(t, err)
+	cMsgs, err := rmq.Consume(groupName, channelName, 1)
+	assert.Nil(t, err)
+	assert.Equal(t, len(cMsgs), 1)
+	assert.Equal(t, string(cMsgs[0].payload), "a_message")
+
+	cMsgs, err = rmq.Consume(groupName, channelName, 2)
+	assert.Nil(t, err)
+	assert.Equal(t, len(cMsgs), 2)
+	assert.Equal(t, string(cMsgs[0].payload), "b_message")
+	assert.Equal(t, string(cMsgs[1].payload), "c_message")
+}
-- 
GitLab