diff --git a/internal/proxynode/channels_mgr.go b/internal/proxynode/channels_mgr.go index 3df91ed47517c5fb6e23b133bf3bbf4603725248..2e43557cc9c36b5c704da6814542c972e9d2c763 100644 --- a/internal/proxynode/channels_mgr.go +++ b/internal/proxynode/channels_mgr.go @@ -17,9 +17,10 @@ type pChan = string type channelsMgr interface { getChannels(collectionID UniqueID) ([]pChan, error) getVChannels(collectionID UniqueID) ([]vChan, error) - createDQLMsgStream(collectionID UniqueID) error - getDQLMsgStream(collectionID UniqueID) error - removeDQLMsgStream(collectionID UniqueID) error + createDQLStream(collectionID UniqueID) error + getDQLStream(collectionID UniqueID) (msgstream.MsgStream, error) + removeDQLStream(collectionID UniqueID) error + removeAllDQLStream() error createDMLMsgStream(collectionID UniqueID) error getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error) removeDMLStream(collectionID UniqueID) error @@ -61,6 +62,8 @@ func getUniqueIntGeneratorIns() uniqueIntGenerator { return uniqueIntGeneratorIns } +type getChannelsFuncType = func(collectionID UniqueID) (map[vChan]pChan, error) + type masterService interface { GetChannels(collectionID UniqueID) (map[vChan]pChan, error) } @@ -100,21 +103,51 @@ func (m *mockMaster) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) return channels, nil } -type channelsMgrImpl struct { +type queryService interface { + GetChannels(collectionID UniqueID) (map[vChan]pChan, error) +} + +type mockQueryService struct { + collectionID2Channels map[UniqueID]map[vChan]pChan +} + +func newMockQueryService() *mockQueryService { + return &mockQueryService{ + collectionID2Channels: make(map[UniqueID]map[vChan]pChan), + } +} + +func (m *mockQueryService) GetChannels(collectionID UniqueID) (map[vChan]pChan, error) { + channels, ok := m.collectionID2Channels[collectionID] + if ok { + return channels, nil + } + + channels = make(map[vChan]pChan) + l := rand.Uint64()%10 + 1 + for i := 0; uint64(i) < l; i++ { + channels[genUniqueStr()] = genUniqueStr() + } + + m.collectionID2Channels[collectionID] = channels + return channels, nil +} + +type singleTypeChannelsMgr struct { collectionID2VIDs map[UniqueID][]int // id are sorted collMtx sync.RWMutex id2vchans map[int][]vChan id2vchansMtx sync.RWMutex - id2DMLStream map[int]msgstream.MsgStream - id2UsageHistogramOfDMLStream map[int]int - dmlStreamMtx sync.RWMutex + id2Stream map[int]msgstream.MsgStream + id2UsageHistogramOfStream map[int]int + streamMtx sync.RWMutex vchans2pchans map[vChan]pChan vchans2pchansMtx sync.RWMutex - master masterService + getChannelsFunc getChannelsFuncType msgStreamFactory msgstream.Factory } @@ -135,7 +168,7 @@ func getAllValues(m map[vChan]pChan) []pChan { return values } -func (mgr *channelsMgrImpl) getLatestVID(collectionID UniqueID) (int, error) { +func (mgr *singleTypeChannelsMgr) getLatestVID(collectionID UniqueID) (int, error) { mgr.collMtx.RLock() defer mgr.collMtx.RUnlock() @@ -147,7 +180,7 @@ func (mgr *channelsMgrImpl) getLatestVID(collectionID UniqueID) (int, error) { return ids[len(ids)-1], nil } -func (mgr *channelsMgrImpl) getAllVIDs(collectionID UniqueID) ([]int, error) { +func (mgr *singleTypeChannelsMgr) getAllVIDs(collectionID UniqueID) ([]int, error) { mgr.collMtx.RLock() defer mgr.collMtx.RUnlock() @@ -159,7 +192,7 @@ func (mgr *channelsMgrImpl) getAllVIDs(collectionID UniqueID) ([]int, error) { return ids, nil } -func (mgr *channelsMgrImpl) getVChansByVID(vid int) ([]vChan, error) { +func (mgr *singleTypeChannelsMgr) getVChansByVID(vid int) ([]vChan, error) { mgr.id2vchansMtx.RLock() defer mgr.id2vchansMtx.RUnlock() @@ -171,7 +204,7 @@ func (mgr *channelsMgrImpl) getVChansByVID(vid int) ([]vChan, error) { return vchans, nil } -func (mgr *channelsMgrImpl) getPChansByVChans(vchans []vChan) ([]pChan, error) { +func (mgr *singleTypeChannelsMgr) getPChansByVChans(vchans []vChan) ([]pChan, error) { mgr.vchans2pchansMtx.RLock() defer mgr.vchans2pchansMtx.RUnlock() @@ -187,21 +220,21 @@ func (mgr *channelsMgrImpl) getPChansByVChans(vchans []vChan) ([]pChan, error) { return pchans, nil } -func (mgr *channelsMgrImpl) updateVChans(vid int, vchans []vChan) { +func (mgr *singleTypeChannelsMgr) updateVChans(vid int, vchans []vChan) { mgr.id2vchansMtx.Lock() defer mgr.id2vchansMtx.Unlock() mgr.id2vchans[vid] = vchans } -func (mgr *channelsMgrImpl) deleteVChansByVID(vid int) { +func (mgr *singleTypeChannelsMgr) deleteVChansByVID(vid int) { mgr.id2vchansMtx.Lock() defer mgr.id2vchansMtx.Unlock() delete(mgr.id2vchans, vid) } -func (mgr *channelsMgrImpl) deleteVChansByVIDs(vids []int) { +func (mgr *singleTypeChannelsMgr) deleteVChansByVIDs(vids []int) { mgr.id2vchansMtx.Lock() defer mgr.id2vchansMtx.Unlock() @@ -210,23 +243,23 @@ func (mgr *channelsMgrImpl) deleteVChansByVIDs(vids []int) { } } -func (mgr *channelsMgrImpl) deleteDMLStreamByVID(vid int) { - mgr.dmlStreamMtx.Lock() - defer mgr.dmlStreamMtx.Unlock() +func (mgr *singleTypeChannelsMgr) deleteStreamByVID(vid int) { + mgr.streamMtx.Lock() + defer mgr.streamMtx.Unlock() - delete(mgr.id2DMLStream, vid) + delete(mgr.id2Stream, vid) } -func (mgr *channelsMgrImpl) deleteDMLStreamByVIDs(vids []int) { - mgr.dmlStreamMtx.Lock() - defer mgr.dmlStreamMtx.Unlock() +func (mgr *singleTypeChannelsMgr) deleteStreamByVIDs(vids []int) { + mgr.streamMtx.Lock() + defer mgr.streamMtx.Unlock() for _, vid := range vids { - delete(mgr.id2DMLStream, vid) + delete(mgr.id2Stream, vid) } } -func (mgr *channelsMgrImpl) updateChannels(channels map[vChan]pChan) { +func (mgr *singleTypeChannelsMgr) updateChannels(channels map[vChan]pChan) { mgr.vchans2pchansMtx.Lock() defer mgr.vchans2pchansMtx.Unlock() @@ -235,44 +268,44 @@ func (mgr *channelsMgrImpl) updateChannels(channels map[vChan]pChan) { } } -func (mgr *channelsMgrImpl) deleteAllChannels() { +func (mgr *singleTypeChannelsMgr) deleteAllChannels() { mgr.vchans2pchansMtx.Lock() defer mgr.vchans2pchansMtx.Unlock() mgr.vchans2pchans = nil } -func (mgr *channelsMgrImpl) deleteAllDMLStream() { +func (mgr *singleTypeChannelsMgr) deleteAllStream() { mgr.id2vchansMtx.Lock() defer mgr.id2vchansMtx.Unlock() - mgr.id2UsageHistogramOfDMLStream = nil - mgr.id2DMLStream = nil + mgr.id2UsageHistogramOfStream = nil + mgr.id2Stream = nil } -func (mgr *channelsMgrImpl) deleteAllVChans() { +func (mgr *singleTypeChannelsMgr) deleteAllVChans() { mgr.id2vchansMtx.Lock() defer mgr.id2vchansMtx.Unlock() mgr.id2vchans = nil } -func (mgr *channelsMgrImpl) deleteAllCollection() { +func (mgr *singleTypeChannelsMgr) deleteAllCollection() { mgr.collMtx.Lock() defer mgr.collMtx.Unlock() mgr.collectionID2VIDs = nil } -func (mgr *channelsMgrImpl) addDMLStream(vid int, stream msgstream.MsgStream) { - mgr.dmlStreamMtx.Lock() - defer mgr.dmlStreamMtx.Unlock() +func (mgr *singleTypeChannelsMgr) addStream(vid int, stream msgstream.MsgStream) { + mgr.streamMtx.Lock() + defer mgr.streamMtx.Unlock() - mgr.id2DMLStream[vid] = stream - mgr.id2UsageHistogramOfDMLStream[vid] = 0 + mgr.id2Stream[vid] = stream + mgr.id2UsageHistogramOfStream[vid] = 0 } -func (mgr *channelsMgrImpl) updateCollection(collectionID UniqueID, id int) { +func (mgr *singleTypeChannelsMgr) updateCollection(collectionID UniqueID, id int) { mgr.collMtx.Lock() defer mgr.collMtx.Unlock() @@ -289,7 +322,7 @@ func (mgr *channelsMgrImpl) updateCollection(collectionID UniqueID, id int) { } } -func (mgr *channelsMgrImpl) getChannels(collectionID UniqueID) ([]pChan, error) { +func (mgr *singleTypeChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) { id, err := mgr.getLatestVID(collectionID) if err == nil { vchans, err := mgr.getVChansByVID(id) @@ -304,7 +337,7 @@ func (mgr *channelsMgrImpl) getChannels(collectionID UniqueID) ([]pChan, error) return nil, err } -func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error) { +func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) { id, err := mgr.getLatestVID(collectionID) if err == nil { return mgr.getVChansByVID(id) @@ -314,20 +347,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error) return nil, err } -func (mgr *channelsMgrImpl) createDQLMsgStream(collectionID UniqueID) error { - panic("implement me") -} - -func (mgr *channelsMgrImpl) getDQLMsgStream(collectionID UniqueID) error { - panic("implement me") -} - -func (mgr *channelsMgrImpl) removeDQLMsgStream(collectionID UniqueID) error { - panic("implement me") -} - -func (mgr *channelsMgrImpl) createDMLMsgStream(collectionID UniqueID) error { - channels, err := mgr.master.GetChannels(collectionID) +func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error { + channels, err := mgr.getChannelsFunc(collectionID) if err != nil { return err } @@ -352,23 +373,23 @@ func (mgr *channelsMgrImpl) createDMLMsgStream(collectionID UniqueID) error { runtime.SetFinalizer(stream, func(stream msgstream.MsgStream) { stream.Close() }) - mgr.addDMLStream(id, stream) + mgr.addStream(id, stream) mgr.updateCollection(collectionID, id) return nil } -func (mgr *channelsMgrImpl) getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error) { - mgr.dmlStreamMtx.RLock() - defer mgr.dmlStreamMtx.RUnlock() +func (mgr *singleTypeChannelsMgr) getStream(collectionID UniqueID) (msgstream.MsgStream, error) { + mgr.streamMtx.RLock() + defer mgr.streamMtx.RUnlock() vid, err := mgr.getLatestVID(collectionID) if err != nil { return nil, err } - stream, ok := mgr.id2DMLStream[vid] + stream, ok := mgr.id2Stream[vid] if !ok { return nil, fmt.Errorf("no dml stream for collection %v", collectionID) } @@ -376,35 +397,87 @@ func (mgr *channelsMgrImpl) getDMLStream(collectionID UniqueID) (msgstream.MsgSt return stream, nil } -func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) error { +func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) error { ids, err := mgr.getAllVIDs(collectionID) if err != nil { return err } mgr.deleteVChansByVIDs(ids) - mgr.deleteDMLStreamByVIDs(ids) + mgr.deleteStreamByVIDs(ids) return nil } -func (mgr *channelsMgrImpl) removeAllDMLStream() error { +func (mgr *singleTypeChannelsMgr) removeAllStream() error { mgr.deleteAllChannels() - mgr.deleteAllDMLStream() + mgr.deleteAllStream() mgr.deleteAllVChans() mgr.deleteAllCollection() return nil } -func newChannelsMgr(master masterService, factory msgstream.Factory) *channelsMgrImpl { +func newSingleTypeChannelsMgr(getChannelsFunc getChannelsFuncType, msgStreamFactory msgstream.Factory) *singleTypeChannelsMgr { + return &singleTypeChannelsMgr{ + collectionID2VIDs: make(map[UniqueID][]int), + id2vchans: make(map[int][]vChan), + id2Stream: make(map[int]msgstream.MsgStream), + id2UsageHistogramOfStream: make(map[int]int), + vchans2pchans: make(map[vChan]pChan), + getChannelsFunc: getChannelsFunc, + msgStreamFactory: msgStreamFactory, + } +} + +type channelsMgrImpl struct { + dmlChannelsMgr *singleTypeChannelsMgr + dqlChannelsMgr *singleTypeChannelsMgr +} + +func (mgr *channelsMgrImpl) getChannels(collectionID UniqueID) ([]pChan, error) { + return mgr.dmlChannelsMgr.getChannels(collectionID) +} + +func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error) { + return mgr.dmlChannelsMgr.getVChannels(collectionID) +} + +func (mgr *channelsMgrImpl) createDQLStream(collectionID UniqueID) error { + return mgr.dqlChannelsMgr.createMsgStream(collectionID) +} + +func (mgr *channelsMgrImpl) getDQLStream(collectionID UniqueID) (msgstream.MsgStream, error) { + return mgr.dqlChannelsMgr.getStream(collectionID) +} + +func (mgr *channelsMgrImpl) removeDQLStream(collectionID UniqueID) error { + return mgr.dqlChannelsMgr.removeStream(collectionID) +} + +func (mgr *channelsMgrImpl) removeAllDQLStream() error { + return mgr.dqlChannelsMgr.removeAllStream() +} + +func (mgr *channelsMgrImpl) createDMLMsgStream(collectionID UniqueID) error { + return mgr.dmlChannelsMgr.createMsgStream(collectionID) +} + +func (mgr *channelsMgrImpl) getDMLStream(collectionID UniqueID) (msgstream.MsgStream, error) { + return mgr.dmlChannelsMgr.getStream(collectionID) +} + +func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) error { + return mgr.dmlChannelsMgr.removeStream(collectionID) +} + +func (mgr *channelsMgrImpl) removeAllDMLStream() error { + return mgr.dmlChannelsMgr.removeAllStream() +} + +func newChannelsMgr(master masterService, query queryService, msgStreamFactory msgstream.Factory) channelsMgr { return &channelsMgrImpl{ - collectionID2VIDs: make(map[UniqueID][]int), - id2vchans: make(map[int][]vChan), - id2DMLStream: make(map[int]msgstream.MsgStream), - id2UsageHistogramOfDMLStream: make(map[int]int), - vchans2pchans: make(map[vChan]pChan), - master: master, - msgStreamFactory: factory, + dmlChannelsMgr: newSingleTypeChannelsMgr(master.GetChannels, msgStreamFactory), + dqlChannelsMgr: newSingleTypeChannelsMgr(query.GetChannels, msgStreamFactory), } } diff --git a/internal/proxynode/channels_mgr_test.go b/internal/proxynode/channels_mgr_test.go index 3064f89291d66e5de004e8eeab434d34071d8cb0..bd8574c9fd894fc82418ec714d3ee1dc7f5ce751 100644 --- a/internal/proxynode/channels_mgr_test.go +++ b/internal/proxynode/channels_mgr_test.go @@ -24,8 +24,9 @@ func TestNaiveUniqueIntGenerator_get(t *testing.T) { func TestChannelsMgrImpl_getChannels(t *testing.T) { master := newMockMaster() + query := newMockQueryService() factory := msgstream.NewSimpleMsgStreamFactory() - mgr := newChannelsMgr(master, factory) + mgr := newChannelsMgr(master, query, factory) defer mgr.removeAllDMLStream() collID := UniqueID(getUniqueIntGeneratorIns().get()) @@ -41,8 +42,9 @@ func TestChannelsMgrImpl_getChannels(t *testing.T) { func TestChannelsMgrImpl_getVChannels(t *testing.T) { master := newMockMaster() + query := newMockQueryService() factory := msgstream.NewSimpleMsgStreamFactory() - mgr := newChannelsMgr(master, factory) + mgr := newChannelsMgr(master, query, factory) defer mgr.removeAllDMLStream() collID := UniqueID(getUniqueIntGeneratorIns().get()) @@ -58,8 +60,9 @@ func TestChannelsMgrImpl_getVChannels(t *testing.T) { func TestChannelsMgrImpl_createDMLMsgStream(t *testing.T) { master := newMockMaster() + query := newMockQueryService() factory := msgstream.NewSimpleMsgStreamFactory() - mgr := newChannelsMgr(master, factory) + mgr := newChannelsMgr(master, query, factory) defer mgr.removeAllDMLStream() collID := UniqueID(getUniqueIntGeneratorIns().get()) @@ -79,8 +82,9 @@ func TestChannelsMgrImpl_createDMLMsgStream(t *testing.T) { func TestChannelsMgrImpl_getDMLMsgStream(t *testing.T) { master := newMockMaster() + query := newMockQueryService() factory := msgstream.NewSimpleMsgStreamFactory() - mgr := newChannelsMgr(master, factory) + mgr := newChannelsMgr(master, query, factory) defer mgr.removeAllDMLStream() collID := UniqueID(getUniqueIntGeneratorIns().get()) @@ -96,8 +100,9 @@ func TestChannelsMgrImpl_getDMLMsgStream(t *testing.T) { func TestChannelsMgrImpl_removeDMLMsgStream(t *testing.T) { master := newMockMaster() + query := newMockQueryService() factory := msgstream.NewSimpleMsgStreamFactory() - mgr := newChannelsMgr(master, factory) + mgr := newChannelsMgr(master, query, factory) defer mgr.removeAllDMLStream() collID := UniqueID(getUniqueIntGeneratorIns().get()) @@ -122,8 +127,9 @@ func TestChannelsMgrImpl_removeDMLMsgStream(t *testing.T) { func TestChannelsMgrImpl_removeAllDMLMsgStream(t *testing.T) { master := newMockMaster() + query := newMockQueryService() factory := msgstream.NewSimpleMsgStreamFactory() - mgr := newChannelsMgr(master, factory) + mgr := newChannelsMgr(master, query, factory) defer mgr.removeAllDMLStream() num := 10 @@ -133,3 +139,85 @@ func TestChannelsMgrImpl_removeAllDMLMsgStream(t *testing.T) { assert.Equal(t, nil, err) } } + +func TestChannelsMgrImpl_createDQLMsgStream(t *testing.T) { + master := newMockMaster() + query := newMockQueryService() + factory := msgstream.NewSimpleMsgStreamFactory() + mgr := newChannelsMgr(master, query, factory) + defer mgr.removeAllDQLStream() + + collID := UniqueID(getUniqueIntGeneratorIns().get()) + _, err := mgr.getChannels(collID) + assert.NotEqual(t, nil, err) + _, err = mgr.getVChannels(collID) + assert.NotEqual(t, nil, err) + + err = mgr.createDQLStream(collID) + assert.Equal(t, nil, err) + + _, err = mgr.getChannels(collID) + assert.Equal(t, nil, err) + _, err = mgr.getVChannels(collID) + assert.Equal(t, nil, err) +} + +func TestChannelsMgrImpl_getDQLMsgStream(t *testing.T) { + master := newMockMaster() + query := newMockQueryService() + factory := msgstream.NewSimpleMsgStreamFactory() + mgr := newChannelsMgr(master, query, factory) + defer mgr.removeAllDQLStream() + + collID := UniqueID(getUniqueIntGeneratorIns().get()) + _, err := mgr.getDQLStream(collID) + assert.NotEqual(t, nil, err) + + err = mgr.createDQLStream(collID) + assert.Equal(t, nil, err) + + _, err = mgr.getDQLStream(collID) + assert.Equal(t, nil, err) +} + +func TestChannelsMgrImpl_removeDQLMsgStream(t *testing.T) { + master := newMockMaster() + query := newMockQueryService() + factory := msgstream.NewSimpleMsgStreamFactory() + mgr := newChannelsMgr(master, query, factory) + defer mgr.removeAllDQLStream() + + collID := UniqueID(getUniqueIntGeneratorIns().get()) + _, err := mgr.getDQLStream(collID) + assert.NotEqual(t, nil, err) + + err = mgr.removeDQLStream(collID) + assert.NotEqual(t, nil, err) + + err = mgr.createDQLStream(collID) + assert.Equal(t, nil, err) + + _, err = mgr.getDQLStream(collID) + assert.Equal(t, nil, err) + + err = mgr.removeDQLStream(collID) + assert.Equal(t, nil, err) + + _, err = mgr.getDQLStream(collID) + assert.NotEqual(t, nil, err) +} + +func TestChannelsMgrImpl_removeAllDQLMsgStream(t *testing.T) { + master := newMockMaster() + query := newMockQueryService() + factory := msgstream.NewSimpleMsgStreamFactory() + mgr := newChannelsMgr(master, query, factory) + defer mgr.removeAllDQLStream() + + num := 10 + for i := 0; i < num; i++ { + collID := UniqueID(getUniqueIntGeneratorIns().get()) + err := mgr.createDQLStream(collID) + assert.Equal(t, nil, err) + } +} diff --git a/internal/proxynode/proxy_node.go b/internal/proxynode/proxy_node.go index 04a2e6faedd7944ee510c8f8fca91f4c60deec75..09e694d542121cf2fcfb56fe62060dd35a1c891a 100644 --- a/internal/proxynode/proxy_node.go +++ b/internal/proxynode/proxy_node.go @@ -228,9 +228,10 @@ func (node *ProxyNode) Init() error { node.segAssigner = segAssigner node.segAssigner.PeerID = Params.ProxyID - // TODO(dragondriver): use real master service instance + // TODO(dragondriver): use real master service & query service instance mockMasterIns := newMockMaster() - chMgr := newChannelsMgr(mockMasterIns, node.msFactory) + mockQueryIns := newMockQueryService() + chMgr := newChannelsMgr(mockMasterIns, mockQueryIns, node.msFactory) node.chMgr = chMgr node.sched, err = NewTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)