diff --git a/.github/workflows/publish-builder.yaml b/.github/workflows/publish-builder.yaml index 68d244f09783239fed40a6adafd837ae80190e6c..7b4a0fdd240a422604c2875f7eae62346464a620 100644 --- a/.github/workflows/publish-builder.yaml +++ b/.github/workflows/publish-builder.yaml @@ -29,7 +29,7 @@ jobs: - name: Checkout uses: actions/checkout@v2 - name: Check Dockerfile - uses: reviewdog/action-hadolint@v1 + uses: reviewdog/action-hadolint@v1.16.1 with: github_token: ${{ secrets.GITHUB_TOKEN }} reporter: github-pr-check # Default is github-pr-check diff --git a/.github/workflows/publish-test-images.yaml b/.github/workflows/publish-test-images.yaml index 5c60f96bea06154d4f806d2df3353b5cae3d1c2c..7e9dee765aa522c32f7350c408b1fc62ac3123d2 100644 --- a/.github/workflows/publish-test-images.yaml +++ b/.github/workflows/publish-test-images.yaml @@ -25,7 +25,7 @@ jobs: - name: Checkout uses: actions/checkout@v2 - name: Check Dockerfile - uses: reviewdog/action-hadolint@v1 + uses: reviewdog/action-hadolint@v1.16.1 with: github_token: ${{ secrets.GITHUB_TOKEN }} reporter: github-pr-check # Default is github-pr-check diff --git a/internal/datanode/data_sync_service_test.go b/internal/datanode/data_sync_service_test.go index c962167f2da71c99ecb4706d9ddec6e361344eb7..e9a1f95f1b01498b209d2ae86f00a1f8d7843d8b 100644 --- a/internal/datanode/data_sync_service_test.go +++ b/internal/datanode/data_sync_service_test.go @@ -100,12 +100,12 @@ func TestDataSyncService_Start(t *testing.T) { var ddMsgStream msgstream.MsgStream = ddStream ddMsgStream.Start() - err = insertMsgStream.Produce(ctx, &msgPack) + err = insertMsgStream.Produce(&msgPack) assert.NoError(t, err) - err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = insertMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) - err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = ddMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) // dataSync diff --git a/internal/datanode/flow_graph_dd_node.go b/internal/datanode/flow_graph_dd_node.go index a4b90888396b5c3b389994e15efda698486af4e6..fa6c3e1c6d0eb249840b617c1e3f904d4a60e738 100644 --- a/internal/datanode/flow_graph_dd_node.go +++ b/internal/datanode/flow_graph_dd_node.go @@ -11,6 +11,7 @@ import ( "github.com/golang/protobuf/proto" "go.uber.org/zap" + "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/kv" miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" "github.com/zilliztech/milvus-distributed/internal/log" @@ -18,6 +19,8 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/storage" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" + "github.com/zilliztech/milvus-distributed/internal/util/trace" ) type ddNode struct { @@ -69,7 +72,7 @@ func (ddNode *ddNode) Name() string { return "ddNode" } -func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (ddNode *ddNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { if len(in) != 1 { log.Error("Invalid operate message input in ddNode", zap.Int("input length", len(in))) @@ -83,7 +86,13 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con } if msMsg == nil { - return []Msg{}, ctx + return []Msg{} + } + var spans []opentracing.Span + for _, msg := range msMsg.TsMessages() { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) } ddNode.ddMsg = &ddMsg{ @@ -165,8 +174,12 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con default: } + for _, span := range spans { + span.Finish() + } + var res Msg = ddNode.ddMsg - return []Msg{res}, ctx + return []Msg{res} } /* @@ -245,6 +258,10 @@ func flushTxn(ddlData *sync.Map, } func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + msg.SetTraceCtx(ctx) + defer sp.Finish() + collectionID := msg.CollectionID // add collection @@ -295,6 +312,10 @@ func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) { dropCollection will drop collection in ddRecords but won't drop collection in replica */ func (ddNode *ddNode) dropCollection(msg *msgstream.DropCollectionMsg) { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + msg.SetTraceCtx(ctx) + defer sp.Finish() + collectionID := msg.CollectionID // remove collection @@ -327,6 +348,10 @@ func (ddNode *ddNode) dropCollection(msg *msgstream.DropCollectionMsg) { } func (ddNode *ddNode) createPartition(msg *msgstream.CreatePartitionMsg) { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + msg.SetTraceCtx(ctx) + defer sp.Finish() + partitionID := msg.PartitionID collectionID := msg.CollectionID @@ -363,6 +388,9 @@ func (ddNode *ddNode) createPartition(msg *msgstream.CreatePartitionMsg) { } func (ddNode *ddNode) dropPartition(msg *msgstream.DropPartitionMsg) { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + msg.SetTraceCtx(ctx) + defer sp.Finish() partitionID := msg.PartitionID collectionID := msg.CollectionID diff --git a/internal/datanode/flow_graph_dd_node_test.go b/internal/datanode/flow_graph_dd_node_test.go index 8fa6ffa56bb9c3f559baa16741640098889cfd03..efa2f848d015e98f90ea1aa3ae1a897f5bd55a91 100644 --- a/internal/datanode/flow_graph_dd_node_test.go +++ b/internal/datanode/flow_graph_dd_node_test.go @@ -160,5 +160,5 @@ func TestFlowGraphDDNode_Operate(t *testing.T) { msgStream := flowgraph.GenerateMsgStreamMsg(tsMessages, Timestamp(0), Timestamp(3), startPos, startPos) var inMsg Msg = msgStream - ddNode.Operate(ctx, []Msg{inMsg}) + ddNode.Operate([]Msg{inMsg}) } diff --git a/internal/datanode/flow_graph_filter_dm_node.go b/internal/datanode/flow_graph_filter_dm_node.go index 8c3f065e0b3b45f608f64e65f148b5b0c02ad29d..fe33458f26f105712c6340271b3c43c7e043872c 100644 --- a/internal/datanode/flow_graph_filter_dm_node.go +++ b/internal/datanode/flow_graph_filter_dm_node.go @@ -1,15 +1,16 @@ package datanode import ( - "context" "math" - "go.uber.org/zap" - + "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" + "github.com/zilliztech/milvus-distributed/internal/util/trace" + "go.uber.org/zap" ) type filterDmNode struct { @@ -21,7 +22,7 @@ func (fdmNode *filterDmNode) Name() string { return "fdmNode" } -func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (fdmNode *filterDmNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { if len(in) != 2 { log.Error("Invalid operate message input in filterDmNode", zap.Int("input length", len(in))) @@ -41,7 +42,13 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont } if msgStreamMsg == nil || ddMsg == nil { - return []Msg{}, ctx + return []Msg{} + } + var spans []opentracing.Span + for _, msg := range msgStreamMsg.TsMessages() { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) } fdmNode.ddMsg = ddMsg @@ -77,11 +84,18 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont iMsg.endPositions = append(iMsg.endPositions, msgStreamMsg.EndPositions()...) iMsg.gcRecord = ddMsg.gcRecord var res Msg = &iMsg - return []Msg{res}, ctx + for _, sp := range spans { + sp.Finish() + } + return []Msg{res} } func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg { // No dd record, do all insert requests. + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + msg.SetTraceCtx(ctx) + defer sp.Finish() + records, ok := fdmNode.ddMsg.collectionRecords[msg.CollectionID] if !ok { return msg diff --git a/internal/datanode/flow_graph_gc_node.go b/internal/datanode/flow_graph_gc_node.go index 982869b2eaa307089cbc39aa288d87fa5385cebf..571de3ddcb5b97e8a7fd636fb03d2db3d7944fac 100644 --- a/internal/datanode/flow_graph_gc_node.go +++ b/internal/datanode/flow_graph_gc_node.go @@ -1,7 +1,7 @@ package datanode import ( - "context" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" "go.uber.org/zap" @@ -17,7 +17,7 @@ func (gcNode *gcNode) Name() string { return "gcNode" } -func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (gcNode *gcNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { if len(in) != 1 { log.Error("Invalid operate message input in gcNode", zap.Int("input length", len(in))) @@ -31,7 +31,7 @@ func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con } if gcMsg == nil { - return []Msg{}, ctx + return []Msg{} } // drop collections @@ -42,7 +42,7 @@ func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con } } - return nil, ctx + return nil } func newGCNode(replica Replica) *gcNode { diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index c4e875a567def660172dc03fcd41ea95891a8f33..a0703c10564036490860f42ad67c6a9731b02b65 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -11,11 +11,14 @@ import ( "go.uber.org/zap" + "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/kv" miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/storage" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" + "github.com/zilliztech/milvus-distributed/internal/util/trace" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" @@ -31,26 +34,25 @@ const ( type ( InsertData = storage.InsertData Blob = storage.Blob - - insertBufferNode struct { - BaseNode - insertBuffer *insertBuffer - replica Replica - flushMeta *binlogMeta - flushMap sync.Map - - minIOKV kv.Base - - timeTickStream msgstream.MsgStream - segmentStatisticsStream msgstream.MsgStream - completeFlushStream msgstream.MsgStream - } - - insertBuffer struct { - insertData map[UniqueID]*InsertData // SegmentID to InsertData - maxSize int32 - } ) +type insertBufferNode struct { + BaseNode + insertBuffer *insertBuffer + replica Replica + flushMeta *binlogMeta + flushMap sync.Map + + minIOKV kv.Base + + timeTickStream msgstream.MsgStream + segmentStatisticsStream msgstream.MsgStream + completeFlushStream msgstream.MsgStream +} + +type insertBuffer struct { + insertData map[UniqueID]*InsertData // SegmentID to InsertData + maxSize int32 +} func (ib *insertBuffer) size(segmentID UniqueID) int32 { if ib.insertData == nil || len(ib.insertData) <= 0 { @@ -85,7 +87,7 @@ func (ibNode *insertBufferNode) Name() string { return "ibNode" } -func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (ibNode *insertBufferNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { if len(in) != 1 { log.Error("Invalid operate message input in insertBufferNode", zap.Int("input length", len(in))) @@ -99,12 +101,20 @@ func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, c } if iMsg == nil { - return []Msg{}, ctx + return []Msg{} + } + + var spans []opentracing.Span + for _, msg := range iMsg.insertMessages { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) } // Updating segment statistics uniqueSeg := make(map[UniqueID]int64) for _, msg := range iMsg.insertMessages { + currentSegID := msg.GetSegmentID() collID := msg.GetCollectionID() partitionID := msg.GetPartitionID() @@ -537,8 +547,11 @@ func (ibNode *insertBufferNode) Operate(ctx context.Context, in []Msg) ([]Msg, c gcRecord: iMsg.gcRecord, timeRange: iMsg.timeRange, } + for _, sp := range spans { + sp.Finish() + } - return []Msg{res}, ctx + return []Msg{res} } func flushSegmentTxn(collMeta *etcdpb.CollectionMeta, segID UniqueID, partitionID UniqueID, collID UniqueID, @@ -639,7 +652,7 @@ func (ibNode *insertBufferNode) completeFlush(segID UniqueID, finishCh <-chan bo } msgPack.Msgs = append(msgPack.Msgs, msg) - err := ibNode.completeFlushStream.Produce(context.TODO(), &msgPack) + err := ibNode.completeFlushStream.Produce(&msgPack) if err != nil { log.Error(".. Produce complete flush msg failed ..", zap.Error(err)) } @@ -663,7 +676,7 @@ func (ibNode *insertBufferNode) writeHardTimeTick(ts Timestamp) error { }, } msgPack.Msgs = append(msgPack.Msgs, &timeTickMsg) - return ibNode.timeTickStream.Produce(context.TODO(), &msgPack) + return ibNode.timeTickStream.Produce(&msgPack) } func (ibNode *insertBufferNode) updateSegStatistics(segIDs []UniqueID) error { @@ -698,7 +711,7 @@ func (ibNode *insertBufferNode) updateSegStatistics(segIDs []UniqueID) error { var msgPack = msgstream.MsgPack{ Msgs: []msgstream.TsMsg{msg}, } - return ibNode.segmentStatisticsStream.Produce(context.TODO(), &msgPack) + return ibNode.segmentStatisticsStream.Produce(&msgPack) } func (ibNode *insertBufferNode) getCollectionSchemaByID(collectionID UniqueID) (*schemapb.CollectionSchema, error) { diff --git a/internal/datanode/flow_graph_insert_buffer_node_test.go b/internal/datanode/flow_graph_insert_buffer_node_test.go index c71e1d319710b74180652fb51dfe898d18482951..0ba783e7b6982d87e21e9cc86f2950bed4344dcd 100644 --- a/internal/datanode/flow_graph_insert_buffer_node_test.go +++ b/internal/datanode/flow_graph_insert_buffer_node_test.go @@ -52,7 +52,7 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) { iBNode := newInsertBufferNode(ctx, newBinlogMeta(), replica, msFactory) inMsg := genInsertMsg() var iMsg flowgraph.Msg = &inMsg - iBNode.Operate(ctx, []flowgraph.Msg{iMsg}) + iBNode.Operate([]flowgraph.Msg{iMsg}) } func genInsertMsg() insertMsg { diff --git a/internal/datanode/flow_graph_message.go b/internal/datanode/flow_graph_message.go index f8e6e1aef5a122f360b0e964232287ac256ecc36..0e9d33d43fcc0ce76c44d04a30ff01324b169135 100644 --- a/internal/datanode/flow_graph_message.go +++ b/internal/datanode/flow_graph_message.go @@ -11,55 +11,53 @@ type ( MsgStreamMsg = flowgraph.MsgStreamMsg ) -type ( - key2SegMsg struct { - tsMessages []msgstream.TsMsg - timeRange TimeRange - } +type key2SegMsg struct { + tsMessages []msgstream.TsMsg + timeRange TimeRange +} - ddMsg struct { - collectionRecords map[UniqueID][]*metaOperateRecord - partitionRecords map[UniqueID][]*metaOperateRecord - flushMessages []*flushMsg - gcRecord *gcRecord - timeRange TimeRange - } +type ddMsg struct { + collectionRecords map[UniqueID][]*metaOperateRecord + partitionRecords map[UniqueID][]*metaOperateRecord + flushMessages []*flushMsg + gcRecord *gcRecord + timeRange TimeRange +} - metaOperateRecord struct { - createOrDrop bool // create: true, drop: false - timestamp Timestamp - } +type metaOperateRecord struct { + createOrDrop bool // create: true, drop: false + timestamp Timestamp +} - insertMsg struct { - insertMessages []*msgstream.InsertMsg - flushMessages []*flushMsg - gcRecord *gcRecord - timeRange TimeRange - startPositions []*internalpb.MsgPosition - endPositions []*internalpb.MsgPosition - } +type insertMsg struct { + insertMessages []*msgstream.InsertMsg + flushMessages []*flushMsg + gcRecord *gcRecord + timeRange TimeRange + startPositions []*internalpb.MsgPosition + endPositions []*internalpb.MsgPosition +} - deleteMsg struct { - deleteMessages []*msgstream.DeleteMsg - timeRange TimeRange - } +type deleteMsg struct { + deleteMessages []*msgstream.DeleteMsg + timeRange TimeRange +} - gcMsg struct { - gcRecord *gcRecord - timeRange TimeRange - } +type gcMsg struct { + gcRecord *gcRecord + timeRange TimeRange +} - gcRecord struct { - collections []UniqueID - } +type gcRecord struct { + collections []UniqueID +} - flushMsg struct { - msgID UniqueID - timestamp Timestamp - segmentIDs []UniqueID - collectionID UniqueID - } -) +type flushMsg struct { + msgID UniqueID + timestamp Timestamp + segmentIDs []UniqueID + collectionID UniqueID +} func (ksMsg *key2SegMsg) TimeTick() Timestamp { return ksMsg.timeRange.timestampMax diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index 085532317f22d214f060fbe7dbc8aa3d338feb19..0cc29b52f574d8ab85506a24362c896b4e74a917 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -307,7 +307,7 @@ func (s *Server) startStatsChannel(ctx context.Context) { return default: } - msgPack, _ := statsStream.Consume() + msgPack := statsStream.Consume() for _, msg := range msgPack.Msgs { statistics, ok := msg.(*msgstream.SegmentStatisticsMsg) if !ok { @@ -338,7 +338,7 @@ func (s *Server) startSegmentFlushChannel(ctx context.Context) { return default: } - msgPack, _ := flushStream.Consume() + msgPack := flushStream.Consume() for _, msg := range msgPack.Msgs { if msg.Type() != commonpb.MsgType_SegmentFlushDone { continue @@ -368,7 +368,7 @@ func (s *Server) startDDChannel(ctx context.Context) { return default: } - msgPack, ctx := ddStream.Consume() + msgPack := ddStream.Consume() for _, msg := range msgPack.Msgs { if err := s.ddHandler.HandleDDMsg(ctx, msg); err != nil { log.Error("handle dd msg error", zap.Error(err)) @@ -622,10 +622,10 @@ func (s *Server) openNewSegment(ctx context.Context, collectionID UniqueID, part Segment: segmentInfo, }, } - msgPack := &msgstream.MsgPack{ + msgPack := msgstream.MsgPack{ Msgs: []msgstream.TsMsg{infoMsg}, } - if err = s.segmentInfoStream.Produce(ctx, msgPack); err != nil { + if err = s.segmentInfoStream.Produce(&msgPack); err != nil { return err } return nil diff --git a/internal/masterservice/master_service.go b/internal/masterservice/master_service.go index 756a97bdbf7a26b96d9a413d9116139d60516abc..44eb529877bc8f4700c2616677188f8794bddb7d 100644 --- a/internal/masterservice/master_service.go +++ b/internal/masterservice/master_service.go @@ -445,10 +445,10 @@ func (c *Core) setMsgStreams() error { TimeTickMsg: timeTickResult, } msgPack.Msgs = append(msgPack.Msgs, timeTickMsg) - if err := timeTickStream.Broadcast(c.ctx, &msgPack); err != nil { + if err := timeTickStream.Broadcast(&msgPack); err != nil { return err } - if err := ddStream.Broadcast(c.ctx, &msgPack); err != nil { + if err := ddStream.Broadcast(&msgPack); err != nil { return err } return nil @@ -457,6 +457,7 @@ func (c *Core) setMsgStreams() error { c.DdCreateCollectionReq = func(ctx context.Context, req *internalpb.CreateCollectionRequest) error { msgPack := ms.MsgPack{} baseMsg := ms.BaseMsg{ + Ctx: ctx, BeginTimestamp: req.Base.Timestamp, EndTimestamp: req.Base.Timestamp, HashValues: []uint32{0}, @@ -466,7 +467,7 @@ func (c *Core) setMsgStreams() error { CreateCollectionRequest: *req, } msgPack.Msgs = append(msgPack.Msgs, collMsg) - if err := ddStream.Broadcast(ctx, &msgPack); err != nil { + if err := ddStream.Broadcast(&msgPack); err != nil { return err } return nil @@ -475,6 +476,7 @@ func (c *Core) setMsgStreams() error { c.DdDropCollectionReq = func(ctx context.Context, req *internalpb.DropCollectionRequest) error { msgPack := ms.MsgPack{} baseMsg := ms.BaseMsg{ + Ctx: ctx, BeginTimestamp: req.Base.Timestamp, EndTimestamp: req.Base.Timestamp, HashValues: []uint32{0}, @@ -484,7 +486,7 @@ func (c *Core) setMsgStreams() error { DropCollectionRequest: *req, } msgPack.Msgs = append(msgPack.Msgs, collMsg) - if err := ddStream.Broadcast(ctx, &msgPack); err != nil { + if err := ddStream.Broadcast(&msgPack); err != nil { return err } return nil @@ -493,6 +495,7 @@ func (c *Core) setMsgStreams() error { c.DdCreatePartitionReq = func(ctx context.Context, req *internalpb.CreatePartitionRequest) error { msgPack := ms.MsgPack{} baseMsg := ms.BaseMsg{ + Ctx: ctx, BeginTimestamp: req.Base.Timestamp, EndTimestamp: req.Base.Timestamp, HashValues: []uint32{0}, @@ -502,7 +505,7 @@ func (c *Core) setMsgStreams() error { CreatePartitionRequest: *req, } msgPack.Msgs = append(msgPack.Msgs, collMsg) - if err := ddStream.Broadcast(ctx, &msgPack); err != nil { + if err := ddStream.Broadcast(&msgPack); err != nil { return err } return nil @@ -511,6 +514,7 @@ func (c *Core) setMsgStreams() error { c.DdDropPartitionReq = func(ctx context.Context, req *internalpb.DropPartitionRequest) error { msgPack := ms.MsgPack{} baseMsg := ms.BaseMsg{ + Ctx: ctx, BeginTimestamp: req.Base.Timestamp, EndTimestamp: req.Base.Timestamp, HashValues: []uint32{0}, @@ -520,7 +524,7 @@ func (c *Core) setMsgStreams() error { DropPartitionRequest: *req, } msgPack.Msgs = append(msgPack.Msgs, collMsg) - if err := ddStream.Broadcast(ctx, &msgPack); err != nil { + if err := ddStream.Broadcast(&msgPack); err != nil { return err } return nil diff --git a/internal/masterservice/master_service_test.go b/internal/masterservice/master_service_test.go index acc7e9e5b079083eb4fe2187c47c2a2850d567f6..79e2876de3cb77ee4e08136d8ad070a3195a9a72 100644 --- a/internal/masterservice/master_service_test.go +++ b/internal/masterservice/master_service_test.go @@ -274,7 +274,7 @@ func TestMasterService(t *testing.T) { TimeTickMsg: timeTickResult, } msgPack.Msgs = append(msgPack.Msgs, timeTickMsg) - err := proxyTimeTickStream.Broadcast(ctx, &msgPack) + err := proxyTimeTickStream.Broadcast(&msgPack) assert.Nil(t, err) ttmsg, ok := <-timeTickStream.Chan() @@ -585,7 +585,7 @@ func TestMasterService(t *testing.T) { }, } msgPack.Msgs = append(msgPack.Msgs, segMsg) - err = dataServiceSegmentStream.Broadcast(ctx, &msgPack) + err = dataServiceSegmentStream.Broadcast(&msgPack) assert.Nil(t, err) time.Sleep(time.Second) @@ -744,7 +744,7 @@ func TestMasterService(t *testing.T) { }, } msgPack.Msgs = append(msgPack.Msgs, segMsg) - err = dataServiceSegmentStream.Broadcast(ctx, &msgPack) + err = dataServiceSegmentStream.Broadcast(&msgPack) assert.Nil(t, err) time.Sleep(time.Second) @@ -765,7 +765,7 @@ func TestMasterService(t *testing.T) { }, } msgPack.Msgs = []ms.TsMsg{flushMsg} - err = dataServiceSegmentStream.Broadcast(ctx, &msgPack) + err = dataServiceSegmentStream.Broadcast(&msgPack) assert.Nil(t, err) time.Sleep(time.Second) diff --git a/internal/masterservice/task.go b/internal/masterservice/task.go index 441f24f99512e5e4c1de268dc076300c753cb8c8..9796f12d61c4f5b980f7d67779b31a4f46c02122 100644 --- a/internal/masterservice/task.go +++ b/internal/masterservice/task.go @@ -6,8 +6,6 @@ import ( "fmt" "github.com/golang/protobuf/proto" - "go.uber.org/zap" - "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" @@ -15,6 +13,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + "go.uber.org/zap" ) type reqTask interface { diff --git a/internal/msgstream/memms/mem_msgstream.go b/internal/msgstream/memms/mem_msgstream.go index 31b1834cbc260604de6f1a83b64f19f8b4c7a933..9d32576241872248c85abfe7f431d4cbf2649813 100644 --- a/internal/msgstream/memms/mem_msgstream.go +++ b/internal/msgstream/memms/mem_msgstream.go @@ -94,7 +94,7 @@ func (mms *MemMsgStream) AsConsumer(channels []string, groupName string) { } } -func (mms *MemMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error { +func (mms *MemMsgStream) Produce(pack *msgstream.MsgPack) error { tsMsgs := pack.Msgs if len(tsMsgs) <= 0 { log.Printf("Warning: Receive empty msgPack") @@ -150,7 +150,7 @@ func (mms *MemMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) e return nil } -func (mms *MemMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { +func (mms *MemMsgStream) Broadcast(msgPack *msgstream.MsgPack) error { for _, channelName := range mms.producers { err := Mmq.Produce(channelName, msgPack) if err != nil { @@ -161,18 +161,18 @@ func (mms *MemMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error return nil } -func (mms *MemMsgStream) Consume() (*msgstream.MsgPack, context.Context) { +func (mms *MemMsgStream) Consume() *msgstream.MsgPack { for { select { case cm, ok := <-mms.receiveBuf: if !ok { log.Println("buf chan closed") - return nil, nil + return nil } - return cm, nil + return cm case <-mms.ctx.Done(): log.Printf("context closed") - return nil, nil + return nil } } } diff --git a/internal/msgstream/memms/mem_msgstream_test.go b/internal/msgstream/memms/mem_msgstream_test.go index 2f019e17ca9200048fbae0cc309ac5e27835a57f..96ad4033c7da9e91bf8e07c667a020afe8eb2143 100644 --- a/internal/msgstream/memms/mem_msgstream_test.go +++ b/internal/msgstream/memms/mem_msgstream_test.go @@ -101,7 +101,7 @@ func TestStream_GlobalMmq_Func(t *testing.T) { if err != nil { log.Fatalf("global mmq produce error = %v", err) } - cm, _ := consumerStreams[0].Consume() + cm := consumerStreams[0].Consume() assert.Equal(t, cm, &msg, "global mmq consume error") err = Mmq.Broadcast(&msg) @@ -109,7 +109,7 @@ func TestStream_GlobalMmq_Func(t *testing.T) { log.Fatalf("global mmq broadcast error = %v", err) } for _, cs := range consumerStreams { - cm, _ := cs.Consume() + cm := cs.Consume() assert.Equal(t, cm, &msg, "global mmq consume error") } @@ -142,12 +142,12 @@ func TestStream_MemMsgStream_Produce(t *testing.T) { msgPack := msgstream.MsgPack{} var hashValue uint32 = 2 msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1, hashValue)) - err := produceStream.Produce(context.Background(), &msgPack) + err := produceStream.Produce(&msgPack) if err != nil { log.Fatalf("new msgstream error = %v", err) } - msg, _ := consumerStreams[hashValue].Consume() + msg := consumerStreams[hashValue].Consume() if msg == nil { log.Fatalf("msgstream consume error") } @@ -167,13 +167,13 @@ func TestStream_MemMsgStream_BroadCast(t *testing.T) { msgPack := msgstream.MsgPack{} msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1, 100)) - err := produceStream.Broadcast(context.Background(), &msgPack) + err := produceStream.Broadcast(&msgPack) if err != nil { log.Fatalf("new msgstream error = %v", err) } for _, consumer := range consumerStreams { - msg, _ := consumer.Consume() + msg := consumer.Consume() if msg == nil { log.Fatalf("msgstream consume error") } diff --git a/internal/msgstream/msg.go b/internal/msgstream/msg.go index e06a1fef8034b9ff747325c68e7618ecd058afd1..8b0eba7e0a2e6affb778c86d77e9b16bb83ffef2 100644 --- a/internal/msgstream/msg.go +++ b/internal/msgstream/msg.go @@ -1,6 +1,7 @@ package msgstream import ( + "context" "errors" "github.com/golang/protobuf/proto" @@ -13,6 +14,8 @@ type MsgType = commonpb.MsgType type MarshalType = interface{} type TsMsg interface { + TraceCtx() context.Context + SetTraceCtx(ctx context.Context) ID() UniqueID BeginTs() Timestamp EndTs() Timestamp @@ -25,6 +28,7 @@ type TsMsg interface { } type BaseMsg struct { + Ctx context.Context BeginTimestamp Timestamp EndTimestamp Timestamp HashValues []uint32 @@ -66,6 +70,13 @@ type InsertMsg struct { internalpb.InsertRequest } +func (it *InsertMsg) TraceCtx() context.Context { + return it.BaseMsg.Ctx +} +func (it *InsertMsg) SetTraceCtx(ctx context.Context) { + it.BaseMsg.Ctx = ctx +} + func (it *InsertMsg) ID() UniqueID { return it.Base.MsgID } @@ -118,6 +129,14 @@ type FlushCompletedMsg struct { internalpb.SegmentFlushCompletedMsg } +func (fl *FlushCompletedMsg) TraceCtx() context.Context { + return fl.BaseMsg.Ctx +} + +func (fl *FlushCompletedMsg) SetTraceCtx(ctx context.Context) { + fl.BaseMsg.Ctx = ctx +} + func (fl *FlushCompletedMsg) ID() UniqueID { return fl.Base.MsgID } @@ -160,6 +179,14 @@ type FlushMsg struct { internalpb.FlushMsg } +func (fl *FlushMsg) TraceCtx() context.Context { + return fl.BaseMsg.Ctx +} + +func (fl *FlushMsg) SetTraceCtx(ctx context.Context) { + fl.BaseMsg.Ctx = ctx +} + func (fl *FlushMsg) ID() UniqueID { return fl.Base.MsgID } @@ -201,6 +228,14 @@ type DeleteMsg struct { internalpb.DeleteRequest } +func (dt *DeleteMsg) TraceCtx() context.Context { + return dt.BaseMsg.Ctx +} + +func (dt *DeleteMsg) SetTraceCtx(ctx context.Context) { + dt.BaseMsg.Ctx = ctx +} + func (dt *DeleteMsg) ID() UniqueID { return dt.Base.MsgID } @@ -254,6 +289,14 @@ type SearchMsg struct { internalpb.SearchRequest } +func (st *SearchMsg) TraceCtx() context.Context { + return st.BaseMsg.Ctx +} + +func (st *SearchMsg) SetTraceCtx(ctx context.Context) { + st.BaseMsg.Ctx = ctx +} + func (st *SearchMsg) ID() UniqueID { return st.Base.MsgID } @@ -295,6 +338,14 @@ type SearchResultMsg struct { internalpb.SearchResults } +func (srt *SearchResultMsg) TraceCtx() context.Context { + return srt.BaseMsg.Ctx +} + +func (srt *SearchResultMsg) SetTraceCtx(ctx context.Context) { + srt.BaseMsg.Ctx = ctx +} + func (srt *SearchResultMsg) ID() UniqueID { return srt.Base.MsgID } @@ -336,6 +387,14 @@ type TimeTickMsg struct { internalpb.TimeTickMsg } +func (tst *TimeTickMsg) TraceCtx() context.Context { + return tst.BaseMsg.Ctx +} + +func (tst *TimeTickMsg) SetTraceCtx(ctx context.Context) { + tst.BaseMsg.Ctx = ctx +} + func (tst *TimeTickMsg) ID() UniqueID { return tst.Base.MsgID } @@ -378,6 +437,14 @@ type QueryNodeStatsMsg struct { internalpb.QueryNodeStats } +func (qs *QueryNodeStatsMsg) TraceCtx() context.Context { + return qs.BaseMsg.Ctx +} + +func (qs *QueryNodeStatsMsg) SetTraceCtx(ctx context.Context) { + qs.BaseMsg.Ctx = ctx +} + func (qs *QueryNodeStatsMsg) ID() UniqueID { return qs.Base.MsgID } @@ -417,6 +484,14 @@ type SegmentStatisticsMsg struct { internalpb.SegmentStatistics } +func (ss *SegmentStatisticsMsg) TraceCtx() context.Context { + return ss.BaseMsg.Ctx +} + +func (ss *SegmentStatisticsMsg) SetTraceCtx(ctx context.Context) { + ss.BaseMsg.Ctx = ctx +} + func (ss *SegmentStatisticsMsg) ID() UniqueID { return ss.Base.MsgID } @@ -466,6 +541,14 @@ type CreateCollectionMsg struct { internalpb.CreateCollectionRequest } +func (cc *CreateCollectionMsg) TraceCtx() context.Context { + return cc.BaseMsg.Ctx +} + +func (cc *CreateCollectionMsg) SetTraceCtx(ctx context.Context) { + cc.BaseMsg.Ctx = ctx +} + func (cc *CreateCollectionMsg) ID() UniqueID { return cc.Base.MsgID } @@ -507,6 +590,14 @@ type DropCollectionMsg struct { internalpb.DropCollectionRequest } +func (dc *DropCollectionMsg) TraceCtx() context.Context { + return dc.BaseMsg.Ctx +} + +func (dc *DropCollectionMsg) SetTraceCtx(ctx context.Context) { + dc.BaseMsg.Ctx = ctx +} + func (dc *DropCollectionMsg) ID() UniqueID { return dc.Base.MsgID } @@ -548,15 +639,23 @@ type CreatePartitionMsg struct { internalpb.CreatePartitionRequest } -func (cc *CreatePartitionMsg) ID() UniqueID { - return cc.Base.MsgID +func (cp *CreatePartitionMsg) TraceCtx() context.Context { + return cp.BaseMsg.Ctx } -func (cc *CreatePartitionMsg) Type() MsgType { - return cc.Base.MsgType +func (cp *CreatePartitionMsg) SetTraceCtx(ctx context.Context) { + cp.BaseMsg.Ctx = ctx +} + +func (cp *CreatePartitionMsg) ID() UniqueID { + return cp.Base.MsgID } -func (cc *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) { +func (cp *CreatePartitionMsg) Type() MsgType { + return cp.Base.MsgType +} + +func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) { createPartitionMsg := input.(*CreatePartitionMsg) createPartitionRequest := &createPartitionMsg.CreatePartitionRequest mb, err := proto.Marshal(createPartitionRequest) @@ -566,7 +665,7 @@ func (cc *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) { return mb, nil } -func (cc *CreatePartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) { +func (cp *CreatePartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) { createPartitionRequest := internalpb.CreatePartitionRequest{} in, err := ConvertToByteArray(input) if err != nil { @@ -589,15 +688,23 @@ type DropPartitionMsg struct { internalpb.DropPartitionRequest } -func (dc *DropPartitionMsg) ID() UniqueID { - return dc.Base.MsgID +func (dp *DropPartitionMsg) TraceCtx() context.Context { + return dp.BaseMsg.Ctx } -func (dc *DropPartitionMsg) Type() MsgType { - return dc.Base.MsgType +func (dp *DropPartitionMsg) SetTraceCtx(ctx context.Context) { + dp.BaseMsg.Ctx = ctx +} + +func (dp *DropPartitionMsg) ID() UniqueID { + return dp.Base.MsgID +} + +func (dp *DropPartitionMsg) Type() MsgType { + return dp.Base.MsgType } -func (dc *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) { +func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) { dropPartitionMsg := input.(*DropPartitionMsg) dropPartitionRequest := &dropPartitionMsg.DropPartitionRequest mb, err := proto.Marshal(dropPartitionRequest) @@ -607,7 +714,7 @@ func (dc *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) { return mb, nil } -func (dc *DropPartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) { +func (dp *DropPartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) { dropPartitionRequest := internalpb.DropPartitionRequest{} in, err := ConvertToByteArray(input) if err != nil { @@ -630,6 +737,14 @@ type LoadIndexMsg struct { internalpb.LoadIndex } +func (lim *LoadIndexMsg) TraceCtx() context.Context { + return lim.BaseMsg.Ctx +} + +func (lim *LoadIndexMsg) SetTraceCtx(ctx context.Context) { + lim.BaseMsg.Ctx = ctx +} + func (lim *LoadIndexMsg) ID() UniqueID { return lim.Base.MsgID } @@ -669,6 +784,14 @@ type SegmentInfoMsg struct { datapb.SegmentMsg } +func (sim *SegmentInfoMsg) TraceCtx() context.Context { + return sim.BaseMsg.Ctx +} + +func (sim *SegmentInfoMsg) SetTraceCtx(ctx context.Context) { + sim.BaseMsg.Ctx = ctx +} + func (sim *SegmentInfoMsg) ID() UniqueID { return sim.Base.MsgID } diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 431ab879b8cfe7e156c987e5a17c106ef7566dd7..abd3a8080eab5a2e6d7f8292852aa78454691f9d 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -30,9 +30,9 @@ type MsgStream interface { AsConsumer(channels []string, subName string) SetRepackFunc(repackFunc RepackFunc) - Produce(context.Context, *MsgPack) error - Broadcast(context.Context, *MsgPack) error - Consume() (*MsgPack, context.Context) + Produce(*MsgPack) error + Broadcast(*MsgPack) error + Consume() *MsgPack Seek(offset *MsgPosition) error } diff --git a/internal/msgstream/pulsarms/msg_test.go b/internal/msgstream/pulsarms/msg_test.go index 621ab5aeefff02b46e5d268aa0d1a211b6bc004a..6cd72e15d8b39cbfa6d8e4878cd11ca55e930d00 100644 --- a/internal/msgstream/pulsarms/msg_test.go +++ b/internal/msgstream/pulsarms/msg_test.go @@ -132,7 +132,6 @@ func getInsertTask(reqID msgstream.UniqueID, hashValue uint32) msgstream.TsMsg { } func TestStream_task_Insert(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") producerChannels := []string{"insert1", "insert2"} consumerChannels := []string{"insert1", "insert2"} @@ -155,13 +154,13 @@ func TestStream_task_Insert(t *testing.T) { outputStream.AsConsumer(consumerChannels, consumerSubName) outputStream.Start() - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } receiveCount := 0 for { - result, _ := outputStream.Consume() + result := outputStream.Consume() if len(result.Msgs) > 0 { msgs := result.Msgs for _, v := range msgs { diff --git a/internal/msgstream/pulsarms/pulsar_msgstream.go b/internal/msgstream/pulsarms/pulsar_msgstream.go index d391542f8a1c2daa619a0b15bb39b6ffd7db75a7..fce42cb5ab800f4bc1be1980b66836a4945bf968 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream.go @@ -11,9 +11,9 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/golang/protobuf/proto" - "github.com/opentracing/opentracing-go" "go.uber.org/zap" + "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" @@ -54,8 +54,6 @@ type PulsarMsgStream struct { producerLock *sync.Mutex consumerLock *sync.Mutex consumerReflects []reflect.SelectCase - - scMap *sync.Map } func newPulsarMsgStream(ctx context.Context, @@ -99,7 +97,6 @@ func newPulsarMsgStream(ctx context.Context, producerLock: &sync.Mutex{}, consumerLock: &sync.Mutex{}, wait: &sync.WaitGroup{}, - scMap: &sync.Map{}, } return stream, nil @@ -195,7 +192,7 @@ func (ms *PulsarMsgStream) Close() { } } -func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error { +func (ms *PulsarMsgStream) Produce(msgPack *msgstream.MsgPack) error { tsMsgs := msgPack.Msgs if len(tsMsgs) <= 0 { log.Debug("Warning: Receive empty msgPack") @@ -257,7 +254,7 @@ func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error msg := &pulsar.ProducerMessage{Payload: m, Properties: map[string]string{}} - sp, spanCtx := trace.MsgSpanFromCtx(ctx, v.Msgs[i]) + sp, spanCtx := trace.MsgSpanFromCtx(v.Msgs[i].TraceCtx(), v.Msgs[i]) trace.InjectContextToPulsarMsgProperties(sp.Context(), msg.Properties) if _, err := ms.producers[channel].Send( @@ -274,7 +271,7 @@ func (ms *PulsarMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error return nil } -func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { +func (ms *PulsarMsgStream) Broadcast(msgPack *msgstream.MsgPack) error { for _, v := range msgPack.Msgs { mb, err := v.Marshal(v) if err != nil { @@ -288,7 +285,7 @@ func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) erro msg := &pulsar.ProducerMessage{Payload: m, Properties: map[string]string{}} - sp, spanCtx := trace.MsgSpanFromCtx(ctx, v) + sp, spanCtx := trace.MsgSpanFromCtx(v.TraceCtx(), v) trace.InjectContextToPulsarMsgProperties(sp.Context(), msg.Properties) ms.producerLock.Lock() @@ -308,31 +305,18 @@ func (ms *PulsarMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) erro return nil } -func (ms *PulsarMsgStream) Consume() (*MsgPack, context.Context) { +func (ms *PulsarMsgStream) Consume() *msgstream.MsgPack { for { select { case cm, ok := <-ms.receiveBuf: if !ok { log.Debug("buf chan closed") - return nil, nil - } - var ctx context.Context - var opts []opentracing.StartSpanOption - for _, msg := range cm.Msgs { - sc, loaded := ms.scMap.LoadAndDelete(msg.ID()) - if loaded { - opts = append(opts, opentracing.ChildOf(sc.(opentracing.SpanContext))) - } + return nil } - if len(opts) != 0 { - ctx = context.Background() - } - sp, ctx := trace.StartSpanFromContext(ctx, opts...) - sp.Finish() - return cm, ctx + return cm case <-ms.ctx.Done(): //log.Debug("context closed") - return nil, nil + return nil } } } @@ -368,7 +352,7 @@ func (ms *PulsarMsgStream) receiveMsg(consumer Consumer) { sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties()) if ok { - ms.scMap.Store(tsMsg.ID(), sp.Context()) + tsMsg.SetTraceCtx(opentracing.ContextWithSpan(context.Background(), sp)) } msgPack := MsgPack{Msgs: []TsMsg{tsMsg}} @@ -460,6 +444,10 @@ func (ms *PulsarMsgStream) bufMsgPackToChannel() { log.Error("Failed to unmarshal tsMsg", zap.Error(err)) continue } + sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties()) + if ok { + tsMsg.SetTraceCtx(opentracing.ContextWithSpan(context.Background(), sp)) + } tsMsg.SetPosition(&msgstream.MsgPosition{ ChannelName: filepath.Base(pulsarMsg.Topic()), @@ -736,7 +724,7 @@ func (ms *PulsarTtMsgStream) findTimeTick(consumer Consumer, sp, ok := trace.ExtractFromPulsarMsgProperties(tsMsg, pulsarMsg.Properties()) if ok { - ms.scMap.Store(tsMsg.ID(), sp.Context()) + tsMsg.SetTraceCtx(opentracing.ContextWithSpan(context.Background(), sp)) } ms.unsolvedMutex.Lock() diff --git a/internal/msgstream/pulsarms/pulsar_msgstream_test.go b/internal/msgstream/pulsarms/pulsar_msgstream_test.go index c67973962cca419376ee1c5382333992fec707c9..42b2ff15f56bcf869ab67381e007a70c4ec8d598 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream_test.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream_test.go @@ -223,7 +223,7 @@ func initPulsarTtStream(pulsarAddress string, func receiveMsg(outputStream msgstream.MsgStream, msgCount int) { receiveCount := 0 for { - result, _ := outputStream.Consume() + result := outputStream.Consume() if len(result.Msgs) > 0 { msgs := result.Msgs for _, v := range msgs { @@ -238,7 +238,6 @@ func receiveMsg(outputStream msgstream.MsgStream, msgCount int) { } func TestStream_PulsarMsgStream_Insert(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -250,7 +249,7 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -262,7 +261,6 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { } func TestStream_PulsarMsgStream_Delete(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c := funcutil.RandomString(8) producerChannels := []string{c} @@ -273,7 +271,7 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { //msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Delete, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -283,7 +281,6 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { } func TestStream_PulsarMsgStream_Search(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c := funcutil.RandomString(8) producerChannels := []string{c} @@ -295,7 +292,7 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -305,7 +302,6 @@ func TestStream_PulsarMsgStream_Search(t *testing.T) { } func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c := funcutil.RandomString(8) producerChannels := []string{c} @@ -316,7 +312,7 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -326,7 +322,6 @@ func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { } func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c := funcutil.RandomString(8) producerChannels := []string{c} @@ -337,7 +332,7 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -347,7 +342,6 @@ func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { } func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -359,7 +353,7 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Broadcast(ctx, &msgPack) + err := inputStream.Broadcast(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -369,7 +363,6 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { } func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -381,7 +374,7 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 3, 3)) inputStream, outputStream := initPulsarStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName, repackFunc) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -391,7 +384,6 @@ func TestStream_PulsarMsgStream_RepackFunc(t *testing.T) { } func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -436,7 +428,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { outputStream.Start() var output msgstream.MsgStream = outputStream - err := (*inputStream).Produce(ctx, &msgPack) + err := (*inputStream).Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -446,7 +438,6 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { } func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -489,7 +480,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { outputStream.Start() var output msgstream.MsgStream = outputStream - err := (*inputStream).Produce(ctx, &msgPack) + err := (*inputStream).Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -499,7 +490,6 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { } func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -522,7 +512,7 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { outputStream.Start() var output msgstream.MsgStream = outputStream - err := (*inputStream).Produce(ctx, &msgPack) + err := (*inputStream).Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -532,7 +522,6 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { } func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -549,15 +538,15 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Broadcast(ctx, &msgPack0) + err := inputStream.Broadcast(&msgPack0) if err != nil { log.Fatalf("broadcast error = %v", err) } - err = inputStream.Produce(ctx, &msgPack1) + err = inputStream.Produce(&msgPack1) if err != nil { log.Fatalf("produce error = %v", err) } - err = inputStream.Broadcast(ctx, &msgPack2) + err = inputStream.Broadcast(&msgPack2) if err != nil { log.Fatalf("broadcast error = %v", err) } @@ -567,7 +556,6 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { } func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -595,23 +583,23 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { msgPack5.Msgs = append(msgPack5.Msgs, getTimeTickMsg(15, 15, 15)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Broadcast(ctx, &msgPack0) + err := inputStream.Broadcast(&msgPack0) assert.Nil(t, err) - err = inputStream.Produce(ctx, &msgPack1) + err = inputStream.Produce(&msgPack1) assert.Nil(t, err) - err = inputStream.Broadcast(ctx, &msgPack2) + err = inputStream.Broadcast(&msgPack2) assert.Nil(t, err) - err = inputStream.Produce(ctx, &msgPack3) + err = inputStream.Produce(&msgPack3) assert.Nil(t, err) - err = inputStream.Broadcast(ctx, &msgPack4) + err = inputStream.Broadcast(&msgPack4) assert.Nil(t, err) outputStream.Consume() - receivedMsg, _ := outputStream.Consume() + receivedMsg := outputStream.Consume() for _, position := range receivedMsg.StartPositions { outputStream.Seek(position) } - err = inputStream.Broadcast(ctx, &msgPack5) + err = inputStream.Broadcast(&msgPack5) assert.Nil(t, err) //seekMsg, _ := outputStream.Consume() //for _, msg := range seekMsg.Msgs { @@ -622,7 +610,6 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { } func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { - ctx := context.Background() pulsarAddress, _ := Params.Load("_PulsarAddress") c1, c2 := funcutil.RandomString(8), funcutil.RandomString(8) producerChannels := []string{c1, c2} @@ -640,15 +627,15 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5, 5, 5)) inputStream, outputStream := initPulsarTtStream(pulsarAddress, producerChannels, consumerChannels, consumerSubName) - err := inputStream.Broadcast(ctx, &msgPack0) + err := inputStream.Broadcast(&msgPack0) if err != nil { log.Fatalf("broadcast error = %v", err) } - err = inputStream.Produce(ctx, &msgPack1) + err = inputStream.Produce(&msgPack1) if err != nil { log.Fatalf("produce error = %v", err) } - err = inputStream.Broadcast(ctx, &msgPack2) + err = inputStream.Broadcast(&msgPack2) if err != nil { log.Fatalf("broadcast error = %v", err) } diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index 49c3d6bbb1aa75ec936230e2df008633545cc3cc..0b5114330bd62b680caaa9fcb178f8598a10b3ed 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -161,7 +161,7 @@ func (rms *RmqMsgStream) AsConsumer(channels []string, groupName string) { } } -func (rms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error { +func (rms *RmqMsgStream) Produce(pack *msgstream.MsgPack) error { tsMsgs := pack.Msgs if len(tsMsgs) <= 0 { log.Debug("Warning: Receive empty msgPack") @@ -228,7 +228,7 @@ func (rms *RmqMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) e return nil } -func (rms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error { +func (rms *RmqMsgStream) Broadcast(msgPack *msgstream.MsgPack) error { for _, v := range msgPack.Msgs { mb, err := v.Marshal(v) if err != nil { @@ -255,18 +255,18 @@ func (rms *RmqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) error return nil } -func (rms *RmqMsgStream) Consume() (*msgstream.MsgPack, context.Context) { +func (rms *RmqMsgStream) Consume() *msgstream.MsgPack { for { select { case cm, ok := <-rms.receiveBuf: if !ok { log.Debug("buf chan closed") - return nil, nil + return nil } - return cm, nil + return cm case <-rms.ctx.Done(): //log.Debug("context closed") - return nil, nil + return nil } } } diff --git a/internal/msgstream/rmqms/rmq_msgstream_test.go b/internal/msgstream/rmqms/rmq_msgstream_test.go index 81d511922ae9e3a20f4acc7068871232ab1b34f1..ba1f7354cf1d7d6b2e0d1d101fe12deb56d6fef3 100644 --- a/internal/msgstream/rmqms/rmq_msgstream_test.go +++ b/internal/msgstream/rmqms/rmq_msgstream_test.go @@ -239,7 +239,7 @@ func initRmqTtStream(producerChannels []string, func receiveMsg(outputStream msgstream.MsgStream, msgCount int) { receiveCount := 0 for { - result, _ := outputStream.Consume() + result := outputStream.Consume() if len(result.Msgs) > 0 { msgs := result.Msgs for _, v := range msgs { @@ -254,7 +254,6 @@ func receiveMsg(outputStream msgstream.MsgStream, msgCount int) { } func TestStream_RmqMsgStream_Insert(t *testing.T) { - ctx := context.Background() producerChannels := []string{"insert1", "insert2"} consumerChannels := []string{"insert1", "insert2"} consumerGroupName := "InsertGroup" @@ -266,7 +265,7 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) { rocksdbName := "/tmp/rocksmq_insert" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerGroupName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -276,7 +275,6 @@ func TestStream_RmqMsgStream_Insert(t *testing.T) { } func TestStream_RmqMsgStream_Delete(t *testing.T) { - ctx := context.Background() producerChannels := []string{"delete"} consumerChannels := []string{"delete"} consumerSubName := "subDelete" @@ -287,7 +285,7 @@ func TestStream_RmqMsgStream_Delete(t *testing.T) { rocksdbName := "/tmp/rocksmq_delete" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -296,7 +294,6 @@ func TestStream_RmqMsgStream_Delete(t *testing.T) { } func TestStream_RmqMsgStream_Search(t *testing.T) { - ctx := context.Background() producerChannels := []string{"search"} consumerChannels := []string{"search"} consumerSubName := "subSearch" @@ -308,7 +305,7 @@ func TestStream_RmqMsgStream_Search(t *testing.T) { rocksdbName := "/tmp/rocksmq_search" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -317,8 +314,6 @@ func TestStream_RmqMsgStream_Search(t *testing.T) { } func TestStream_RmqMsgStream_SearchResult(t *testing.T) { - ctx := context.Background() - producerChannels := []string{"searchResult"} consumerChannels := []string{"searchResult"} consumerSubName := "subSearchResult" @@ -330,7 +325,7 @@ func TestStream_RmqMsgStream_SearchResult(t *testing.T) { rocksdbName := "/tmp/rocksmq_searchresult" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -339,7 +334,6 @@ func TestStream_RmqMsgStream_SearchResult(t *testing.T) { } func TestStream_RmqMsgStream_TimeTick(t *testing.T) { - ctx := context.Background() producerChannels := []string{"timeTick"} consumerChannels := []string{"timeTick"} consumerSubName := "subTimeTick" @@ -351,7 +345,7 @@ func TestStream_RmqMsgStream_TimeTick(t *testing.T) { rocksdbName := "/tmp/rocksmq_timetick" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -360,7 +354,6 @@ func TestStream_RmqMsgStream_TimeTick(t *testing.T) { } func TestStream_RmqMsgStream_BroadCast(t *testing.T) { - ctx := context.Background() producerChannels := []string{"insert1", "insert2"} consumerChannels := []string{"insert1", "insert2"} consumerSubName := "subInsert" @@ -372,7 +365,7 @@ func TestStream_RmqMsgStream_BroadCast(t *testing.T) { rocksdbName := "/tmp/rocksmq_broadcast" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName) - err := inputStream.Broadcast(ctx, &msgPack) + err := inputStream.Broadcast(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -381,8 +374,6 @@ func TestStream_RmqMsgStream_BroadCast(t *testing.T) { } func TestStream_RmqMsgStream_RepackFunc(t *testing.T) { - ctx := context.Background() - producerChannels := []string{"insert1", "insert2"} consumerChannels := []string{"insert1", "insert2"} consumerSubName := "subInsert" @@ -394,7 +385,7 @@ func TestStream_RmqMsgStream_RepackFunc(t *testing.T) { rocksdbName := "/tmp/rocksmq_repackfunc" etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqStream(producerChannels, consumerChannels, consumerSubName, repackFunc) - err := inputStream.Produce(ctx, &msgPack) + err := inputStream.Produce(&msgPack) if err != nil { log.Fatalf("produce error = %v", err) } @@ -403,8 +394,6 @@ func TestStream_RmqMsgStream_RepackFunc(t *testing.T) { } func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { - ctx := context.Background() - producerChannels := []string{"insert1", "insert2"} consumerChannels := []string{"insert1", "insert2"} consumerSubName := "subInsert" @@ -423,15 +412,15 @@ func TestStream_PulsarTtMsgStream_Insert(t *testing.T) { etcdKV := initRmq(rocksdbName) inputStream, outputStream := initRmqTtStream(producerChannels, consumerChannels, consumerSubName) - err := inputStream.Broadcast(ctx, &msgPack0) + err := inputStream.Broadcast(&msgPack0) if err != nil { log.Fatalf("broadcast error = %v", err) } - err = inputStream.Produce(ctx, &msgPack1) + err = inputStream.Produce(&msgPack1) if err != nil { log.Fatalf("produce error = %v", err) } - err = inputStream.Broadcast(ctx, &msgPack2) + err = inputStream.Broadcast(&msgPack2) if err != nil { log.Fatalf("broadcast error = %v", err) } diff --git a/internal/msgstream/util/repack_func.go b/internal/msgstream/util/repack_func.go index 3e7cfec59908ddd1513cb7e86266dca7b723311d..abbe2e9206354e4707bd706d5f806ea62d6013f5 100644 --- a/internal/msgstream/util/repack_func.go +++ b/internal/msgstream/util/repack_func.go @@ -57,6 +57,9 @@ func InsertRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e } insertMsg := &msgstream.InsertMsg{ + BaseMsg: BaseMsg{ + Ctx: request.TraceCtx(), + }, InsertRequest: sliceRequest, } result[key].Msgs = append(result[key].Msgs, insertMsg) @@ -103,6 +106,9 @@ func DeleteRepackFunc(tsMsgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, e } deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: BaseMsg{ + Ctx: request.TraceCtx(), + }, DeleteRequest: sliceRequest, } result[key].Msgs = append(result[key].Msgs, deleteMsg) diff --git a/internal/proxynode/repack_func.go b/internal/proxynode/repack_func.go index ff74331484b35c799b75524f0d9a8df12967dbaa..4efa0d031139394df2cef7eb46ed2e9240e9506e 100644 --- a/internal/proxynode/repack_func.go +++ b/internal/proxynode/repack_func.go @@ -183,6 +183,7 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, // if slice, todo: a common function to calculate size of slice, // if map, a little complicated size := 0 + size += int(unsafe.Sizeof(msg.Ctx)) size += int(unsafe.Sizeof(msg.BeginTimestamp)) size += int(unsafe.Sizeof(msg.EndTimestamp)) size += int(unsafe.Sizeof(msg.HashValues)) @@ -262,6 +263,9 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, RowData: []*commonpb.Blob{row}, } insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: request.TraceCtx(), + }, InsertRequest: sliceRequest, } if together { // all rows with same hash value are accumulated to only one message diff --git a/internal/proxynode/task.go b/internal/proxynode/task.go index 0fa481466c3d8b21a4e8464cf8d8e70e5be7d45e..140e1ccf8735d3641d5b2d592e355895bfe2c435 100644 --- a/internal/proxynode/task.go +++ b/internal/proxynode/task.go @@ -52,7 +52,7 @@ const ( ) type task interface { - Ctx() context.Context + TraceCtx() context.Context ID() UniqueID // return ReqID SetID(uid UniqueID) // set ReqID Name() string @@ -79,7 +79,7 @@ type InsertTask struct { rowIDAllocator *allocator.IDAllocator } -func (it *InsertTask) Ctx() context.Context { +func (it *InsertTask) TraceCtx() context.Context { return it.ctx } @@ -185,7 +185,8 @@ func (it *InsertTask) Execute(ctx context.Context) error { } var tsMsg msgstream.TsMsg = &it.BaseInsertTask - msgPack := &msgstream.MsgPack{ + it.BaseMsg.Ctx = ctx + msgPack := msgstream.MsgPack{ BeginTs: it.BeginTs(), EndTs: it.EndTs(), Msgs: make([]msgstream.TsMsg, 1), @@ -231,7 +232,7 @@ func (it *InsertTask) Execute(ctx context.Context) error { return err } - err = stream.Produce(ctx, msgPack) + err = stream.Produce(&msgPack) if err != nil { it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError it.result.Status.Reason = err.Error() @@ -255,7 +256,7 @@ type CreateCollectionTask struct { schema *schemapb.CollectionSchema } -func (cct *CreateCollectionTask) Ctx() context.Context { +func (cct *CreateCollectionTask) TraceCtx() context.Context { return cct.ctx } @@ -403,7 +404,7 @@ type DropCollectionTask struct { result *commonpb.Status } -func (dct *DropCollectionTask) Ctx() context.Context { +func (dct *DropCollectionTask) TraceCtx() context.Context { return dct.ctx } @@ -484,7 +485,7 @@ type SearchTask struct { query *milvuspb.SearchRequest } -func (st *SearchTask) Ctx() context.Context { +func (st *SearchTask) TraceCtx() context.Context { return st.ctx } @@ -596,18 +597,19 @@ func (st *SearchTask) Execute(ctx context.Context) error { var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{ SearchRequest: *st.SearchRequest, BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, HashValues: []uint32{uint32(Params.ProxyID)}, BeginTimestamp: st.Base.Timestamp, EndTimestamp: st.Base.Timestamp, }, } - msgPack := &msgstream.MsgPack{ + msgPack := msgstream.MsgPack{ BeginTs: st.Base.Timestamp, EndTs: st.Base.Timestamp, Msgs: make([]msgstream.TsMsg, 1), } msgPack.Msgs[0] = tsMsg - err := st.queryMsgStream.Produce(ctx, msgPack) + err := st.queryMsgStream.Produce(&msgPack) log.Debug("proxynode", zap.Int("length of searchMsg", len(msgPack.Msgs))) if err != nil { log.Debug("proxynode", zap.String("send search request failed", err.Error())) @@ -990,7 +992,7 @@ func printSearchResult(partialSearchResult *internalpb.SearchResults) { func (st *SearchTask) PostExecute(ctx context.Context) error { for { select { - case <-st.Ctx().Done(): + case <-st.TraceCtx().Done(): log.Debug("proxynode", zap.Int64("SearchTask: wait to finish failed, timeout!, taskID:", st.ID())) return fmt.Errorf("SearchTask:wait to finish failed, timeout: %d", st.ID()) case searchResults := <-st.resultBuf: @@ -1073,7 +1075,7 @@ type HasCollectionTask struct { result *milvuspb.BoolResponse } -func (hct *HasCollectionTask) Ctx() context.Context { +func (hct *HasCollectionTask) TraceCtx() context.Context { return hct.ctx } @@ -1144,7 +1146,7 @@ type DescribeCollectionTask struct { result *milvuspb.DescribeCollectionResponse } -func (dct *DescribeCollectionTask) Ctx() context.Context { +func (dct *DescribeCollectionTask) TraceCtx() context.Context { return dct.ctx } @@ -1215,7 +1217,7 @@ type GetCollectionsStatisticsTask struct { result *milvuspb.GetCollectionStatisticsResponse } -func (g *GetCollectionsStatisticsTask) Ctx() context.Context { +func (g *GetCollectionsStatisticsTask) TraceCtx() context.Context { return g.ctx } @@ -1302,7 +1304,7 @@ type ShowCollectionsTask struct { result *milvuspb.ShowCollectionsResponse } -func (sct *ShowCollectionsTask) Ctx() context.Context { +func (sct *ShowCollectionsTask) TraceCtx() context.Context { return sct.ctx } @@ -1370,7 +1372,7 @@ type CreatePartitionTask struct { result *commonpb.Status } -func (cpt *CreatePartitionTask) Ctx() context.Context { +func (cpt *CreatePartitionTask) TraceCtx() context.Context { return cpt.ctx } @@ -1447,7 +1449,7 @@ type DropPartitionTask struct { result *commonpb.Status } -func (dpt *DropPartitionTask) Ctx() context.Context { +func (dpt *DropPartitionTask) TraceCtx() context.Context { return dpt.ctx } @@ -1524,7 +1526,7 @@ type HasPartitionTask struct { result *milvuspb.BoolResponse } -func (hpt *HasPartitionTask) Ctx() context.Context { +func (hpt *HasPartitionTask) TraceCtx() context.Context { return hpt.ctx } @@ -1600,7 +1602,7 @@ type ShowPartitionsTask struct { result *milvuspb.ShowPartitionsResponse } -func (spt *ShowPartitionsTask) Ctx() context.Context { +func (spt *ShowPartitionsTask) TraceCtx() context.Context { return spt.ctx } @@ -1671,7 +1673,7 @@ type CreateIndexTask struct { result *commonpb.Status } -func (cit *CreateIndexTask) Ctx() context.Context { +func (cit *CreateIndexTask) TraceCtx() context.Context { return cit.ctx } @@ -1749,7 +1751,7 @@ type DescribeIndexTask struct { result *milvuspb.DescribeIndexResponse } -func (dit *DescribeIndexTask) Ctx() context.Context { +func (dit *DescribeIndexTask) TraceCtx() context.Context { return dit.ctx } @@ -1832,7 +1834,7 @@ type DropIndexTask struct { result *commonpb.Status } -func (dit *DropIndexTask) Ctx() context.Context { +func (dit *DropIndexTask) TraceCtx() context.Context { return dit.ctx } @@ -1911,7 +1913,7 @@ type GetIndexStateTask struct { result *milvuspb.GetIndexStateResponse } -func (gist *GetIndexStateTask) Ctx() context.Context { +func (gist *GetIndexStateTask) TraceCtx() context.Context { return gist.ctx } @@ -2142,7 +2144,7 @@ type FlushTask struct { result *commonpb.Status } -func (ft *FlushTask) Ctx() context.Context { +func (ft *FlushTask) TraceCtx() context.Context { return ft.ctx } @@ -2228,7 +2230,7 @@ type LoadCollectionTask struct { result *commonpb.Status } -func (lct *LoadCollectionTask) Ctx() context.Context { +func (lct *LoadCollectionTask) TraceCtx() context.Context { return lct.ctx } @@ -2323,7 +2325,7 @@ type ReleaseCollectionTask struct { result *commonpb.Status } -func (rct *ReleaseCollectionTask) Ctx() context.Context { +func (rct *ReleaseCollectionTask) TraceCtx() context.Context { return rct.ctx } @@ -2404,6 +2406,10 @@ type LoadPartitionTask struct { result *commonpb.Status } +func (lpt *LoadPartitionTask) TraceCtx() context.Context { + return lpt.ctx +} + func (lpt *LoadPartitionTask) ID() UniqueID { return lpt.Base.MsgID } @@ -2495,7 +2501,7 @@ type ReleasePartitionTask struct { result *commonpb.Status } -func (rpt *ReleasePartitionTask) Ctx() context.Context { +func (rpt *ReleasePartitionTask) TraceCtx() context.Context { return rpt.ctx } diff --git a/internal/proxynode/task_scheduler.go b/internal/proxynode/task_scheduler.go index 450a698b9a3217909227997d12d71534875a4ba3..13d380083e3cc6717e554049c0e06023bbcdbac4 100644 --- a/internal/proxynode/task_scheduler.go +++ b/internal/proxynode/task_scheduler.go @@ -302,7 +302,7 @@ func (sched *TaskScheduler) getTaskByReqID(collMeta UniqueID) task { } func (sched *TaskScheduler) processTask(t task, q TaskQueue) { - span, ctx := trace.StartSpanFromContext(t.Ctx(), + span, ctx := trace.StartSpanFromContext(t.TraceCtx(), opentracing.Tags{ "Type": t.Name(), "ID": t.ID(), @@ -409,6 +409,8 @@ func (sched *TaskScheduler) queryResultLoop() { continue } for _, tsMsg := range msgPack.Msgs { + sp, ctx := trace.StartSpanFromContext(tsMsg.TraceCtx()) + tsMsg.SetTraceCtx(ctx) searchResultMsg, _ := tsMsg.(*msgstream.SearchResultMsg) reqID := searchResultMsg.Base.MsgID reqIDStr := strconv.FormatInt(reqID, 10) @@ -443,6 +445,7 @@ func (sched *TaskScheduler) queryResultLoop() { // log.Printf("task with reqID %v is nil", reqID) } } + sp.Finish() } case <-sched.ctx.Done(): log.Debug("proxynode server is closed ...") diff --git a/internal/proxynode/timetick.go b/internal/proxynode/timetick.go index 30911031a68e1ee9ce2d91f351484ba4bcfa9ccd..97ec662de9bf539a59ac904a25576980cbae7cc8 100644 --- a/internal/proxynode/timetick.go +++ b/internal/proxynode/timetick.go @@ -86,7 +86,7 @@ func (tt *timeTick) tick() error { }, } msgPack.Msgs = append(msgPack.Msgs, timeTickMsg) - err := tt.tickMsgStream.Produce(tt.ctx, &msgPack) + err := tt.tickMsgStream.Produce(&msgPack) if err != nil { log.Warn("proxynode", zap.String("error", err.Error())) } diff --git a/internal/proxyservice/timetick.go b/internal/proxyservice/timetick.go index 4e5cb8eaa496094e4019200952848f14c3eca203..74e2b2ed09d4904f27e799c9f9f517863dafa5a7 100644 --- a/internal/proxyservice/timetick.go +++ b/internal/proxyservice/timetick.go @@ -55,7 +55,7 @@ func (tt *TimeTick) Start() error { log.Debug("proxyservice", zap.Stringer("msg type", msg.Type())) } for _, channel := range tt.channels { - err = channel.Broadcast(tt.ctx, &msgPack) + err = channel.Broadcast(&msgPack) if err != nil { log.Error("proxyservice", zap.String("send time tick error", err.Error())) } diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index fcd932796ff48175406c8bed789d1159b34108d5..0680bd043e7bb7dd19ec6d00332c27399fda94d8 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -1,7 +1,6 @@ package querynode import ( - "context" "encoding/binary" "math" "testing" @@ -16,8 +15,6 @@ import ( // NOTE: start pulsar before test func TestDataSyncService_Start(t *testing.T) { - ctx := context.Background() - collectionID := UniqueID(0) node := newQueryNodeMock() @@ -127,10 +124,10 @@ func TestDataSyncService_Start(t *testing.T) { var insertMsgStream msgstream.MsgStream = insertStream insertMsgStream.Start() - err = insertMsgStream.Produce(ctx, &msgPack) + err = insertMsgStream.Produce(&msgPack) assert.NoError(t, err) - err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = insertMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) // dataSync diff --git a/internal/querynode/flow_graph_dd_node.go b/internal/querynode/flow_graph_dd_node.go index 66b80cf11cf7950de0a00ff384ddbf69287e144e..d65f05a8f2afce4e9e4bb02cf4fe46df4c765662 100644 --- a/internal/querynode/flow_graph_dd_node.go +++ b/internal/querynode/flow_graph_dd_node.go @@ -1,14 +1,15 @@ package querynode import ( - "context" - "github.com/golang/protobuf/proto" + "github.com/opentracing/opentracing-go" + "github.com/zilliztech/milvus-distributed/internal/util/trace" "go.uber.org/zap" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" ) type ddNode struct { @@ -21,7 +22,7 @@ func (ddNode *ddNode) Name() string { return "ddNode" } -func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (ddNode *ddNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { //log.Debug("Do filterDmNode operation") if len(in) != 1 { @@ -35,6 +36,13 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con // TODO: add error handling } + var spans []opentracing.Span + for _, msg := range msMsg.TsMessages() { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) + } + var ddMsg = ddMsg{ collectionRecords: make(map[UniqueID][]metaOperateRecord), partitionRecords: make(map[UniqueID][]metaOperateRecord), @@ -74,7 +82,10 @@ func (ddNode *ddNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con //} var res Msg = ddNode.ddMsg - return []Msg{res}, ctx + for _, span := range spans { + span.Finish() + } + return []Msg{res} } func (ddNode *ddNode) createCollection(msg *msgstream.CreateCollectionMsg) { diff --git a/internal/querynode/flow_graph_filter_dm_node.go b/internal/querynode/flow_graph_filter_dm_node.go index e2e705ef87f8f14c13a03aa51d24ad03971da184..c4106323f61bee24f185596fd85ba8b4df5a7b7e 100644 --- a/internal/querynode/flow_graph_filter_dm_node.go +++ b/internal/querynode/flow_graph_filter_dm_node.go @@ -1,14 +1,15 @@ package querynode import ( - "context" "fmt" - "go.uber.org/zap" - + "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" + "github.com/zilliztech/milvus-distributed/internal/util/trace" + "go.uber.org/zap" ) type filterDmNode struct { @@ -21,7 +22,7 @@ func (fdmNode *filterDmNode) Name() string { return "fdmNode" } -func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (fdmNode *filterDmNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { //log.Debug("Do filterDmNode operation") if len(in) != 1 { @@ -36,7 +37,14 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont } if msgStreamMsg == nil { - return []Msg{}, ctx + return []Msg{} + } + + var spans []opentracing.Span + for _, msg := range msgStreamMsg.TsMessages() { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) } var iMsg = insertMsg{ @@ -61,11 +69,16 @@ func (fdmNode *filterDmNode) Operate(ctx context.Context, in []Msg) ([]Msg, cont } var res Msg = &iMsg - - return []Msg{res}, ctx + for _, sp := range spans { + sp.Finish() + } + return []Msg{res} } func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg) *msgstream.InsertMsg { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + msg.SetTraceCtx(ctx) + defer sp.Finish() // check if collection and partition exist collection := fdmNode.replica.hasCollection(msg.CollectionID) partition := fdmNode.replica.hasPartition(msg.PartitionID) diff --git a/internal/querynode/flow_graph_gc_node.go b/internal/querynode/flow_graph_gc_node.go index 07ef91b4e7d935669154b99b068641f6e9ba53d3..587ab416b6d462d1d2b0b15c1367fc8b6e39d9d3 100644 --- a/internal/querynode/flow_graph_gc_node.go +++ b/internal/querynode/flow_graph_gc_node.go @@ -1,7 +1,7 @@ package querynode import ( - "context" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" "go.uber.org/zap" @@ -17,7 +17,7 @@ func (gcNode *gcNode) Name() string { return "gcNode" } -func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (gcNode *gcNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { //log.Debug("Do gcNode operation") if len(in) != 1 { @@ -51,7 +51,7 @@ func (gcNode *gcNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Con // } //} - return nil, ctx + return nil } func newGCNode(replica ReplicaInterface) *gcNode { diff --git a/internal/querynode/flow_graph_insert_node.go b/internal/querynode/flow_graph_insert_node.go index 91c4e08a82a7bd64b23aecc1a4ad16ba784c9b6e..5f32c0f999298ae09dbcae5bd4423b67c6b8a94d 100644 --- a/internal/querynode/flow_graph_insert_node.go +++ b/internal/querynode/flow_graph_insert_node.go @@ -4,10 +4,12 @@ import ( "context" "sync" - "go.uber.org/zap" - + "github.com/opentracing/opentracing-go" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" + "github.com/zilliztech/milvus-distributed/internal/util/trace" + "go.uber.org/zap" ) type insertNode struct { @@ -28,7 +30,7 @@ func (iNode *insertNode) Name() string { return "iNode" } -func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { //log.Debug("Do insertNode operation") if len(in) != 1 { @@ -50,7 +52,14 @@ func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context. } if iMsg == nil { - return []Msg{}, ctx + return []Msg{} + } + + var spans []opentracing.Span + for _, msg := range iMsg.insertMessages { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) } // 1. hash insertMessages to insertData @@ -108,7 +117,10 @@ func (iNode *insertNode) Operate(ctx context.Context, in []Msg) ([]Msg, context. gcRecord: iMsg.gcRecord, timeRange: iMsg.timeRange, } - return []Msg{res}, ctx + for _, sp := range spans { + sp.Finish() + } + return []Msg{res} } func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) { diff --git a/internal/querynode/flow_graph_service_time_node.go b/internal/querynode/flow_graph_service_time_node.go index eb3572ce8dac6e8f8b148ec0e07bd32205a0fe06..cf673297be4ced234eeed39c2aa827c53aee9dfc 100644 --- a/internal/querynode/flow_graph_service_time_node.go +++ b/internal/querynode/flow_graph_service_time_node.go @@ -3,12 +3,12 @@ package querynode import ( "context" - "go.uber.org/zap" - "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" + "go.uber.org/zap" ) type serviceTimeNode struct { @@ -22,7 +22,7 @@ func (stNode *serviceTimeNode) Name() string { return "stNode" } -func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { //log.Debug("Do serviceTimeNode operation") if len(in) != 1 { @@ -37,7 +37,7 @@ func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, co } if serviceTimeMsg == nil { - return []Msg{}, ctx + return []Msg{} } // update service time @@ -57,7 +57,7 @@ func (stNode *serviceTimeNode) Operate(ctx context.Context, in []Msg) ([]Msg, co gcRecord: serviceTimeMsg.gcRecord, timeRange: serviceTimeMsg.timeRange, } - return []Msg{res}, ctx + return []Msg{res} } func (stNode *serviceTimeNode) sendTimeTick(ts Timestamp) error { @@ -78,7 +78,7 @@ func (stNode *serviceTimeNode) sendTimeTick(ts Timestamp) error { }, } msgPack.Msgs = append(msgPack.Msgs, &timeTickMsg) - return stNode.timeTickMsgStream.Produce(context.TODO(), &msgPack) + return stNode.timeTickMsgStream.Produce(&msgPack) } func newServiceTimeNode(ctx context.Context, replica ReplicaInterface, factory msgstream.Factory, collectionID UniqueID) *serviceTimeNode { diff --git a/internal/querynode/load_service_test.go b/internal/querynode/load_service_test.go index f22123ad8f495986298a8af151616819b719ad11..dd374f075f397f27a0c951af802e1dd26068296d 100644 --- a/internal/querynode/load_service_test.go +++ b/internal/querynode/load_service_test.go @@ -1038,16 +1038,16 @@ func doInsert(ctx context.Context, collectionID UniqueID, partitionID UniqueID, var ddMsgStream msgstream.MsgStream = ddStream ddMsgStream.Start() - err = insertMsgStream.Produce(ctx, &msgPack) + err = insertMsgStream.Produce(&msgPack) if err != nil { return err } - err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = insertMsgStream.Broadcast(&timeTickMsgPack) if err != nil { return err } - err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = ddMsgStream.Broadcast(&timeTickMsgPack) if err != nil { return err } @@ -1104,11 +1104,11 @@ func sentTimeTick(ctx context.Context) error { var ddMsgStream msgstream.MsgStream = ddStream ddMsgStream.Start() - err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = insertMsgStream.Broadcast(&timeTickMsgPack) if err != nil { return err } - err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = ddMsgStream.Broadcast(&timeTickMsgPack) if err != nil { return err } diff --git a/internal/querynode/search_collection.go b/internal/querynode/search_collection.go index 7ada659b3beedda5fff4f8fae9e6b3e0bf519070..a8a7978a93b9b56f7e3c3b94c4bbf0b00c8ca5d5 100644 --- a/internal/querynode/search_collection.go +++ b/internal/querynode/search_collection.go @@ -15,6 +15,7 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" + "github.com/zilliztech/milvus-distributed/internal/util/trace" ) type searchCollection struct { @@ -99,6 +100,9 @@ func (s *searchCollection) setServiceableTime(t Timestamp) { } func (s *searchCollection) emptySearch(searchMsg *msgstream.SearchMsg) { + sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) + defer sp.Finish() + searchMsg.SetTraceCtx(ctx) err := s.search(searchMsg) if err != nil { log.Error(err.Error()) @@ -164,6 +168,8 @@ func (s *searchCollection) doUnsolvedMsgSearch() { continue } for _, sm := range searchMsg { + sp, ctx := trace.StartSpanFromContext(sm.TraceCtx()) + sm.SetTraceCtx(ctx) err := s.search(sm) if err != nil { log.Error(err.Error()) @@ -172,6 +178,7 @@ func (s *searchCollection) doUnsolvedMsgSearch() { log.Error("publish FailedSearchResult failed", zap.Error(err2)) } } + sp.Finish() } log.Debug("doUnsolvedMsgSearch, do search done", zap.Int("num of searchMsg", len(searchMsg))) } @@ -181,6 +188,9 @@ func (s *searchCollection) doUnsolvedMsgSearch() { // TODO:: cache map[dsl]plan // TODO: reBatched search requests func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { + sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) + defer sp.Finish() + searchMsg.SetTraceCtx(ctx) searchTimestamp := searchMsg.Base.Timestamp var queryBlob = searchMsg.Query.Value query := milvuspb.SearchRequest{} @@ -266,7 +276,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } resultChannelInt, _ := strconv.ParseInt(searchMsg.ResultChannelID, 10, 64) searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}}, + BaseMsg: msgstream.BaseMsg{Ctx: searchMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}}, SearchResults: internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, @@ -328,7 +338,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } resultChannelInt, _ := strconv.ParseInt(searchMsg.ResultChannelID, 10, 64) searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}}, + BaseMsg: msgstream.BaseMsg{Ctx: searchMsg.Ctx, HashValues: []uint32{uint32(resultChannelInt)}}, SearchResults: internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, @@ -368,19 +378,19 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } func (s *searchCollection) publishSearchResult(msg msgstream.TsMsg) error { - // span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "publish search result") - // defer span.Finish() - // msg.SetMsgContext(ctx) + span, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + defer span.Finish() + msg.SetTraceCtx(ctx) msgPack := msgstream.MsgPack{} msgPack.Msgs = append(msgPack.Msgs, msg) - err := s.searchResultMsgStream.Produce(context.TODO(), &msgPack) + err := s.searchResultMsgStream.Produce(&msgPack) return err } func (s *searchCollection) publishFailedSearchResult(searchMsg *msgstream.SearchMsg, errMsg string) error { - // span, ctx := opentracing.StartSpanFromContext(msg.GetMsgContext(), "receive search msg") - // defer span.Finish() - // msg.SetMsgContext(ctx) + span, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) + defer span.Finish() + searchMsg.SetTraceCtx(ctx) //log.Debug("Public fail SearchResult!") msgPack := msgstream.MsgPack{} @@ -401,7 +411,7 @@ func (s *searchCollection) publishFailedSearchResult(searchMsg *msgstream.Search } msgPack.Msgs = append(msgPack.Msgs, searchResultMsg) - err := s.searchResultMsgStream.Produce(context.TODO(), &msgPack) + err := s.searchResultMsgStream.Produce(&msgPack) if err != nil { return err } diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index 7cf7b33b80b45388e8c99b5cd60b3fa6227fa1d5..1393efcb790eded94f95a22944b3bb88d0c3da5b 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/zilliztech/milvus-distributed/internal/log" "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/util/trace" "go.uber.org/zap" "strconv" "strings" @@ -77,7 +78,7 @@ func (s *searchService) consumeSearch() { case <-s.ctx.Done(): return default: - msgPack, _ := s.searchMsgStream.Consume() + msgPack := s.searchMsgStream.Consume() if msgPack == nil || len(msgPack.Msgs) <= 0 { continue } @@ -87,6 +88,8 @@ func (s *searchService) consumeSearch() { if !ok { continue } + sp, ctx := trace.StartSpanFromContext(sm.BaseMsg.Ctx) + sm.BaseMsg.Ctx = ctx err := s.collectionCheck(sm.CollectionID) if err != nil { s.emptySearchCollection.emptySearch(sm) @@ -98,6 +101,7 @@ func (s *searchService) consumeSearch() { s.startSearchCollection(sm.CollectionID) } sc.msgBuffer <- sm + sp.Finish() } log.Debug("do empty search done", zap.Int("num of searchMsg", emptySearchNum)) } diff --git a/internal/querynode/search_service_test.go b/internal/querynode/search_service_test.go index bf473cf85e89b815ac0104ffd42b28161e589f5a..bd211f7a3471a37b5a34edc9cbd9004bf9325d78 100644 --- a/internal/querynode/search_service_test.go +++ b/internal/querynode/search_service_test.go @@ -19,8 +19,6 @@ import ( ) func TestSearch_Search(t *testing.T) { - ctx := context.Background() - collectionID := UniqueID(0) node := newQueryNodeMock() @@ -108,7 +106,7 @@ func TestSearch_Search(t *testing.T) { searchStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) searchStream.AsProducer(searchProducerChannels) searchStream.Start() - err = searchStream.Produce(ctx, &msgPackSearch) + err = searchStream.Produce(&msgPackSearch) assert.NoError(t, err) node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory) @@ -203,12 +201,12 @@ func TestSearch_Search(t *testing.T) { var ddMsgStream msgstream.MsgStream = ddStream ddMsgStream.Start() - err = insertMsgStream.Produce(ctx, &msgPack) + err = insertMsgStream.Produce(&msgPack) assert.NoError(t, err) - err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = insertMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) - err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = ddMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) // dataSync @@ -221,8 +219,6 @@ func TestSearch_Search(t *testing.T) { } func TestSearch_SearchMultiSegments(t *testing.T) { - ctx := context.Background() - collectionID := UniqueID(0) pulsarURL := Params.PulsarAddress @@ -310,7 +306,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) { searchStream, _ := msFactory.NewMsgStream(node.queryNodeLoopCtx) searchStream.AsProducer(searchProducerChannels) searchStream.Start() - err = searchStream.Produce(ctx, &msgPackSearch) + err = searchStream.Produce(&msgPackSearch) assert.NoError(t, err) node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, msFactory) @@ -409,12 +405,12 @@ func TestSearch_SearchMultiSegments(t *testing.T) { var ddMsgStream msgstream.MsgStream = ddStream ddMsgStream.Start() - err = insertMsgStream.Produce(ctx, &msgPack) + err = insertMsgStream.Produce(&msgPack) assert.NoError(t, err) - err = insertMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = insertMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) - err = ddMsgStream.Broadcast(ctx, &timeTickMsgPack) + err = ddMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) // dataSync diff --git a/internal/querynode/stats_service.go b/internal/querynode/stats_service.go index e134945b0c4bfae673f07070d8f775f434845a10..f98c6b25bfee7e0a19ac65c4a8e83d50a06e0cfc 100644 --- a/internal/querynode/stats_service.go +++ b/internal/querynode/stats_service.go @@ -91,7 +91,7 @@ func (sService *statsService) publicStatistic(fieldStats []*internalpb.FieldStat var msgPack = msgstream.MsgPack{ Msgs: []msgstream.TsMsg{msg}, } - err := sService.statsStream.Produce(context.TODO(), &msgPack) + err := sService.statsStream.Produce(&msgPack) if err != nil { log.Error(err.Error()) } diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go index de0a2ca4b4513cb77f7ebd5e6eee9385465d25f9..eee61ad1cd27cf76bf09ce079ee259ace11ea569 100644 --- a/internal/timesync/timesync.go +++ b/internal/timesync/timesync.go @@ -97,7 +97,7 @@ func (ttBarrier *softTimeTickBarrier) Start() { return default: } - ttmsgs, _ := ttBarrier.ttStream.Consume() + ttmsgs := ttBarrier.ttStream.Consume() if len(ttmsgs.Msgs) > 0 { for _, timetickmsg := range ttmsgs.Msgs { ttmsg := timetickmsg.(*ms.TimeTickMsg) @@ -161,7 +161,7 @@ func (ttBarrier *hardTimeTickBarrier) Start() { return default: } - ttmsgs, _ := ttBarrier.ttStream.Consume() + ttmsgs := ttBarrier.ttStream.Consume() if len(ttmsgs.Msgs) > 0 { log.Debug("receive tt msg") for _, timetickmsg := range ttmsgs.Msgs { diff --git a/internal/timesync/timetick_watcher.go b/internal/timesync/timetick_watcher.go index 5716a485c1bbc1bc82a36dbe7b7dbe70100a87b2..eca0aed6ce94143c2fce78988fe87ea5fed561bb 100644 --- a/internal/timesync/timetick_watcher.go +++ b/internal/timesync/timetick_watcher.go @@ -41,7 +41,7 @@ func (watcher *MsgTimeTickWatcher) StartBackgroundLoop(ctx context.Context) { msgPack := &ms.MsgPack{} msgPack.Msgs = append(msgPack.Msgs, msg) for _, stream := range watcher.streams { - if err := stream.Broadcast(ctx, msgPack); err != nil { + if err := stream.Broadcast(msgPack); err != nil { log.Warn("stream broadcast failed", zap.Error(err)) } } diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go index 70b92d00aaee79aa567781cefdee44985a1d8736..653aa4d394641dd70f0dd05711f25ac3368408b5 100644 --- a/internal/util/flowgraph/flow_graph.go +++ b/internal/util/flowgraph/flow_graph.go @@ -17,7 +17,7 @@ func (fg *TimeTickedFlowGraph) AddNode(node Node) { nodeName := node.Name() nodeCtx := nodeCtx{ node: node, - inputChannels: make([]chan *MsgWithCtx, 0), + inputChannels: make([]chan Msg, 0), downstreamInputChanIdx: make(map[string]int), } fg.nodeCtx[nodeName] = &nodeCtx @@ -51,7 +51,7 @@ func (fg *TimeTickedFlowGraph) SetEdges(nodeName string, in []string, out []stri return errors.New(errMsg) } maxQueueLength := outNode.node.MaxQueueLength() - outNode.inputChannels = append(outNode.inputChannels, make(chan *MsgWithCtx, maxQueueLength)) + outNode.inputChannels = append(outNode.inputChannels, make(chan Msg, maxQueueLength)) currentNode.downstream[i] = outNode } diff --git a/internal/util/flowgraph/flow_graph_test.go b/internal/util/flowgraph/flow_graph_test.go index 8feed307f0e925e45e72c3abb711123c86745f30..8c4306c6a98cf12cdb98e2aa3c05263de75da094 100644 --- a/internal/util/flowgraph/flow_graph_test.go +++ b/internal/util/flowgraph/flow_graph_test.go @@ -68,43 +68,43 @@ func (a *nodeA) Name() string { return "NodeA" } -func (a *nodeA) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { - return append(in, in...), nil +func (a *nodeA) Operate(in []Msg) []Msg { + return append(in, in...) } func (b *nodeB) Name() string { return "NodeB" } -func (b *nodeB) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (b *nodeB) Operate(in []Msg) []Msg { messages := make([]*intMsg, 0) for _, msg := range msg2IntMsg(in) { messages = append(messages, &intMsg{ num: math.Pow(msg.num, 2), }) } - return intMsg2Msg(messages), nil + return intMsg2Msg(messages) } func (c *nodeC) Name() string { return "NodeC" } -func (c *nodeC) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (c *nodeC) Operate(in []Msg) []Msg { messages := make([]*intMsg, 0) for _, msg := range msg2IntMsg(in) { messages = append(messages, &intMsg{ num: math.Sqrt(msg.num), }) } - return intMsg2Msg(messages), nil + return intMsg2Msg(messages) } func (d *nodeD) Name() string { return "NodeD" } -func (d *nodeD) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) { +func (d *nodeD) Operate(in []Msg) []Msg { messages := make([]*intMsg, 0) outLength := len(in) / 2 inMessages := msg2IntMsg(in) @@ -117,7 +117,7 @@ func (d *nodeD) Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) d.d = messages[0].num d.resChan <- d.d fmt.Println("flow graph result:", d.d) - return intMsg2Msg(messages), nil + return intMsg2Msg(messages) } func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) { @@ -129,12 +129,8 @@ func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) { time.Sleep(time.Millisecond * time.Duration(500)) var num = float64(rand.Int() % 100) var msg Msg = &intMsg{num: num} - var msgWithContext = &MsgWithCtx{ - ctx: ctx, - msg: msg, - } a := nodeA{} - fg.nodeCtx[a.Name()].inputChannels[0] <- msgWithContext + fg.nodeCtx[a.Name()].inputChannels[0] <- msg fmt.Println("send number", num, "to node", a.Name()) res, ok := receiveResult(ctx, fg) if !ok { @@ -254,7 +250,7 @@ func TestTimeTickedFlowGraph_Start(t *testing.T) { // init node A nodeCtxA := fg.nodeCtx[a.Name()] - nodeCtxA.inputChannels = []chan *MsgWithCtx{make(chan *MsgWithCtx, 10)} + nodeCtxA.inputChannels = []chan Msg{make(chan Msg, 10)} go fg.Start() diff --git a/internal/util/flowgraph/input_node.go b/internal/util/flowgraph/input_node.go index bb9b38eaac4b88a343339fd59409430e844f5696..042b282cf63abf321d8c46c6b8bebb67f6e466eb 100644 --- a/internal/util/flowgraph/input_node.go +++ b/internal/util/flowgraph/input_node.go @@ -1,10 +1,7 @@ package flowgraph import ( - "context" - "github.com/opentracing/opentracing-go" - "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/util/trace" ) @@ -28,17 +25,20 @@ func (inNode *InputNode) InStream() *msgstream.MsgStream { } // empty input and return one *Msg -func (inNode *InputNode) Operate(ctx context.Context, msgs []Msg) ([]Msg, context.Context) { +func (inNode *InputNode) Operate(in []Msg) []Msg { //fmt.Println("Do InputNode operation") - msgPack, ctx := (*inNode.inStream).Consume() - - sp, ctx := trace.StartSpanFromContext(ctx, opentracing.Tag{Key: "NodeName", Value: inNode.Name()}) - defer sp.Finish() + msgPack := (*inNode.inStream).Consume() // TODO: add status if msgPack == nil { - return nil, ctx + return nil + } + var spans []opentracing.Span + for _, msg := range msgPack.Msgs { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) } var msgStreamMsg Msg = &MsgStreamMsg{ @@ -49,7 +49,11 @@ func (inNode *InputNode) Operate(ctx context.Context, msgs []Msg) ([]Msg, contex endPositions: msgPack.EndPositions, } - return []Msg{msgStreamMsg}, ctx + for _, span := range spans { + span.Finish() + } + + return []Msg{msgStreamMsg} } func NewInputNode(inStream *msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode { diff --git a/internal/util/flowgraph/message.go b/internal/util/flowgraph/message.go index 10a0459b0405549c20e252f9850e2d71bba1618b..62785bdce1a20de7e5893d22ff15dcc41273a9e9 100644 --- a/internal/util/flowgraph/message.go +++ b/internal/util/flowgraph/message.go @@ -1,6 +1,8 @@ package flowgraph -import "github.com/zilliztech/milvus-distributed/internal/msgstream" +import ( + "github.com/zilliztech/milvus-distributed/internal/msgstream" +) type Msg interface { TimeTick() Timestamp diff --git a/internal/util/flowgraph/node.go b/internal/util/flowgraph/node.go index ec5192605d8f6ce7d7f5f32ba2e05cbfc1ebecff..6e57903db06ec3718a29857b839bdd272bd365e2 100644 --- a/internal/util/flowgraph/node.go +++ b/internal/util/flowgraph/node.go @@ -6,16 +6,13 @@ import ( "log" "sync" "time" - - "github.com/opentracing/opentracing-go" - "github.com/zilliztech/milvus-distributed/internal/util/trace" ) type Node interface { Name() string MaxQueueLength() int32 MaxParallelism() int32 - Operate(ctx context.Context, in []Msg) ([]Msg, context.Context) + Operate(in []Msg) []Msg IsInputNode() bool } @@ -26,7 +23,7 @@ type BaseNode struct { type nodeCtx struct { node Node - inputChannels []chan *MsgWithCtx + inputChannels []chan Msg inputMessages []Msg downstream []*nodeCtx downstreamInputChanIdx map[string]int @@ -35,11 +32,6 @@ type nodeCtx struct { NumCompletedTasks int64 } -type MsgWithCtx struct { - ctx context.Context - msg Msg -} - func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { if nodeCtx.node.IsInputNode() { // fmt.Println("start InputNode.inStream") @@ -60,17 +52,13 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { // inputs from inputsMessages for Operate inputs := make([]Msg, 0) - var msgCtx context.Context var res []Msg - var sp opentracing.Span if !nodeCtx.node.IsInputNode() { - msgCtx = nodeCtx.collectInputMessages(ctx) + nodeCtx.collectInputMessages(ctx) inputs = nodeCtx.inputMessages } n := nodeCtx.node - res, msgCtx = n.Operate(msgCtx, inputs) - sp, msgCtx = trace.StartSpanFromContext(msgCtx) - sp.SetTag("node name", n.Name()) + res = n.Operate(inputs) downstreamLength := len(nodeCtx.downstreamInputChanIdx) if len(nodeCtx.downstream) < downstreamLength { @@ -84,10 +72,9 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { w := sync.WaitGroup{} for i := 0; i < downstreamLength; i++ { w.Add(1) - go nodeCtx.downstream[i].ReceiveMsg(msgCtx, &w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()]) + go nodeCtx.downstream[i].ReceiveMsg(&w, res[i], nodeCtx.downstreamInputChanIdx[nodeCtx.downstream[i].node.Name()]) } w.Wait() - sp.Finish() } } } @@ -99,18 +86,14 @@ func (nodeCtx *nodeCtx) Close() { } } -func (nodeCtx *nodeCtx) ReceiveMsg(ctx context.Context, wg *sync.WaitGroup, msg Msg, inputChanIdx int) { - sp, ctx := trace.StartSpanFromContext(ctx) - defer sp.Finish() - nodeCtx.inputChannels[inputChanIdx] <- &MsgWithCtx{ctx: ctx, msg: msg} +func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int) { + nodeCtx.inputChannels[inputChanIdx] <- msg //fmt.Println((*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx) wg.Done() } -func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Context { - var opts []opentracing.StartSpanOption - +func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) { inputsNum := len(nodeCtx.inputChannels) nodeCtx.inputMessages = make([]Msg, inputsNum) @@ -121,29 +104,17 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Co channel := nodeCtx.inputChannels[i] select { case <-exitCtx.Done(): - return nil - case msgWithCtx, ok := <-channel: + return + case msg, ok := <-channel: if !ok { // TODO: add status log.Println("input channel closed") - return nil - } - nodeCtx.inputMessages[i] = msgWithCtx.msg - if msgWithCtx.ctx != nil { - sp, _ := trace.StartSpanFromContext(msgWithCtx.ctx) - opts = append(opts, opentracing.ChildOf(sp.Context())) - sp.Finish() + return } + nodeCtx.inputMessages[i] = msg } } - var ctx context.Context - var sp opentracing.Span - if len(opts) != 0 { - sp, ctx = trace.StartSpanFromContext(context.Background(), opts...) - defer sp.Finish() - } - // timeTick alignment check if len(nodeCtx.inputMessages) > 1 { t := nodeCtx.inputMessages[0].TimeTick() @@ -169,7 +140,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Co log.Println("input channel closed") return } - nodeCtx.inputMessages[i] = msg.msg + nodeCtx.inputMessages[i] = msg } } } @@ -183,7 +154,6 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) context.Co } } - return ctx } func (node *BaseNode) MaxQueueLength() int32 {