From f9b8a2789384113e5d23c27fc92a15b470a80521 Mon Sep 17 00:00:00 2001
From: fangyincheng <fangyincheng@sina.com>
Date: Thu, 5 Sep 2019 22:50:38 +0800
Subject: [PATCH] Fix: get attachments for comsumer

---
 protocol/dubbo/client.go             | 64 ++++++++++++++++++++--------
 protocol/dubbo/client_test.go        | 26 +++++------
 protocol/dubbo/codec.go              |  7 ++-
 protocol/dubbo/dubbo_invoker.go      | 10 +++--
 protocol/dubbo/dubbo_invoker_test.go |  7 +--
 protocol/dubbo/listener.go           |  2 +
 protocol/dubbo/readwriter.go         |  2 +-
 7 files changed, 75 insertions(+), 43 deletions(-)

diff --git a/protocol/dubbo/client.go b/protocol/dubbo/client.go
index e0cfc0b0d..2cec5aa8f 100644
--- a/protocol/dubbo/client.go
+++ b/protocol/dubbo/client.go
@@ -150,46 +150,74 @@ func NewClient(opt Options) *Client {
 	return c
 }
 
+type Request struct {
+	addr   string
+	svcUrl common.URL
+	method string
+	args   interface{}
+	atta   map[string]string
+}
+
+func NewRequest(addr string, svcUrl common.URL, method string, args interface{}, atta map[string]string) *Request {
+	return &Request{
+		addr:   addr,
+		svcUrl: svcUrl,
+		method: method,
+		args:   args,
+		atta:   atta,
+	}
+}
+
+type Response struct {
+	reply interface{}
+	atta  map[string]string
+}
+
+func NewResponse(reply interface{}, atta map[string]string) *Response {
+	return &Response{
+		reply: reply,
+		atta:  atta,
+	}
+}
+
 // call one way
-func (c *Client) CallOneway(addr string, svcUrl common.URL, method string, args interface{}, atta map[string]string) error {
+func (c *Client) CallOneway(request *Request) error {
 
-	return perrors.WithStack(c.call(CT_OneWay, addr, svcUrl, method, args, nil, nil, atta))
+	return perrors.WithStack(c.call(CT_OneWay, request, NewResponse(nil, nil), nil))
 }
 
-// if @reply is nil, the transport layer will get the response without notify the invoker.
-func (c *Client) Call(addr string, svcUrl common.URL, method string, args, reply interface{}, atta map[string]string) error {
+// if @response is nil, the transport layer will get the response without notify the invoker.
+func (c *Client) Call(request *Request, response *Response) error {
 
 	ct := CT_TwoWay
-	if reply == nil {
+	if response.reply == nil {
 		ct = CT_OneWay
 	}
 
-	return perrors.WithStack(c.call(ct, addr, svcUrl, method, args, reply, nil, atta))
+	return perrors.WithStack(c.call(ct, request, response, nil))
 }
 
-func (c *Client) AsyncCall(addr string, svcUrl common.URL, method string, args interface{},
-	callback AsyncCallback, reply interface{}, atta map[string]string) error {
+func (c *Client) AsyncCall(request *Request, callback AsyncCallback, response *Response) error {
 
-	return perrors.WithStack(c.call(CT_TwoWay, addr, svcUrl, method, args, reply, callback, atta))
+	return perrors.WithStack(c.call(CT_TwoWay, request, response, callback))
 }
 
-func (c *Client) call(ct CallType, addr string, svcUrl common.URL, method string,
-	args, reply interface{}, callback AsyncCallback, atta map[string]string) error {
+func (c *Client) call(ct CallType, request *Request, response *Response, callback AsyncCallback) error {
 
 	p := &DubboPackage{}
-	p.Service.Path = strings.TrimPrefix(svcUrl.Path, "/")
-	p.Service.Interface = svcUrl.GetParam(constant.INTERFACE_KEY, "")
-	p.Service.Version = svcUrl.GetParam(constant.VERSION_KEY, "")
-	p.Service.Method = method
+	p.Service.Path = strings.TrimPrefix(request.svcUrl.Path, "/")
+	p.Service.Interface = request.svcUrl.GetParam(constant.INTERFACE_KEY, "")
+	p.Service.Version = request.svcUrl.GetParam(constant.VERSION_KEY, "")
+	p.Service.Method = request.method
 	p.Service.Timeout = c.opts.RequestTimeout
 	p.Header.SerialID = byte(S_Dubbo)
-	p.Body = hessian.NewRequest(args, atta)
+	p.Body = hessian.NewRequest(request.args, request.atta)
 
 	var rsp *PendingResponse
 	if ct != CT_OneWay {
 		p.Header.Type = hessian.PackageRequest_TwoWay
 		rsp = NewPendingResponse()
-		rsp.reply = reply
+		rsp.response = response
 		rsp.callback = callback
 	} else {
 		p.Header.Type = hessian.PackageRequest
@@ -200,7 +228,7 @@ func (c *Client) call(ct CallType, addr string, svcUrl common.URL, method string
 		session getty.Session
 		conn    *gettyRPCClient
 	)
-	conn, session, err = c.selectSession(addr)
+	conn, session, err = c.selectSession(request.addr)
 	if err != nil {
 		return perrors.WithStack(err)
 	}
diff --git a/protocol/dubbo/client_test.go b/protocol/dubbo/client_test.go
index e8bf02e0e..ed986cf87 100644
--- a/protocol/dubbo/client_test.go
+++ b/protocol/dubbo/client_test.go
@@ -51,7 +51,7 @@ func TestClient_CallOneway(t *testing.T) {
 	c.pool = newGettyRPCClientConnPool(c, clientConf.PoolSize, time.Duration(int(time.Second)*clientConf.PoolTTL))
 
 	//user := &User{}
-	err := c.CallOneway("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil)
+	err := c.CallOneway(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil))
 	assert.NoError(t, err)
 
 	// destroy
@@ -77,50 +77,50 @@ func TestClient_Call(t *testing.T) {
 	)
 
 	user = &User{}
-	err = c.Call("127.0.0.1:20000", url, "GetBigPkg", []interface{}{nil}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetBigPkg", []interface{}{nil}, nil), NewResponse(user, nil))
 	assert.NoError(t, err)
 	assert.NotEqual(t, "", user.Id)
 	assert.NotEqual(t, "", user.Name)
 
 	user = &User{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil), NewResponse(user, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, User{Id: "1", Name: "username"}, *user)
 
 	user = &User{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser0", []interface{}{"1", nil, "username"}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser0", []interface{}{"1", nil, "username"}, nil), NewResponse(user, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, User{Id: "1", Name: "username"}, *user)
 
-	err = c.Call("127.0.0.1:20000", url, "GetUser1", []interface{}{}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser1", []interface{}{}, nil), NewResponse(user, nil))
 	assert.NoError(t, err)
 
-	err = c.Call("127.0.0.1:20000", url, "GetUser2", []interface{}{}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser2", []interface{}{}, nil), NewResponse(user, nil))
 	assert.EqualError(t, err, "error")
 
 	user2 := []interface{}{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser3", []interface{}{}, &user2, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser3", []interface{}{}, nil), NewResponse(&user2, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, &User{Id: "1", Name: "username"}, user2[0])
 
 	user2 = []interface{}{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser4", []interface{}{[]interface{}{"1", "username"}}, &user2, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser4", []interface{}{[]interface{}{"1", "username"}}, nil), NewResponse(&user2, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, &User{Id: "1", Name: "username"}, user2[0])
 
 	user3 := map[interface{}]interface{}{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser5", []interface{}{map[interface{}]interface{}{"id": "1", "name": "username"}}, &user3, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser5", []interface{}{map[interface{}]interface{}{"id": "1", "name": "username"}}, nil), NewResponse(&user3, nil))
 	assert.NoError(t, err)
 	assert.NotNil(t, user3)
 	assert.Equal(t, &User{Id: "1", Name: "username"}, user3["key"])
 
 	user = &User{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser6", []interface{}{0}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser6", []interface{}{0}, nil), NewResponse(user, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, User{Id: "", Name: ""}, *user)
 
 	user = &User{}
-	err = c.Call("127.0.0.1:20000", url, "GetUser6", []interface{}{1}, user, nil)
+	err = c.Call(NewRequest("127.0.0.1:20000", url, "GetUser6", []interface{}{1}, nil), NewResponse(user, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, User{Id: "1", Name: ""}, *user)
 
@@ -144,10 +144,10 @@ func TestClient_AsyncCall(t *testing.T) {
 	user := &User{}
 	lock := sync.Mutex{}
 	lock.Lock()
-	err := c.AsyncCall("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, func(response CallResponse) {
+	err := c.AsyncCall(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil), func(response CallResponse) {
 		assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*User))
 		lock.Unlock()
-	}, user, nil)
+	}, NewResponse(user, nil))
 	assert.NoError(t, err)
 	assert.Equal(t, User{}, *user)
 
diff --git a/protocol/dubbo/codec.go b/protocol/dubbo/codec.go
index 98c29a4e5..5e8ad23a6 100644
--- a/protocol/dubbo/codec.go
+++ b/protocol/dubbo/codec.go
@@ -91,9 +91,8 @@ func (p *DubboPackage) Unmarshal(buf *bytes.Buffer, opts ...interface{}) error {
 		pendingRsp, ok := client.pendingResponses.Load(SequenceType(p.Header.ID))
 		if !ok {
 			return perrors.Errorf("client.GetPendingResponse(%v) = nil", p.Header.ID)
-		} else {
-			p.Body = &hessian.Response{RspObj: pendingRsp.(*PendingResponse).reply}
 		}
+		p.Body = &hessian.Response{RspObj: pendingRsp.(*PendingResponse).response.reply}
 	}
 
 	// read body
@@ -111,7 +110,7 @@ type PendingResponse struct {
 	start     time.Time
 	readStart time.Time
 	callback  AsyncCallback
-	reply     interface{}
+	response  *Response
 	done      chan struct{}
 }
 
@@ -127,6 +126,6 @@ func (r PendingResponse) GetCallResponse() CallResponse {
 		Cause:     r.err,
 		Start:     r.start,
 		ReadStart: r.readStart,
-		Reply:     r.reply,
+		Reply:     r.response,
 	}
 }
diff --git a/protocol/dubbo/dubbo_invoker.go b/protocol/dubbo/dubbo_invoker.go
index 6528ef67c..edddc7ad5 100644
--- a/protocol/dubbo/dubbo_invoker.go
+++ b/protocol/dubbo/dubbo_invoker.go
@@ -34,7 +34,7 @@ import (
 	invocation_impl "github.com/apache/dubbo-go/protocol/invocation"
 )
 
-var Err_No_Reply = perrors.New("request need @reply")
+var Err_No_Reply = perrors.New("request need @response")
 
 type DubboInvoker struct {
 	protocol.BaseInvoker
@@ -68,21 +68,23 @@ func (di *DubboInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
 		logger.Errorf("ParseBool - error: %v", err)
 		async = false
 	}
+	response := NewResponse(inv.Reply(), nil)
 	if async {
 		if callBack, ok := inv.CallBack().(func(response CallResponse)); ok {
-			result.Err = di.client.AsyncCall(url.Location, url, inv.MethodName(), inv.Arguments(), callBack, inv.Reply(), inv.Attachments())
+			result.Err = di.client.AsyncCall(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments()), callBack, response)
 		} else {
-			result.Err = di.client.CallOneway(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments())
+			result.Err = di.client.CallOneway(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments()))
 		}
 	} else {
 		if inv.Reply() == nil {
 			result.Err = Err_No_Reply
 		} else {
-			result.Err = di.client.Call(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Reply(), inv.Attachments())
+			result.Err = di.client.Call(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments()), response)
 		}
 	}
 	if result.Err == nil {
 		result.Rest = inv.Reply()
+		result.Attrs = response.atta
 	}
 	logger.Debugf("result.Err: %v, result.Rest: %v", result.Err, result.Rest)
 
diff --git a/protocol/dubbo/dubbo_invoker_test.go b/protocol/dubbo/dubbo_invoker_test.go
index 09a4c128b..0a765356f 100644
--- a/protocol/dubbo/dubbo_invoker_test.go
+++ b/protocol/dubbo/dubbo_invoker_test.go
@@ -49,12 +49,13 @@ func TestDubboInvoker_Invoke(t *testing.T) {
 	user := &User{}
 
 	inv := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"), invocation.WithArguments([]interface{}{"1", "username"}),
-		invocation.WithReply(user))
+		invocation.WithReply(user), invocation.WithAttachments(map[string]string{"test_key": "test_value"}))
 
 	// Call
 	res := invoker.Invoke(inv)
 	assert.NoError(t, res.Error())
 	assert.Equal(t, User{Id: "1", Name: "username"}, *res.Result().(*User))
+	assert.Equal(t, "test_value", res.Attachments()["test_key"]) // test attachments for request/response
 
 	// CallOneway
 	inv.SetAttachments(constant.ASYNC_KEY, "true")
@@ -65,7 +66,7 @@ func TestDubboInvoker_Invoke(t *testing.T) {
 	lock := sync.Mutex{}
 	lock.Lock()
 	inv.SetCallBack(func(response CallResponse) {
-		assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*User))
+		assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*Response).reply.(*User))
 		lock.Unlock()
 	})
 	res = invoker.Invoke(inv)
@@ -75,7 +76,7 @@ func TestDubboInvoker_Invoke(t *testing.T) {
 	inv.SetAttachments(constant.ASYNC_KEY, "false")
 	inv.SetReply(nil)
 	res = invoker.Invoke(inv)
-	assert.EqualError(t, res.Error(), "request need @reply")
+	assert.EqualError(t, res.Error(), "request need @response")
 
 	// destroy
 	lock.Lock()
diff --git a/protocol/dubbo/listener.go b/protocol/dubbo/listener.go
index 5f46d522c..1a7b00281 100644
--- a/protocol/dubbo/listener.go
+++ b/protocol/dubbo/listener.go
@@ -105,6 +105,8 @@ func (h *RpcClientHandler) OnMessage(session getty.Session, pkg interface{}) {
 		pendingResponse.err = p.Err
 	}
 
+	pendingResponse.response.atta = p.Body.(*Response).atta
+
 	if pendingResponse.callback == nil {
 		pendingResponse.done <- struct{}{}
 	} else {
diff --git a/protocol/dubbo/readwriter.go b/protocol/dubbo/readwriter.go
index 34a66f616..e43619ebd 100644
--- a/protocol/dubbo/readwriter.go
+++ b/protocol/dubbo/readwriter.go
@@ -63,7 +63,7 @@ func (p *RpcClientPackageHandler) Read(ss getty.Session, data []byte) (interface
 	}
 
 	pkg.Err = pkg.Body.(*hessian.Response).Exception
-	pkg.Body = pkg.Body.(*hessian.Response).RspObj
+	pkg.Body = NewResponse(pkg.Body.(*hessian.Response).RspObj, pkg.Body.(*hessian.Response).Attachments)
 
 	return pkg, hessian.HEADER_LENGTH + pkg.Header.BodyLen, nil
 }
-- 
GitLab