diff --git a/common/proxy/proxy.go b/common/proxy/proxy.go
index 1c079f6bca52bf8f6e8c5ebb168da82ab8ccb5f2..d13646dba86eea04adb3726d33ee9d20457276b6 100644
--- a/common/proxy/proxy.go
+++ b/common/proxy/proxy.go
@@ -181,3 +181,7 @@ func (p *Proxy) Implement(v common.RPCService) {
func (p *Proxy) Get() common.RPCService {
return p.rpc
}
+
+func (p *Proxy) GetCallback() interface{} {
+ return p.callBack
+}
diff --git a/common/proxy/proxy_factory.go b/common/proxy/proxy_factory.go
index 2567e0ee09cf7fa5aef7fde46872eb88205d8e45..116cfe06693b6923ca10e0df6964317dabd91d0e 100644
--- a/common/proxy/proxy_factory.go
+++ b/common/proxy/proxy_factory.go
@@ -24,6 +24,7 @@ import (
type ProxyFactory interface {
GetProxy(invoker protocol.Invoker, url *common.URL) *Proxy
+ GetAsyncProxy(invoker protocol.Invoker, callBack interface{}, url *common.URL) *Proxy
GetInvoker(url common.URL) protocol.Invoker
}
diff --git a/common/proxy/proxy_factory/default.go b/common/proxy/proxy_factory/default.go
index bafba60b400ec59d99e2d68ecf4d067c906ba6fb..06824fdc1e27cde5e1905be3277451dd4395049c 100644
--- a/common/proxy/proxy_factory/default.go
+++ b/common/proxy/proxy_factory/default.go
@@ -55,11 +55,16 @@ func NewDefaultProxyFactory(options ...proxy.Option) proxy.ProxyFactory {
return &DefaultProxyFactory{}
}
func (factory *DefaultProxyFactory) GetProxy(invoker protocol.Invoker, url *common.URL) *proxy.Proxy {
+ return factory.GetAsyncProxy(invoker, nil, url)
+}
+
+func (factory *DefaultProxyFactory) GetAsyncProxy(invoker protocol.Invoker, callBack interface{}, url *common.URL) *proxy.Proxy {
//create proxy
attachments := map[string]string{}
attachments[constant.ASYNC_KEY] = url.GetParam(constant.ASYNC_KEY, "false")
- return proxy.NewProxy(invoker, nil, attachments)
+ return proxy.NewProxy(invoker, callBack, attachments)
}
+
func (factory *DefaultProxyFactory) GetInvoker(url common.URL) protocol.Invoker {
return &ProxyInvoker{
BaseInvoker: *protocol.NewBaseInvoker(url),
diff --git a/common/proxy/proxy_factory/default_test.go b/common/proxy/proxy_factory/default_test.go
index b6a6b675baf992b2d64ffd19291ee2dc009bd1e3..7159b4b00eb2fcddb0f20f701f56b3179e57c4a0 100644
--- a/common/proxy/proxy_factory/default_test.go
+++ b/common/proxy/proxy_factory/default_test.go
@@ -18,6 +18,7 @@
package proxy_factory
import (
+ "fmt"
"testing"
)
@@ -37,6 +38,21 @@ func Test_GetProxy(t *testing.T) {
assert.NotNil(t, proxy)
}
+type TestAsync struct {
+}
+
+func (u *TestAsync) CallBack(res common.CallbackResponse) {
+ fmt.Println("CallBack res:", res)
+}
+
+func Test_GetAsyncProxy(t *testing.T) {
+ proxyFactory := NewDefaultProxyFactory()
+ url := common.NewURLWithOptions()
+ async := &TestAsync{}
+ proxy := proxyFactory.GetAsyncProxy(protocol.NewBaseInvoker(*url), async.CallBack, url)
+ assert.NotNil(t, proxy)
+}
+
func Test_GetInvoker(t *testing.T) {
proxyFactory := NewDefaultProxyFactory()
url := common.NewURLWithOptions()
diff --git a/common/rpc_service.go b/common/rpc_service.go
index 4741a6fa3c0daef97f044f639a5e64a38fe4a187..4c9f083dd0850c3f110491ef820c7b677c8009aa 100644
--- a/common/rpc_service.go
+++ b/common/rpc_service.go
@@ -39,6 +39,18 @@ type RPCService interface {
Reference() string // rpc service id or reference id
}
+//AsyncCallbackService callback interface for async
+type AsyncCallbackService interface {
+ CallBack(response CallbackResponse) // callback
+}
+
+//CallbackResponse for different protocol
+type CallbackResponse interface {
+}
+
+//AsyncCallback async callback method
+type AsyncCallback func(response CallbackResponse)
+
// for lowercase func
// func MethodMapper() map[string][string] {
// return map[string][string]{}
diff --git a/config/reference_config.go b/config/reference_config.go
index 8703c459bab306f98beb1668a1f9438126586f24..6b34f5535964a98516fbb215312575c9d3cfeb86 100644
--- a/config/reference_config.go
+++ b/config/reference_config.go
@@ -55,7 +55,7 @@ type ReferenceConfig struct {
Group string `yaml:"group" json:"group,omitempty" property:"group"`
Version string `yaml:"version" json:"version,omitempty" property:"version"`
Methods []*MethodConfig `yaml:"methods" json:"methods,omitempty" property:"methods"`
- async bool `yaml:"async" json:"async,omitempty" property:"async"`
+ Async bool `yaml:"async" json:"async,omitempty" property:"async"`
Params map[string]string `yaml:"params" json:"params,omitempty" property:"params"`
invoker protocol.Invoker
urls []*common.URL
@@ -141,7 +141,12 @@ func (refconfig *ReferenceConfig) Refer() {
}
//create proxy
- refconfig.pxy = extension.GetProxyFactory(consumerConfig.ProxyFactory).GetProxy(refconfig.invoker, url)
+ if refconfig.Async {
+ callback := GetCallback(refconfig.id)
+ refconfig.pxy = extension.GetProxyFactory(consumerConfig.ProxyFactory).GetAsyncProxy(refconfig.invoker, callback, url)
+ } else {
+ refconfig.pxy = extension.GetProxyFactory(consumerConfig.ProxyFactory).GetProxy(refconfig.invoker, url)
+ }
}
// @v is service provider implemented RPCService
@@ -169,7 +174,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values {
urlMap.Set(constant.GENERIC_KEY, strconv.FormatBool(refconfig.Generic))
urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
//getty invoke async or sync
- urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.async))
+ urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.Async))
//application info
urlMap.Set(constant.APPLICATION_KEY, consumerConfig.ApplicationConfig.Name)
diff --git a/config/reference_config_test.go b/config/reference_config_test.go
index a81dbf06cef7d275cf6af4a7f651ff8d1600a3c9..a7af925cabcf6b4e7db9213f2bb6953bea965699 100644
--- a/config/reference_config_test.go
+++ b/config/reference_config_test.go
@@ -81,6 +81,7 @@ func doInitConsumer() {
},
References: map[string]*ReferenceConfig{
"MockService": {
+ id: "MockProvider",
Params: map[string]string{
"serviceid": "soa.mock",
"forks": "5",
@@ -110,6 +111,26 @@ func doInitConsumer() {
}
}
+var mockProvider = new(MockProvider)
+
+type MockProvider struct {
+}
+
+func (m *MockProvider) Reference() string {
+ return "MockProvider"
+}
+
+func (m *MockProvider) CallBack(res common.CallbackResponse) {
+}
+
+func doInitConsumerAsync() {
+ doInitConsumer()
+ SetConsumerService(mockProvider)
+ for _, v := range consumerConfig.References {
+ v.Async = true
+ }
+}
+
func doInitConsumerWithSingleRegistry() {
consumerConfig = &ConsumerConfig{
ApplicationConfig: &ApplicationConfig{
@@ -181,6 +202,22 @@ func Test_Refer(t *testing.T) {
}
consumerConfig = nil
}
+
+func Test_ReferAsync(t *testing.T) {
+ doInitConsumerAsync()
+ extension.SetProtocol("registry", GetProtocol)
+ extension.SetCluster("registryAware", cluster_impl.NewRegistryAwareCluster)
+
+ for _, reference := range consumerConfig.References {
+ reference.Refer()
+ assert.Equal(t, "soa.mock", reference.Params["serviceid"])
+ assert.NotNil(t, reference.invoker)
+ assert.NotNil(t, reference.pxy)
+ assert.NotNil(t, reference.pxy.GetCallback())
+ }
+ consumerConfig = nil
+}
+
func Test_ReferP2P(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
diff --git a/config/service.go b/config/service.go
index 2bceac4a8c20bb598dc2607c90c8206e4a448808..f1b51790ca13df0534882837397181e45e56ffa3 100644
--- a/config/service.go
+++ b/config/service.go
@@ -43,3 +43,11 @@ func GetConsumerService(name string) common.RPCService {
func GetProviderService(name string) common.RPCService {
return proServices[name]
}
+
+func GetCallback(name string) func(response common.CallbackResponse) {
+ service := GetConsumerService(name)
+ if sv, ok := service.(common.AsyncCallbackService); ok {
+ return sv.CallBack
+ }
+ return nil
+}
diff --git a/protocol/dubbo/client.go b/protocol/dubbo/client.go
index 817dbab572f4c75ae8679a407e60e51a8010a5ef..81f392565f701d990dc1783d5d467814a0fba5bf 100644
--- a/protocol/dubbo/client.go
+++ b/protocol/dubbo/client.go
@@ -113,7 +113,9 @@ type Options struct {
RequestTimeout time.Duration
}
-type CallResponse struct {
+//AsyncCallbackResponse async response for dubbo
+type AsyncCallbackResponse struct {
+ common.CallbackResponse
Opts Options
Cause error
Start time.Time // invoke(call) start time == write start time
@@ -121,8 +123,6 @@ type CallResponse struct {
Reply interface{}
}
-type AsyncCallback func(response CallResponse)
-
type Client struct {
opts Options
conf ClientConfig
@@ -199,12 +199,12 @@ func (c *Client) Call(request *Request, response *Response) error {
return perrors.WithStack(c.call(ct, request, response, nil))
}
-func (c *Client) AsyncCall(request *Request, callback AsyncCallback, response *Response) error {
+func (c *Client) AsyncCall(request *Request, callback common.AsyncCallback, response *Response) error {
return perrors.WithStack(c.call(CT_TwoWay, request, response, callback))
}
-func (c *Client) call(ct CallType, request *Request, response *Response, callback AsyncCallback) error {
+func (c *Client) call(ct CallType, request *Request, response *Response, callback common.AsyncCallback) error {
p := &DubboPackage{}
p.Service.Path = strings.TrimPrefix(request.svcUrl.Path, "/")
diff --git a/protocol/dubbo/client_test.go b/protocol/dubbo/client_test.go
index eb1f15c862a910120e118c06bf9b572e93f58832..3f8a8ee98c3b2d8b87e2d5469a18d1792578d1d6 100644
--- a/protocol/dubbo/client_test.go
+++ b/protocol/dubbo/client_test.go
@@ -144,8 +144,9 @@ func TestClient_AsyncCall(t *testing.T) {
user := &User{}
lock := sync.Mutex{}
lock.Lock()
- err := c.AsyncCall(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil), func(response CallResponse) {
- assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*Response).reply.(*User))
+ err := c.AsyncCall(NewRequest("127.0.0.1:20000", url, "GetUser", []interface{}{"1", "username"}, nil), func(response common.CallbackResponse) {
+ r := response.(AsyncCallbackResponse)
+ assert.Equal(t, User{Id: "1", Name: "username"}, *r.Reply.(*Response).reply.(*User))
lock.Unlock()
}, NewResponse(user, nil))
assert.NoError(t, err)
diff --git a/protocol/dubbo/codec.go b/protocol/dubbo/codec.go
index a878ffd91e29d6949870ec25fed9481f301b435a..758363117f1720a7fe89eb9745b415e506315db8 100644
--- a/protocol/dubbo/codec.go
+++ b/protocol/dubbo/codec.go
@@ -26,6 +26,7 @@ import (
import (
"github.com/apache/dubbo-go-hessian2"
+ "github.com/apache/dubbo-go/common"
perrors "github.com/pkg/errors"
)
@@ -109,7 +110,7 @@ type PendingResponse struct {
err error
start time.Time
readStart time.Time
- callback AsyncCallback
+ callback common.AsyncCallback
response *Response
done chan struct{}
}
@@ -122,8 +123,8 @@ func NewPendingResponse() *PendingResponse {
}
}
-func (r PendingResponse) GetCallResponse() CallResponse {
- return CallResponse{
+func (r PendingResponse) GetCallResponse() common.CallbackResponse {
+ return AsyncCallbackResponse{
Cause: r.err,
Start: r.start,
ReadStart: r.readStart,
diff --git a/protocol/dubbo/dubbo_invoker.go b/protocol/dubbo/dubbo_invoker.go
index 4582e54c2158a43509b26138a8d414d2f34e052a..6dcf2568fa8c88a864c567486a501c2ad7feb3f7 100644
--- a/protocol/dubbo/dubbo_invoker.go
+++ b/protocol/dubbo/dubbo_invoker.go
@@ -75,7 +75,7 @@ func (di *DubboInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
}
response := NewResponse(inv.Reply(), nil)
if async {
- if callBack, ok := inv.CallBack().(func(response CallResponse)); ok {
+ if callBack, ok := inv.CallBack().(func(response common.CallbackResponse)); ok {
result.Err = di.client.AsyncCall(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments()), callBack, response)
} else {
result.Err = di.client.CallOneway(NewRequest(url.Location, url, inv.MethodName(), inv.Arguments(), inv.Attachments()))
diff --git a/protocol/dubbo/dubbo_invoker_test.go b/protocol/dubbo/dubbo_invoker_test.go
index 0a765356f7353829c8486fddba986e3a444441a1..7d60090e2d81bcb750d1e6d79a08059687c7937d 100644
--- a/protocol/dubbo/dubbo_invoker_test.go
+++ b/protocol/dubbo/dubbo_invoker_test.go
@@ -28,6 +28,7 @@ import (
)
import (
+ "github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/protocol/invocation"
)
@@ -65,8 +66,9 @@ func TestDubboInvoker_Invoke(t *testing.T) {
// AsyncCall
lock := sync.Mutex{}
lock.Lock()
- inv.SetCallBack(func(response CallResponse) {
- assert.Equal(t, User{Id: "1", Name: "username"}, *response.Reply.(*Response).reply.(*User))
+ inv.SetCallBack(func(response common.CallbackResponse) {
+ r := response.(AsyncCallbackResponse)
+ assert.Equal(t, User{Id: "1", Name: "username"}, *r.Reply.(*Response).reply.(*User))
lock.Unlock()
})
res = invoker.Invoke(inv)