diff --git a/internal/masterservice/task.go b/internal/masterservice/task.go index 373f00c0b4dfde335a6464060ea3f4cb78bb8f46..5919b894635ed14f26f848b667411c55c639eae6 100644 --- a/internal/masterservice/task.go +++ b/internal/masterservice/task.go @@ -414,13 +414,10 @@ func (t *ShowPartitionReqTask) IgnoreTimeStamp() bool { } func (t *ShowPartitionReqTask) Execute() error { - coll, err := t.core.MetaTable.GetCollectionByID(t.Req.CollectionID) + coll, err := t.core.MetaTable.GetCollectionByName(t.Req.CollectionName) if err != nil { return err } - if coll.Schema.Name != t.Req.CollectionName { - return errors.Errorf("collection %s not exist", t.Req.CollectionName) - } for _, partID := range coll.PartitionIDs { partMeta, err := t.core.MetaTable.GetPartitionByID(partID) if err != nil { diff --git a/internal/proxynode/impl.go b/internal/proxynode/impl.go index 6b248738adb672799e1d00ad5aaddf4ec563fef1..a4ac7b2257e1beefb146a5c284a4f0dfe35af772 100644 --- a/internal/proxynode/impl.go +++ b/internal/proxynode/impl.go @@ -40,6 +40,7 @@ func (node *NodeImpl) CreateCollection(request *milvuspb.CreateCollectionRequest Condition: NewTaskCondition(ctx), CreateCollectionRequest: request, masterClient: node.masterClient, + dataServiceClient: node.dataServiceClient, } var cancel func() cct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval) @@ -79,6 +80,7 @@ func (node *NodeImpl) DropCollection(request *milvuspb.DropCollectionRequest) (* Condition: NewTaskCondition(ctx), DropCollectionRequest: request, masterClient: node.masterClient, + dataServiceClient: node.dataServiceClient, } var cancel func() dct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval) @@ -569,8 +571,9 @@ func (node *NodeImpl) Insert(request *milvuspb.InsertRequest) (*milvuspb.InsertR span.SetTag("partition tag", request.PartitionName) log.Println("insert into: ", request.CollectionName) it := &InsertTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), + ctx: ctx, + Condition: NewTaskCondition(ctx), + dataServiceClient: node.dataServiceClient, BaseInsertTask: BaseInsertTask{ BaseMsg: msgstream.BaseMsg{ HashValues: request.HashKeys, @@ -585,8 +588,7 @@ func (node *NodeImpl) Insert(request *milvuspb.InsertRequest) (*milvuspb.InsertR RowData: request.RowData, }, }, - manipulationMsgStream: node.manipulationMsgStream, - rowIDAllocator: node.idAllocator, + rowIDAllocator: node.idAllocator, } if len(it.PartitionName) <= 0 { it.PartitionName = Params.DefaultPartitionTag diff --git a/internal/proxynode/insert_channels.go b/internal/proxynode/insert_channels.go new file mode 100644 index 0000000000000000000000000000000000000000..e9d16be74317cbec225743618683a9aa13eceacc --- /dev/null +++ b/internal/proxynode/insert_channels.go @@ -0,0 +1,164 @@ +package proxynode + +import ( + "context" + "log" + "reflect" + "sort" + "strconv" + "sync" + + "github.com/zilliztech/milvus-distributed/internal/msgstream/pulsarms" + + "github.com/zilliztech/milvus-distributed/internal/errors" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" +) + +func SliceContain(s interface{}, item interface{}) bool { + ss := reflect.ValueOf(s) + if ss.Kind() != reflect.Slice { + panic("SliceContain expect a slice") + } + + for i := 0; i < ss.Len(); i++ { + if ss.Index(i).Interface() == item { + return true + } + } + + return false +} + +func SliceSetEqual(s1 interface{}, s2 interface{}) bool { + ss1 := reflect.ValueOf(s1) + ss2 := reflect.ValueOf(s2) + if ss1.Kind() != reflect.Slice { + panic("expect a slice") + } + if ss2.Kind() != reflect.Slice { + panic("expect a slice") + } + if ss1.Len() != ss2.Len() { + return false + } + for i := 0; i < ss1.Len(); i++ { + if !SliceContain(s2, ss1.Index(i).Interface()) { + return false + } + } + return true +} + +func SortedSliceEqual(s1 interface{}, s2 interface{}) bool { + ss1 := reflect.ValueOf(s1) + ss2 := reflect.ValueOf(s2) + if ss1.Kind() != reflect.Slice { + panic("expect a slice") + } + if ss2.Kind() != reflect.Slice { + panic("expect a slice") + } + if ss1.Len() != ss2.Len() { + return false + } + for i := 0; i < ss1.Len(); i++ { + if ss2.Index(i).Interface() != ss1.Index(i).Interface() { + return false + } + } + return true +} + +type InsertChannelsMap struct { + collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams + insertChannels [][]string // it's a little confusing to use []string as the key of map + insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now + droppedBitMap []int // 0 -> normal, 1 -> dropped + mtx sync.RWMutex + nodeInstance *NodeImpl +} + +func (m *InsertChannelsMap) createInsertMsgStream(collID UniqueID, channels []string) error { + m.mtx.Lock() + defer m.mtx.Unlock() + + _, ok := m.collectionID2InsertChannels[collID] + if ok { + return errors.New("impossible and forbidden to create message stream twice") + } + sort.Slice(channels, func(i, j int) bool { + return channels[i] <= channels[j] + }) + for loc, existedChannels := range m.insertChannels { + if m.droppedBitMap[loc] == 0 && SortedSliceEqual(existedChannels, channels) { + m.collectionID2InsertChannels[collID] = loc + return nil + } + } + m.insertChannels = append(m.insertChannels, channels) + m.collectionID2InsertChannels[collID] = len(m.insertChannels) - 1 + stream := pulsarms.NewPulsarMsgStream(context.Background(), Params.MsgStreamInsertBufSize) + stream.SetPulsarClient(Params.PulsarAddress) + stream.CreatePulsarProducers(channels) + repack := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { + return insertRepackFunc(tsMsgs, hashKeys, m.nodeInstance.segAssigner, true) + } + stream.SetRepackFunc(repack) + stream.Start() + m.insertMsgStreams = append(m.insertMsgStreams, stream) + m.droppedBitMap = append(m.droppedBitMap, 0) + + return nil +} + +func (m *InsertChannelsMap) closeInsertMsgStream(collID UniqueID) error { + m.mtx.Lock() + defer m.mtx.Unlock() + + loc, ok := m.collectionID2InsertChannels[collID] + if !ok { + return errors.New("cannot find collection with id: " + strconv.Itoa(int(collID))) + } + if m.droppedBitMap[loc] != 0 { + return errors.New("insert message stream already closed") + } + m.insertMsgStreams[loc].Close() + log.Print("close insert message stream ...") + + m.droppedBitMap[loc] = 1 + delete(m.collectionID2InsertChannels, collID) + + return nil +} + +func (m *InsertChannelsMap) getInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) { + m.mtx.RLock() + defer m.mtx.RUnlock() + + loc, ok := m.collectionID2InsertChannels[collID] + if !ok { + return nil, errors.New("cannot find collection with id: " + strconv.Itoa(int(collID))) + } + + if m.droppedBitMap[loc] != 0 { + return nil, errors.New("insert message stream already closed") + } + + return m.insertMsgStreams[loc], nil +} + +func newInsertChannelsMap(node *NodeImpl) *InsertChannelsMap { + return &InsertChannelsMap{ + collectionID2InsertChannels: make(map[UniqueID]int), + insertChannels: make([][]string, 0), + insertMsgStreams: make([]msgstream.MsgStream, 0), + nodeInstance: node, + } +} + +var globalInsertChannelsMap *InsertChannelsMap + +func initGlobalInsertChannelsMap(node *NodeImpl) { + globalInsertChannelsMap = newInsertChannelsMap(node) +} diff --git a/internal/proxynode/proxy_node.go b/internal/proxynode/proxy_node.go index c7d2623796fde5ab0738bd3143488f537cd8016b..506b0f5c2cf7c26d928e4ec01dcafd2b70cf34e8 100644 --- a/internal/proxynode/proxy_node.go +++ b/internal/proxynode/proxy_node.go @@ -244,6 +244,9 @@ func (node *NodeImpl) Init() error { } func (node *NodeImpl) Start() error { + initGlobalInsertChannelsMap(node) + log.Println("init global insert channels map ...") + initGlobalMetaCache(node.ctx, node) log.Println("init global meta cache ...") diff --git a/internal/proxynode/task.go b/internal/proxynode/task.go index 8149ae0ef22d817e06e87323e53f9f0436c062fc..54b28ea7e1571aa26bb9f2d871c5a651697a8430 100644 --- a/internal/proxynode/task.go +++ b/internal/proxynode/task.go @@ -7,6 +7,8 @@ import ( "math" "strconv" + "github.com/zilliztech/milvus-distributed/internal/proto/datapb" + "github.com/opentracing/opentracing-go" oplog "github.com/opentracing/opentracing-go/log" @@ -41,10 +43,10 @@ type BaseInsertTask = msgstream.InsertMsg type InsertTask struct { BaseInsertTask Condition - result *milvuspb.InsertResponse - manipulationMsgStream *pulsarms.PulsarMsgStream - ctx context.Context - rowIDAllocator *allocator.IDAllocator + dataServiceClient DataServiceClient + result *milvuspb.InsertResponse + ctx context.Context + rowIDAllocator *allocator.IDAllocator } func (it *InsertTask) OnEnqueue() error { @@ -161,8 +163,6 @@ func (it *InsertTask) Execute() error { } tsMsg.SetMsgContext(ctx) span.LogFields(oplog.String("send msg", "send msg")) - msgPack.Msgs[0] = tsMsg - err = it.manipulationMsgStream.Produce(msgPack) it.result = &milvuspb.InsertResponse{ Status: &commonpb.Status{ @@ -171,11 +171,45 @@ func (it *InsertTask) Execute() error { RowIDBegin: rowIDBegin, RowIDEnd: rowIDEnd, } + + msgPack.Msgs[0] = tsMsg + + stream, err := globalInsertChannelsMap.getInsertMsgStream(description.CollectionID) + if err != nil { + collectionInsertChannels, err := it.dataServiceClient.GetInsertChannels(&datapb.InsertChannelRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, // todo + MsgID: it.Base.MsgID, // todo + Timestamp: 0, // todo + SourceID: Params.ProxyID, + }, + DbID: 0, // todo + CollectionID: description.CollectionID, + }) + if err != nil { + return err + } + err = globalInsertChannelsMap.createInsertMsgStream(description.CollectionID, collectionInsertChannels) + if err != nil { + return err + } + } + stream, err = globalInsertChannelsMap.getInsertMsgStream(description.CollectionID) + if err != nil { + it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR + it.result.Status.Reason = err.Error() + span.LogFields(oplog.Error(err)) + return err + } + + err = stream.Produce(msgPack) if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR it.result.Status.Reason = err.Error() span.LogFields(oplog.Error(err)) + return err } + return nil } @@ -188,10 +222,11 @@ func (it *InsertTask) PostExecute() error { type CreateCollectionTask struct { Condition *milvuspb.CreateCollectionRequest - masterClient MasterClient - result *commonpb.Status - ctx context.Context - schema *schemapb.CollectionSchema + masterClient MasterClient + dataServiceClient DataServiceClient + result *commonpb.Status + ctx context.Context + schema *schemapb.CollectionSchema } func (cct *CreateCollectionTask) OnEnqueue() error { @@ -293,7 +328,37 @@ func (cct *CreateCollectionTask) PreExecute() error { func (cct *CreateCollectionTask) Execute() error { var err error cct.result, err = cct.masterClient.CreateCollection(cct.CreateCollectionRequest) - return err + if err != nil { + return err + } + if cct.result.ErrorCode == commonpb.ErrorCode_SUCCESS { + err = globalMetaCache.Sync(cct.CollectionName) + if err != nil { + return err + } + desc, err := globalMetaCache.Get(cct.CollectionName) + if err != nil { + return err + } + collectionInsertChannels, err := cct.dataServiceClient.GetInsertChannels(&datapb.InsertChannelRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_kInsert, // todo + MsgID: cct.Base.MsgID, // todo + Timestamp: 0, // todo + SourceID: Params.ProxyID, + }, + DbID: 0, // todo + CollectionID: desc.CollectionID, + }) + if err != nil { + return err + } + err = globalInsertChannelsMap.createInsertMsgStream(desc.CollectionID, collectionInsertChannels) + if err != nil { + return err + } + } + return nil } func (cct *CreateCollectionTask) PostExecute() error { @@ -303,9 +368,10 @@ func (cct *CreateCollectionTask) PostExecute() error { type DropCollectionTask struct { Condition *milvuspb.DropCollectionRequest - masterClient MasterClient - result *commonpb.Status - ctx context.Context + masterClient MasterClient + dataServiceClient DataServiceClient + result *commonpb.Status + ctx context.Context } func (dct *DropCollectionTask) OnEnqueue() error { @@ -350,6 +416,11 @@ func (dct *DropCollectionTask) PreExecute() error { func (dct *DropCollectionTask) Execute() error { var err error dct.result, err = dct.masterClient.DropCollection(dct.DropCollectionRequest) + if dct.result.ErrorCode == commonpb.ErrorCode_SUCCESS { + _ = globalMetaCache.Sync(dct.CollectionName) + desc, _ := globalMetaCache.Get(dct.CollectionName) + _ = globalInsertChannelsMap.closeInsertMsgStream(desc.CollectionID) + } return err }