Select Git revision
cluster_test.go
cluster_test.go 4.97 KiB
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package datacoord
import (
"context"
"fmt"
"testing"
memkv "github.com/milvus-io/milvus/internal/kv/mem"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/stretchr/testify/assert"
)
type SpyClusterStore struct {
*NodesInfo
ch chan interface{}
}
func (s *SpyClusterStore) SetNode(nodeID UniqueID, node *NodeInfo) {
s.NodesInfo.SetNode(nodeID, node)
s.ch <- struct{}{}
}
func (s *SpyClusterStore) DeleteNode(nodeID UniqueID) {
s.NodesInfo.DeleteNode(nodeID)
s.ch <- struct{}{}
}
func spyWatchPolicy(ch chan interface{}) channelAssignPolicy {
return func(cluster []*NodeInfo, channel string, collectionID UniqueID) []*NodeInfo {
for _, node := range cluster {
for _, c := range node.info.GetChannels() {
if c.GetName() == channel && c.GetCollectionID() == collectionID {
ch <- struct{}{}
return nil
}
}
}
ret := make([]*NodeInfo, 0)
c := &datapb.ChannelStatus{
Name: channel,
State: datapb.ChannelWatchState_Uncomplete,
CollectionID: collectionID,
}
n := cluster[0].Clone(AddChannels([]*datapb.ChannelStatus{c}))
ret = append(ret, n)
return ret
}
}
func TestClusterCreate(t *testing.T) {
ch := make(chan interface{})
kv := memkv.NewMemoryKV()
spyClusterStore := &SpyClusterStore{
NodesInfo: NewNodesInfo(),
ch: ch,
}
cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{})
assert.Nil(t, err)
defer cluster.Close()
addr := "localhost:8080"
info := &datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
}
nodes := []*NodeInfo{NewNodeInfo(context.TODO(), info)}
cluster.Startup(nodes)
<-ch
dataNodes := cluster.GetNodes()
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[0].info.GetAddress())
}
func TestRegister(t *testing.T) {
registerPolicy := newEmptyRegisterPolicy()
ch := make(chan interface{})
kv := memkv.NewMemoryKV()
spyClusterStore := &SpyClusterStore{
NodesInfo: NewNodesInfo(),
ch: ch,
}
cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}, withRegisterPolicy(registerPolicy))
assert.Nil(t, err)
defer cluster.Close()
addr := "localhost:8080"
cluster.Startup(nil)
info := &datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
}
node := NewNodeInfo(context.TODO(), info)
cluster.Register(node)
<-ch
dataNodes := cluster.GetNodes()
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[0].info.GetAddress())
}
func TestUnregister(t *testing.T) {
unregisterPolicy := newEmptyUnregisterPolicy()
ch := make(chan interface{})
kv := memkv.NewMemoryKV()
spyClusterStore := &SpyClusterStore{
NodesInfo: NewNodesInfo(),
ch: ch,
}
cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}, withUnregistorPolicy(unregisterPolicy))
assert.Nil(t, err)
defer cluster.Close()
addr := "localhost:8080"
info := &datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
}
nodes := []*NodeInfo{NewNodeInfo(context.TODO(), info)}
cluster.Startup(nodes)
<-ch
dataNodes := cluster.GetNodes()
assert.EqualValues(t, 1, len(dataNodes))
assert.EqualValues(t, "localhost:8080", dataNodes[0].info.GetAddress())
cluster.UnRegister(nodes[0])
<-ch
dataNodes = cluster.GetNodes()
assert.EqualValues(t, 0, len(dataNodes))
}
func TestWatchIfNeeded(t *testing.T) {
ch := make(chan interface{})
kv := memkv.NewMemoryKV()
spyClusterStore := &SpyClusterStore{
NodesInfo: NewNodesInfo(),
ch: ch,
}
pch := make(chan interface{})
spyPolicy := spyWatchPolicy(pch)
cluster, err := NewCluster(context.TODO(), kv, spyClusterStore, dummyPosProvider{}, withAssignPolicy(spyPolicy))
assert.Nil(t, err)
defer cluster.Close()
addr := "localhost:8080"
info := &datapb.DataNodeInfo{
Address: addr,
Version: 1,
Channels: []*datapb.ChannelStatus{},
}
node := NewNodeInfo(context.TODO(), info)
node.client, err = newMockDataNodeClient(1, make(chan interface{}))
assert.Nil(t, err)
nodes := []*NodeInfo{node}
cluster.Startup(nodes)
fmt.Println("11111")
<-ch
chName := "ch1"
cluster.Watch(chName, 0)
fmt.Println("222")
<-ch
dataNodes := cluster.GetNodes()
assert.EqualValues(t, 1, len(dataNodes[0].info.GetChannels()))
assert.EqualValues(t, chName, dataNodes[0].info.Channels[0].Name)
cluster.Watch(chName, 0)
<-pch
}