diff --git a/pkg/common/morpc/codec.go b/pkg/common/morpc/codec.go index bf687ae306169c22efd38115afc24c5b81f06c18..b5fea2ff07ef3bcc7e2bdb282f30a11519e14c50 100644 --- a/pkg/common/morpc/codec.go +++ b/pkg/common/morpc/codec.go @@ -56,6 +56,7 @@ func newMessageCodec(messageFactory func() Message, payloadCopyBufSize int, enab } c := &messageCodec{codec: length.New(bc), bc: bc} c.AddHeaderCodec(&deadlineContextCodec{}) + c.AddHeaderCodec(&traceCodec{}) return c } diff --git a/pkg/common/morpc/codec_header.go b/pkg/common/morpc/codec_header.go index 5e60e831626cf034fec2d9faacfb6f1d7ca79e88..4c16a8d3d6cde8879138dcf174b5f6817296d241 100644 --- a/pkg/common/morpc/codec_header.go +++ b/pkg/common/morpc/codec_header.go @@ -20,6 +20,7 @@ import ( "time" "github.com/fagongzi/goetty/v2/buf" + "github.com/matrixorigin/matrixone/pkg/util/trace" ) type deadlineContextCodec struct { @@ -50,3 +51,43 @@ func (hc *deadlineContextCodec) Decode(msg *RPCMessage, data []byte) (int, error msg.Ctx, msg.cancel = context.WithTimeout(msg.Ctx, time.Duration(buf.Byte2Int64(data))) return 8, nil } + +type traceCodec struct { +} + +func (hc *traceCodec) Encode(msg *RPCMessage, out *buf.ByteBuf) (int, error) { + if msg.Ctx == nil { + return 0, nil + } + + span := trace.SpanFromContext(msg.Ctx) + c := span.SpanContext() + n := c.Size() + out.MustWriteByte(byte(n)) + out.Grow(n) + idx := out.GetWriteIndex() + out.SetWriteIndex(idx + n) + c.MarshalTo(out.RawSlice(idx, idx+n)) + return 1 + n, nil +} + +func (hc *traceCodec) Decode(msg *RPCMessage, data []byte) (int, error) { + if len(data) < 1 { + return 0, io.ErrShortBuffer + } + + if len(data) < int(data[0]) { + return 0, io.ErrShortBuffer + } + + c := &trace.SpanContext{} + if err := c.Unmarshal(data[1 : 1+data[0]]); err != nil { + return 0, err + } + + if msg.Ctx == nil { + msg.Ctx = context.Background() + } + msg.Ctx = trace.ContextWithSpanContext(msg.Ctx, *c) + return int(1 + data[0]), nil +} diff --git a/pkg/common/morpc/codec_header_test.go b/pkg/common/morpc/codec_header_test.go index 388cf3f2bc75dd90475c971018a75e61695dc73f..cdaa715e7ef71371a379e7937324ba54cf9e3cb6 100644 --- a/pkg/common/morpc/codec_header_test.go +++ b/pkg/common/morpc/codec_header_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/fagongzi/goetty/v2/buf" + "github.com/matrixorigin/matrixone/pkg/util/trace" "github.com/stretchr/testify/assert" ) @@ -47,3 +48,29 @@ func TestDecodeContext(t *testing.T) { assert.True(t, ok) assert.True(t, !ts.IsZero()) } + +func TestEncodeAndDecodeTrace(t *testing.T) { + hc := &traceCodec{} + out := buf.NewByteBuf(8) + span := trace.SpanContextWithID(trace.TraceID(1)) + n, err := hc.Encode(&RPCMessage{Ctx: trace.ContextWithSpanContext(context.Background(), span)}, out) + assert.Equal(t, 1+span.Size(), n) + assert.NoError(t, err) + + msg := &RPCMessage{} + _, data := out.ReadBytes(1 + span.Size()) + + n, err = hc.Decode(msg, nil) + assert.Equal(t, 0, n) + assert.Error(t, err) + + n, err = hc.Decode(msg, data[:1]) + assert.Equal(t, 0, n) + assert.Error(t, err) + + n, err = hc.Decode(msg, data) + assert.Equal(t, 1+span.Size(), n) + assert.NoError(t, err) + + assert.Equal(t, span, trace.SpanFromContext(msg.Ctx).SpanContext()) +}