diff --git a/cluster/cluster_impl/available_cluster_invoker.go b/cluster/cluster_impl/available_cluster_invoker.go
index c59c0702c216fe5c58d190a023322aaa00ac9c17..bc6705c8156aaeb6a0a52e08b1aa539e179013ca 100644
--- a/cluster/cluster_impl/available_cluster_invoker.go
+++ b/cluster/cluster_impl/available_cluster_invoker.go
@@ -18,6 +18,7 @@ limitations under the License.
package cluster_impl
import (
+ "context"
"fmt"
)
@@ -40,7 +41,7 @@ func NewAvailableClusterInvoker(directory cluster.Directory) protocol.Invoker {
}
}
-func (invoker *availableClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *availableClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
err := invoker.checkInvokers(invokers, invocation)
if err != nil {
@@ -54,7 +55,7 @@ func (invoker *availableClusterInvoker) Invoke(invocation protocol.Invocation) p
for _, ivk := range invokers {
if ivk.IsAvailable() {
- return ivk.Invoke(invocation)
+ return ivk.Invoke(ctx, invocation)
}
}
return &protocol.RPCResult{Err: errors.New(fmt.Sprintf("no provider available in %v", invokers))}
diff --git a/cluster/cluster_impl/available_cluster_invoker_test.go b/cluster/cluster_impl/available_cluster_invoker_test.go
index 04032a7f24dec0e73acb15921f753921391f1515..de04db1da4e8e6df12960b1a2ee81b0044379d6f 100644
--- a/cluster/cluster_impl/available_cluster_invoker_test.go
+++ b/cluster/cluster_impl/available_cluster_invoker_test.go
@@ -66,7 +66,7 @@ func TestAvailableClusterInvokerSuccess(t *testing.T) {
invoker.EXPECT().IsAvailable().Return(true)
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Equal(t, mockResult, result)
}
@@ -80,7 +80,7 @@ func TestAvailableClusterInvokerNoAvail(t *testing.T) {
invoker.EXPECT().IsAvailable().Return(false)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.TODO(), &invocation.RPCInvocation{})
assert.NotNil(t, result.Error())
assert.True(t, strings.Contains(result.Error().Error(), "no provider available"))
diff --git a/cluster/cluster_impl/broadcast_cluster_invoker.go b/cluster/cluster_impl/broadcast_cluster_invoker.go
index 238df0acfa7fb946e38bfbfd490bce7c0bb34e60..1b49e9a115252d4eca94bedd557ebcc21fee4cc7 100644
--- a/cluster/cluster_impl/broadcast_cluster_invoker.go
+++ b/cluster/cluster_impl/broadcast_cluster_invoker.go
@@ -17,6 +17,9 @@ limitations under the License.
package cluster_impl
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common/logger"
@@ -33,7 +36,7 @@ func newBroadcastClusterInvoker(directory cluster.Directory) protocol.Invoker {
}
}
-func (invoker *broadcastClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *broadcastClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
err := invoker.checkInvokers(invokers, invocation)
if err != nil {
@@ -46,7 +49,7 @@ func (invoker *broadcastClusterInvoker) Invoke(invocation protocol.Invocation) p
var result protocol.Result
for _, ivk := range invokers {
- result = ivk.Invoke(invocation)
+ result = ivk.Invoke(ctx, invocation)
if result.Error() != nil {
logger.Warnf("broadcast invoker invoke err: %v when use invoker: %v\n", result.Error(), ivk)
err = result.Error()
diff --git a/cluster/cluster_impl/broadcast_cluster_invoker_test.go b/cluster/cluster_impl/broadcast_cluster_invoker_test.go
index 565684a8ae25c648ff77aef71d2ced0665202fe7..b20d962e2cffb34d0a151488a1bdf63499e4de86 100644
--- a/cluster/cluster_impl/broadcast_cluster_invoker_test.go
+++ b/cluster/cluster_impl/broadcast_cluster_invoker_test.go
@@ -74,7 +74,7 @@ func Test_BroadcastInvokeSuccess(t *testing.T) {
clusterInvoker := registerBroadcast(t, invokers...)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Equal(t, mockResult, result)
}
@@ -104,6 +104,6 @@ func Test_BroadcastInvokeFailed(t *testing.T) {
clusterInvoker := registerBroadcast(t, invokers...)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Equal(t, mockFailedResult.Err, result.Error())
}
diff --git a/cluster/cluster_impl/failback_cluster_invoker.go b/cluster/cluster_impl/failback_cluster_invoker.go
index c8dbeda09f62e88b51dd4ad2b6b09d5715f0b224..46b0ff634e56c45223a5aeb5566b9b1401518960 100644
--- a/cluster/cluster_impl/failback_cluster_invoker.go
+++ b/cluster/cluster_impl/failback_cluster_invoker.go
@@ -18,6 +18,7 @@
package cluster_impl
import (
+ "context"
"strconv"
"sync"
"time"
@@ -71,7 +72,7 @@ func newFailbackClusterInvoker(directory cluster.Directory) protocol.Invoker {
return invoker
}
-func (invoker *failbackClusterInvoker) process() {
+func (invoker *failbackClusterInvoker) process(ctx context.Context) {
invoker.ticker = time.NewTicker(time.Second * 1)
for range invoker.ticker.C {
// check each timeout task and re-run
@@ -102,7 +103,7 @@ func (invoker *failbackClusterInvoker) process() {
retryInvoker := invoker.doSelect(retryTask.loadbalance, retryTask.invocation, retryTask.invokers, invoked)
var result protocol.Result
- result = retryInvoker.Invoke(retryTask.invocation)
+ result = retryInvoker.Invoke(ctx, retryTask.invocation)
if result.Error() != nil {
retryTask.lastInvoker = retryInvoker
invoker.checkRetry(retryTask, result.Error())
@@ -126,7 +127,7 @@ func (invoker *failbackClusterInvoker) checkRetry(retryTask *retryTimerTask, err
}
}
-func (invoker *failbackClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *failbackClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
err := invoker.checkInvokers(invokers, invocation)
if err != nil {
@@ -150,11 +151,11 @@ func (invoker *failbackClusterInvoker) Invoke(invocation protocol.Invocation) pr
ivk := invoker.doSelect(loadbalance, invocation, invokers, invoked)
//DO INVOKE
- result = ivk.Invoke(invocation)
+ result = ivk.Invoke(ctx, invocation)
if result.Error() != nil {
invoker.once.Do(func() {
invoker.taskList = queue.New(invoker.failbackTasks)
- go invoker.process()
+ go invoker.process(ctx)
})
taskLen := invoker.taskList.Len()
diff --git a/cluster/cluster_impl/failback_cluster_test.go b/cluster/cluster_impl/failback_cluster_test.go
index 1d2266cabebf591b09188fb723f02126a3f1e0ec..895077922a88abc05416e58459205b449831ac56 100644
--- a/cluster/cluster_impl/failback_cluster_test.go
+++ b/cluster/cluster_impl/failback_cluster_test.go
@@ -72,7 +72,7 @@ func Test_FailbackSuceess(t *testing.T) {
mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Equal(t, mockResult, result)
}
@@ -102,7 +102,7 @@ func Test_FailbackRetryOneSuccess(t *testing.T) {
return mockSuccResult
})
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
assert.Equal(t, 0, len(result.Attachments()))
@@ -150,7 +150,7 @@ func Test_FailbackRetryFailed(t *testing.T) {
}
// first call should failed.
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
assert.Equal(t, 0, len(result.Attachments()))
@@ -192,7 +192,7 @@ func Test_FailbackRetryFailed10Times(t *testing.T) {
}).Times(10)
for i := 0; i < 10; i++ {
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
assert.Equal(t, 0, len(result.Attachments()))
@@ -222,14 +222,14 @@ func Test_FailbackOutOfLimit(t *testing.T) {
invoker.EXPECT().Invoke(gomock.Any()).Return(mockFailedResult).Times(11)
// reached limit
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
assert.Equal(t, 0, len(result.Attachments()))
// all will be out of limit
for i := 0; i < 10; i++ {
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
assert.Equal(t, 0, len(result.Attachments()))
diff --git a/cluster/cluster_impl/failfast_cluster_invoker.go b/cluster/cluster_impl/failfast_cluster_invoker.go
index 734ea2c6cb19bf54a338a76a10c9cfcc59d3954b..49e7c7689f5a19a36154e092a6a83cc39da604ba 100644
--- a/cluster/cluster_impl/failfast_cluster_invoker.go
+++ b/cluster/cluster_impl/failfast_cluster_invoker.go
@@ -17,6 +17,9 @@ limitations under the License.
package cluster_impl
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/protocol"
@@ -32,7 +35,7 @@ func newFailFastClusterInvoker(directory cluster.Directory) protocol.Invoker {
}
}
-func (invoker *failfastClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *failfastClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
err := invoker.checkInvokers(invokers, invocation)
if err != nil {
@@ -47,5 +50,5 @@ func (invoker *failfastClusterInvoker) Invoke(invocation protocol.Invocation) pr
}
ivk := invoker.doSelect(loadbalance, invocation, invokers, nil)
- return ivk.Invoke(invocation)
+ return ivk.Invoke(ctx, invocation)
}
diff --git a/cluster/cluster_impl/failfast_cluster_test.go b/cluster/cluster_impl/failfast_cluster_test.go
index 1a4342e6c2b74fd6b1359646eeb463bb6dc17d0a..9585f03b7fa8f45a19c7c47e04dcd57cc1e4bb11 100644
--- a/cluster/cluster_impl/failfast_cluster_test.go
+++ b/cluster/cluster_impl/failfast_cluster_test.go
@@ -69,7 +69,7 @@ func Test_FailfastInvokeSuccess(t *testing.T) {
mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NoError(t, result.Error())
res := result.Result().(rest)
@@ -89,7 +89,7 @@ func Test_FailfastInvokeFail(t *testing.T) {
mockResult := &protocol.RPCResult{Err: perrors.New("error")}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NotNil(t, result.Error())
assert.Equal(t, "error", result.Error().Error())
diff --git a/cluster/cluster_impl/failover_cluster_invoker.go b/cluster/cluster_impl/failover_cluster_invoker.go
index dcce7369931a11f31fb6b9e4e1a6c0aa0ec7cdf6..6178a05a1226ba629d2456ad6886b02a26288e45 100644
--- a/cluster/cluster_impl/failover_cluster_invoker.go
+++ b/cluster/cluster_impl/failover_cluster_invoker.go
@@ -18,6 +18,7 @@
package cluster_impl
import (
+ "context"
"strconv"
)
@@ -43,7 +44,7 @@ func newFailoverClusterInvoker(directory cluster.Directory) protocol.Invoker {
}
}
-func (invoker *failoverClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *failoverClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
err := invoker.checkInvokers(invokers, invocation)
@@ -95,7 +96,7 @@ func (invoker *failoverClusterInvoker) Invoke(invocation protocol.Invocation) pr
}
invoked = append(invoked, ivk)
//DO INVOKE
- result = ivk.Invoke(invocation)
+ result = ivk.Invoke(ctx, invocation)
if result.Error() != nil {
providers = append(providers, ivk.GetUrl().Key())
continue
diff --git a/cluster/cluster_impl/failover_cluster_test.go b/cluster/cluster_impl/failover_cluster_test.go
index 78b799320dfa58d55e531c658ec5eb0e69306cff..7bde83ea66a49f9317732ec46da0f11800f846eb 100644
--- a/cluster/cluster_impl/failover_cluster_test.go
+++ b/cluster/cluster_impl/failover_cluster_test.go
@@ -77,7 +77,7 @@ type rest struct {
success bool
}
-func (bi *MockInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (bi *MockInvoker) Invoke(c context.Context, invocation protocol.Invocation) protocol.Result {
count++
var success bool
var err error = nil
@@ -112,9 +112,9 @@ func normalInvoke(t *testing.T, successCount int, urlParam url.Values, invocatio
staticDir := directory.NewStaticDirectory(invokers)
clusterInvoker := failoverCluster.Join(staticDir)
if len(invocations) > 0 {
- return clusterInvoker.Invoke(invocations[0])
+ return clusterInvoker.Invoke(context.Background(), invocations[0])
}
- return clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ return clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
}
func Test_FailoverInvokeSuccess(t *testing.T) {
urlParams := url.Values{}
@@ -155,14 +155,14 @@ func Test_FailoverDestroy(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
- url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
+ url, _ := common.NewURL(context.Background(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
invokers = append(invokers, NewMockInvoker(url, 1))
}
staticDir := directory.NewStaticDirectory(invokers)
clusterInvoker := failoverCluster.Join(staticDir)
assert.Equal(t, true, clusterInvoker.IsAvailable())
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NoError(t, result.Error())
count = 0
clusterInvoker.Destroy()
diff --git a/cluster/cluster_impl/failsafe_cluster_invoker.go b/cluster/cluster_impl/failsafe_cluster_invoker.go
index b95f997fef87cf466f07c4e506e41758e7998e52..4d8fe27719eb71fa287fe4142d8e92ca17acfba4 100644
--- a/cluster/cluster_impl/failsafe_cluster_invoker.go
+++ b/cluster/cluster_impl/failsafe_cluster_invoker.go
@@ -17,6 +17,9 @@
package cluster_impl
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common/constant"
@@ -42,7 +45,7 @@ func newFailsafeClusterInvoker(directory cluster.Directory) protocol.Invoker {
}
}
-func (invoker *failsafeClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *failsafeClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
err := invoker.checkInvokers(invokers, invocation)
@@ -65,7 +68,7 @@ func (invoker *failsafeClusterInvoker) Invoke(invocation protocol.Invocation) pr
ivk := invoker.doSelect(loadbalance, invocation, invokers, invoked)
//DO INVOKE
- result = ivk.Invoke(invocation)
+ result = ivk.Invoke(ctx, invocation)
if result.Error() != nil {
// ignore
logger.Errorf("Failsafe ignore exception: %v.\n", result.Error().Error())
diff --git a/cluster/cluster_impl/failsafe_cluster_test.go b/cluster/cluster_impl/failsafe_cluster_test.go
index 7888b97c3a02bd4679f8ec5267637b8d2a7c12e4..930b4bb16628e2b363659a65fc174543b7f2cf6e 100644
--- a/cluster/cluster_impl/failsafe_cluster_test.go
+++ b/cluster/cluster_impl/failsafe_cluster_test.go
@@ -69,7 +69,7 @@ func Test_FailSafeInvokeSuccess(t *testing.T) {
mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NoError(t, result.Error())
res := result.Result().(rest)
@@ -88,7 +88,7 @@ func Test_FailSafeInvokeFail(t *testing.T) {
mockResult := &protocol.RPCResult{Err: perrors.New("error")}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NoError(t, result.Error())
assert.Nil(t, result.Result())
diff --git a/cluster/cluster_impl/forking_cluster_invoker.go b/cluster/cluster_impl/forking_cluster_invoker.go
index d6cf2f4b89ab4f322fa758deecae90c60742ef49..c830079ff6d3c29c3385eda289782f5e52877be2 100644
--- a/cluster/cluster_impl/forking_cluster_invoker.go
+++ b/cluster/cluster_impl/forking_cluster_invoker.go
@@ -18,6 +18,7 @@ limitations under the License.
package cluster_impl
import (
+ "context"
"errors"
"fmt"
"time"
@@ -44,7 +45,7 @@ func newForkingClusterInvoker(directory cluster.Directory) protocol.Invoker {
}
}
-func (invoker *forkingClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *forkingClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
err := invoker.checkWhetherDestroyed()
if err != nil {
return &protocol.RPCResult{Err: err}
@@ -75,7 +76,7 @@ func (invoker *forkingClusterInvoker) Invoke(invocation protocol.Invocation) pro
resultQ := queue.New(1)
for _, ivk := range selected {
go func(k protocol.Invoker) {
- result := k.Invoke(invocation)
+ result := k.Invoke(ctx, invocation)
err := resultQ.Put(result)
if err != nil {
logger.Errorf("resultQ put failed with exception: %v.\n", err)
diff --git a/cluster/cluster_impl/forking_cluster_test.go b/cluster/cluster_impl/forking_cluster_test.go
index 8603f8aedc4e28a3a4ca2f115355debc1a5ecc62..d819781eb23631e6b8eef76e5bdf7d7837f43d53 100644
--- a/cluster/cluster_impl/forking_cluster_test.go
+++ b/cluster/cluster_impl/forking_cluster_test.go
@@ -87,7 +87,7 @@ func Test_ForkingInvokeSuccess(t *testing.T) {
clusterInvoker := registerForking(t, invokers...)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Equal(t, mockResult, result)
wg.Wait()
}
@@ -117,7 +117,7 @@ func Test_ForkingInvokeTimeout(t *testing.T) {
clusterInvoker := registerForking(t, invokers...)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NotNil(t, result)
assert.NotNil(t, result.Error())
wg.Wait()
@@ -156,7 +156,7 @@ func Test_ForkingInvokeHalfTimeout(t *testing.T) {
clusterInvoker := registerForking(t, invokers...)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.Equal(t, mockResult, result)
wg.Wait()
}
diff --git a/cluster/cluster_impl/registry_aware_cluster_invoker.go b/cluster/cluster_impl/registry_aware_cluster_invoker.go
index 5785c02489f95168d5419f0087f38b07c851a4a3..cded5bf16432e6b0c590e15b81c28369889a5f88 100644
--- a/cluster/cluster_impl/registry_aware_cluster_invoker.go
+++ b/cluster/cluster_impl/registry_aware_cluster_invoker.go
@@ -17,6 +17,9 @@
package cluster_impl
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common/constant"
@@ -33,19 +36,19 @@ func newRegistryAwareClusterInvoker(directory cluster.Directory) protocol.Invoke
}
}
-func (invoker *registryAwareClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (invoker *registryAwareClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
invokers := invoker.directory.List(invocation)
//First, pick the invoker (XXXClusterInvoker) that comes from the local registry, distinguish by a 'default' key.
for _, invoker := range invokers {
if invoker.IsAvailable() && invoker.GetUrl().GetParam(constant.REGISTRY_DEFAULT_KEY, "false") == "true" {
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
}
//If none of the invokers has a local signal, pick the first one available.
for _, invoker := range invokers {
if invoker.IsAvailable() {
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
}
return nil
diff --git a/cluster/cluster_impl/registry_aware_cluster_test.go b/cluster/cluster_impl/registry_aware_cluster_test.go
index 4ae15cc5066c70646dee66cf4ef601202653cb07..7f916c1aaa5609beb3d818e08f5b0950c3273e6d 100644
--- a/cluster/cluster_impl/registry_aware_cluster_test.go
+++ b/cluster/cluster_impl/registry_aware_cluster_test.go
@@ -39,13 +39,13 @@ func Test_RegAwareInvokeSuccess(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
- url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
+ url, _ := common.NewURL(context.Background(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
invokers = append(invokers, NewMockInvoker(url, 1))
}
staticDir := directory.NewStaticDirectory(invokers)
clusterInvoker := regAwareCluster.Join(staticDir)
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NoError(t, result.Error())
count = 0
}
@@ -55,14 +55,14 @@ func TestDestroy(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
- url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
+ url, _ := common.NewURL(context.Background(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
invokers = append(invokers, NewMockInvoker(url, 1))
}
staticDir := directory.NewStaticDirectory(invokers)
clusterInvoker := regAwareCluster.Join(staticDir)
assert.Equal(t, true, clusterInvoker.IsAvailable())
- result := clusterInvoker.Invoke(&invocation.RPCInvocation{})
+ result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{})
assert.NoError(t, result.Error())
count = 0
clusterInvoker.Destroy()
diff --git a/cluster/router/condition/factory_test.go b/cluster/router/condition/factory_test.go
index 1bb6382898d9cc4434e59965af8f56bca0ae4233..054bd0b6890a210ca20805ed4b1977699cf3152e 100644
--- a/cluster/router/condition/factory_test.go
+++ b/cluster/router/condition/factory_test.go
@@ -93,7 +93,7 @@ type rest struct {
var count int
-func (bi *MockInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (bi *MockInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
count++
var success bool
var err error = nil
diff --git a/common/proxy/proxy.go b/common/proxy/proxy.go
index 0b5e2860495604207dc8e4a384225e2fca47df1a..8f3f8916100a1b7a3fe25a64ba7e5149cd3282dd 100644
--- a/common/proxy/proxy.go
+++ b/common/proxy/proxy.go
@@ -18,6 +18,7 @@
package proxy
import (
+ "context"
"reflect"
"sync"
)
@@ -129,7 +130,7 @@ func (p *Proxy) Implement(v common.RPCService) {
inv.SetAttachments(k, value)
}
- result := p.invoke.Invoke(inv)
+ result := p.invoke.Invoke(context.Background(), inv)
err = result.Error()
logger.Infof("[makeDubboCallProxy] result: %v, err: %v", result.Result(), err)
diff --git a/common/proxy/proxy_factory/default.go b/common/proxy/proxy_factory/default.go
index 06824fdc1e27cde5e1905be3277451dd4395049c..1f2f80e52abf7a262c84f15c7278946eded9ab26 100644
--- a/common/proxy/proxy_factory/default.go
+++ b/common/proxy/proxy_factory/default.go
@@ -18,6 +18,7 @@
package proxy_factory
import (
+ "context"
"reflect"
"strings"
)
@@ -75,7 +76,7 @@ type ProxyInvoker struct {
protocol.BaseInvoker
}
-func (pi *ProxyInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (pi *ProxyInvoker) Invoke(context context.Context, invocation protocol.Invocation) protocol.Result {
result := &protocol.RPCResult{}
result.SetAttachments(invocation.Attachments())
diff --git a/filter/filter.go b/filter/filter.go
index 5bd78998a76a1b0e8af99b0b3f0d7e6c103bb794..6c9e4455476b42d97718b5364d9687ac9671f687 100644
--- a/filter/filter.go
+++ b/filter/filter.go
@@ -17,12 +17,15 @@
package filter
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/protocol"
)
// Extension - Filter
type Filter interface {
- Invoke(protocol.Invoker, protocol.Invocation) protocol.Result
- OnResponse(protocol.Result, protocol.Invoker, protocol.Invocation) protocol.Result
+ Invoke(context.Context, protocol.Invoker, protocol.Invocation) protocol.Result
+ OnResponse(context.Context, protocol.Result, protocol.Invoker, protocol.Invocation) protocol.Result
}
diff --git a/filter/filter_impl/access_log_filter.go b/filter/filter_impl/access_log_filter.go
index a1b022f27edef4a3bdb84c6117364394cd72aefe..468393ba5be0c0991b9ab218ebc440d699382c20 100644
--- a/filter/filter_impl/access_log_filter.go
+++ b/filter/filter_impl/access_log_filter.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"os"
"reflect"
"strings"
@@ -66,13 +67,13 @@ type AccessLogFilter struct {
logChan chan AccessLogData
}
-func (ef *AccessLogFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *AccessLogFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
accessLog := invoker.GetUrl().GetParam(constant.ACCESS_LOG_KEY, "")
if len(accessLog) > 0 {
accessLogData := AccessLogData{data: ef.buildAccessLogData(invoker, invocation), accessLog: accessLog}
ef.logIntoChannel(accessLogData)
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
// it won't block the invocation
@@ -119,7 +120,7 @@ func (ef *AccessLogFilter) buildAccessLogData(invoker protocol.Invoker, invocati
return dataMap
}
-func (ef *AccessLogFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *AccessLogFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/filter/filter_impl/access_log_filter_test.go b/filter/filter_impl/access_log_filter_test.go
index 5076962486da1ca40c4bdf6d7ba4b75a05bb0f92..14b9166b0fc486638c77388c76b49423a8d4a83e 100644
--- a/filter/filter_impl/access_log_filter_test.go
+++ b/filter/filter_impl/access_log_filter_test.go
@@ -49,7 +49,7 @@ func TestAccessLogFilter_Invoke_Not_Config(t *testing.T) {
inv := invocation.NewRPCInvocation("MethodName", []interface{}{"OK", "Hello"}, attach)
accessLogFilter := GetAccessLogFilter()
- result := accessLogFilter.Invoke(invoker, inv)
+ result := accessLogFilter.Invoke(context.Background(), invoker, inv)
assert.Nil(t, result.Error())
}
@@ -70,13 +70,13 @@ func TestAccessLogFilter_Invoke_Default_Config(t *testing.T) {
inv := invocation.NewRPCInvocation("MethodName", []interface{}{"OK", "Hello"}, attach)
accessLogFilter := GetAccessLogFilter()
- result := accessLogFilter.Invoke(invoker, inv)
+ result := accessLogFilter.Invoke(context.Background(), invoker, inv)
assert.Nil(t, result.Error())
}
func TestAccessLogFilter_OnResponse(t *testing.T) {
result := &protocol.RPCResult{}
accessLogFilter := GetAccessLogFilter()
- response := accessLogFilter.OnResponse(result, nil, nil)
+ response := accessLogFilter.OnResponse(nil, result, nil, nil)
assert.Equal(t, result, response)
}
diff --git a/filter/filter_impl/active_filter.go b/filter/filter_impl/active_filter.go
index e23e4dde74fdeb7f56c5ccad9caa3202e92882a4..cc46fc9d8624f6e756ccfe5c491c3177450e10b5 100644
--- a/filter/filter_impl/active_filter.go
+++ b/filter/filter_impl/active_filter.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"strconv"
)
@@ -41,15 +42,15 @@ func init() {
type ActiveFilter struct {
}
-func (ef *ActiveFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *ActiveFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking active filter. %v,%v", invocation.MethodName(), len(invocation.Arguments()))
invocation.(*invocation2.RPCInvocation).SetAttachments(dubboInvokeStartTime, strconv.FormatInt(protocol.CurrentTimeMillis(), 10))
protocol.BeginCount(invoker.GetUrl(), invocation.MethodName())
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (ef *ActiveFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *ActiveFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
startTime, err := strconv.ParseInt(invocation.(*invocation2.RPCInvocation).AttachmentsByKey(dubboInvokeStartTime, "0"), 10, 64)
if err != nil {
diff --git a/filter/filter_impl/active_filter_test.go b/filter/filter_impl/active_filter_test.go
index acc4f9121641bfbbc484a711c0ea04dffeab55e3..7b355086f9d48b3fb864ed40d1cb5db999543d77 100644
--- a/filter/filter_impl/active_filter_test.go
+++ b/filter/filter_impl/active_filter_test.go
@@ -28,7 +28,7 @@ func TestActiveFilter_Invoke(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
invoker.EXPECT().Invoke(gomock.Any()).Return(nil)
invoker.EXPECT().GetUrl().Return(url).Times(1)
- filter.Invoke(invoker, invoc)
+ filter.Invoke(context.Background(), invoker, invoc)
assert.True(t, invoc.AttachmentsByKey(dubboInvokeStartTime, "") != "")
}
@@ -48,7 +48,7 @@ func TestActiveFilter_OnResponse(t *testing.T) {
result := &protocol.RPCResult{
Err: errors.New("test"),
}
- filter.OnResponse(result, invoker, invoc)
+ filter.OnResponse(nil, result, invoker, invoc)
methodStatus := protocol.GetMethodStatus(url, "test")
urlStatus := protocol.GetURLStatus(url)
diff --git a/filter/filter_impl/echo_filter.go b/filter/filter_impl/echo_filter.go
index f67a47ac8704b1f6e10135bd24234cc0b8965dec..f6bdd4a4e8398c65303d426a48f104e12314ded3 100644
--- a/filter/filter_impl/echo_filter.go
+++ b/filter/filter_impl/echo_filter.go
@@ -17,6 +17,9 @@
package filter_impl
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
@@ -38,7 +41,7 @@ func init() {
// Echo func(ctx context.Context, arg interface{}, rsp *Xxx) error
type EchoFilter struct{}
-func (ef *EchoFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *EchoFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking echo filter.")
logger.Debugf("%v,%v", invocation.MethodName(), len(invocation.Arguments()))
if invocation.MethodName() == constant.ECHO && len(invocation.Arguments()) == 1 {
@@ -48,10 +51,10 @@ func (ef *EchoFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invoc
}
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (ef *EchoFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *EchoFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/filter/filter_impl/echo_filter_test.go b/filter/filter_impl/echo_filter_test.go
index b75b9c19a1f073cc23dfccfa97a51e456e59d9cc..fc09bdce696c6be3c9e11d0ac864b187d1d85cde 100644
--- a/filter/filter_impl/echo_filter_test.go
+++ b/filter/filter_impl/echo_filter_test.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"testing"
)
@@ -33,12 +34,10 @@ import (
func TestEchoFilter_Invoke(t *testing.T) {
filter := GetFilter()
- result := filter.Invoke(protocol.NewBaseInvoker(common.URL{}),
- invocation.NewRPCInvocation("$echo", []interface{}{"OK"}, nil))
+ result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(common.URL{}), invocation.NewRPCInvocation("$echo", []interface{}{"OK"}, nil))
assert.Equal(t, "OK", result.Result())
- result = filter.Invoke(protocol.NewBaseInvoker(common.URL{}),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil))
+ result = filter.Invoke(context.Background(), protocol.NewBaseInvoker(common.URL{}), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
diff --git a/filter/filter_impl/execute_limit_filter.go b/filter/filter_impl/execute_limit_filter.go
index a192aede400b1d73b7e604b09126ae372a1e91db..f9ff87751b21979f9d794db88deb9f4d8527f0d1 100644
--- a/filter/filter_impl/execute_limit_filter.go
+++ b/filter/filter_impl/execute_limit_filter.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"strconv"
"sync"
"sync/atomic"
@@ -75,7 +76,7 @@ type ExecuteState struct {
concurrentCount int64
}
-func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *ExecuteLimitFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
methodConfigPrefix := "methods." + invocation.MethodName() + "."
url := invoker.GetUrl()
limitTarget := url.ServiceKey()
@@ -97,7 +98,7 @@ func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protoc
}
if limitRate < 0 {
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
state, _ := ef.executeState.LoadOrStore(limitTarget, &ExecuteState{
@@ -113,10 +114,10 @@ func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protoc
return extension.GetRejectedExecutionHandler(rejectedHandlerConfig).RejectedExecution(url, invocation)
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (ef *ExecuteLimitFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *ExecuteLimitFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/filter/filter_impl/execute_limit_filter_test.go b/filter/filter_impl/execute_limit_filter_test.go
index e3836251df4ba78befcbb5720affb5dbc3cbdf1f..ae8641f2db0b98b59f9939cfc85f3ad096b1bc7f 100644
--- a/filter/filter_impl/execute_limit_filter_test.go
+++ b/filter/filter_impl/execute_limit_filter_test.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"net/url"
"testing"
)
@@ -43,7 +44,7 @@ func TestExecuteLimitFilter_Invoke_Ignored(t *testing.T) {
limitFilter := GetExecuteLimitFilter()
- result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
+ result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
@@ -60,7 +61,7 @@ func TestExecuteLimitFilter_Invoke_Configure_Error(t *testing.T) {
limitFilter := GetExecuteLimitFilter()
- result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
+ result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
@@ -77,7 +78,7 @@ func TestExecuteLimitFilter_Invoke(t *testing.T) {
limitFilter := GetExecuteLimitFilter()
- result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
+ result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
diff --git a/filter/filter_impl/generic_filter.go b/filter/filter_impl/generic_filter.go
index 3bfae1e35d76cd65289d5f100da621a8fa745d1b..9d3804d9434ce2ab108dfa8be4607a6425f2d29c 100644
--- a/filter/filter_impl/generic_filter.go
+++ b/filter/filter_impl/generic_filter.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"reflect"
"strings"
)
@@ -44,7 +45,7 @@ func init() {
type GenericFilter struct{}
-func (ef *GenericFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *GenericFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 {
oldArguments := invocation.Arguments()
@@ -60,13 +61,13 @@ func (ef *GenericFilter) Invoke(invoker protocol.Invoker, invocation protocol.In
}
newInvocation := invocation2.NewRPCInvocation(invocation.MethodName(), newArguments, invocation.Attachments())
newInvocation.SetReply(invocation.Reply())
- return invoker.Invoke(newInvocation)
+ return invoker.Invoke(ctx, newInvocation)
}
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (ef *GenericFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *GenericFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/filter/filter_impl/generic_service_filter.go b/filter/filter_impl/generic_service_filter.go
index da33f13e5ef29a7164c3776b65cc5cabd4b43888..6beebf4566b657d4d4ea0d2c737cdf3344bdcbe4 100644
--- a/filter/filter_impl/generic_service_filter.go
+++ b/filter/filter_impl/generic_service_filter.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"reflect"
"strings"
)
@@ -49,12 +50,12 @@ func init() {
type GenericServiceFilter struct{}
-func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *GenericServiceFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking generic service filter.")
logger.Debugf("generic service filter methodName:%v,args:%v", invocation.MethodName(), len(invocation.Arguments()))
if invocation.MethodName() != constant.GENERIC || len(invocation.Arguments()) != 3 {
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
var (
@@ -107,10 +108,10 @@ func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation prot
}
newInvocation := invocation2.NewRPCInvocation(methodName, newParams, invocation.Attachments())
newInvocation.SetReply(invocation.Reply())
- return invoker.Invoke(newInvocation)
+ return invoker.Invoke(ctx, newInvocation)
}
-func (ef *GenericServiceFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *GenericServiceFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 && result.Result() != nil {
v := reflect.ValueOf(result.Result())
if v.Kind() == reflect.Ptr {
diff --git a/filter/filter_impl/generic_service_filter_test.go b/filter/filter_impl/generic_service_filter_test.go
index e36ec5086ecacffbf56a0da6dd9249ffd6fec649..8211e717564465bba3009772715a3ab1cd3322dd 100644
--- a/filter/filter_impl/generic_service_filter_test.go
+++ b/filter/filter_impl/generic_service_filter_test.go
@@ -99,7 +99,7 @@ func TestGenericServiceFilter_Invoke(t *testing.T) {
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
filter := GetGenericServiceFilter()
url, _ := common.NewURL(context.Background(), "testprotocol://127.0.0.1:20000/com.test.Path")
- result := filter.Invoke(&proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation)
+ result := filter.Invoke(context.Background(), &proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
@@ -124,7 +124,7 @@ func TestGenericServiceFilter_ResponseTestStruct(t *testing.T) {
filter := GetGenericServiceFilter()
methodName := "$invoke"
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
- r := filter.OnResponse(result, nil, rpcInvocation)
+ r := filter.OnResponse(nil, result, nil, rpcInvocation)
assert.NotNil(t, r.Result())
assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.Map)
}
@@ -142,7 +142,7 @@ func TestGenericServiceFilter_ResponseString(t *testing.T) {
filter := GetGenericServiceFilter()
methodName := "$invoke"
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
- r := filter.OnResponse(result, nil, rpcInvocation)
+ r := filter.OnResponse(nil, result, nil, rpcInvocation)
assert.NotNil(t, r.Result())
assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.String)
}
diff --git a/filter/filter_impl/graceful_shutdown_filter.go b/filter/filter_impl/graceful_shutdown_filter.go
index 1af7e1f8c32ea3924550399a7ff5e76c68368636..95e625b2d56895a4d57823e4e0e2e7d1d5e90a08 100644
--- a/filter/filter_impl/graceful_shutdown_filter.go
+++ b/filter/filter_impl/graceful_shutdown_filter.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"sync/atomic"
)
@@ -52,16 +53,16 @@ type gracefulShutdownFilter struct {
shutdownConfig *config.ShutdownConfig
}
-func (gf *gracefulShutdownFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (gf *gracefulShutdownFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
if gf.rejectNewRequest() {
logger.Info("The application is closing, new request will be rejected.")
return gf.getRejectHandler().RejectedExecution(invoker.GetUrl(), invocation)
}
atomic.AddInt32(&gf.activeCount, 1)
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (gf *gracefulShutdownFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (gf *gracefulShutdownFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
atomic.AddInt32(&gf.activeCount, -1)
// although this isn't thread safe, it won't be a problem if the gf.rejectNewRequest() is true.
if gf.shutdownConfig != nil && gf.activeCount <= 0 {
diff --git a/filter/filter_impl/graceful_shutdown_filter_test.go b/filter/filter_impl/graceful_shutdown_filter_test.go
index fc437c3557fa452273e770d3d50678401ba3b33b..4c670933e3dcec29ad9ae7bfef250b4236ae7c54 100644
--- a/filter/filter_impl/graceful_shutdown_filter_test.go
+++ b/filter/filter_impl/graceful_shutdown_filter_test.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"net/url"
"testing"
)
@@ -53,7 +54,7 @@ func TestGenericFilter_Invoke(t *testing.T) {
assert.Equal(t, extension.GetRejectedExecutionHandler(constant.DEFAULT_KEY),
shutdownFilter.getRejectHandler())
- result := shutdownFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
+ result := shutdownFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
@@ -64,7 +65,7 @@ func TestGenericFilter_Invoke(t *testing.T) {
shutdownFilter.shutdownConfig = providerConfig.ShutdownConfig
assert.True(t, shutdownFilter.rejectNewRequest())
- result = shutdownFilter.OnResponse(nil, protocol.NewBaseInvoker(*invokeUrl), invoc)
+ result = shutdownFilter.OnResponse(nil, nil, protocol.NewBaseInvoker(*invokeUrl), invoc)
rejectHandler := &common2.OnlyLogRejectedExecutionHandler{}
extension.SetRejectedExecutionHandler("mock", func() filter.RejectedExecutionHandler {
diff --git a/filter/filter_impl/hystrix_filter.go b/filter/filter_impl/hystrix_filter.go
index a6e07803046005b5ab31d7a02ea9e25f4b74da75..0f40d815ffbd4c199ad30cad44eb1a94e93cf916 100644
--- a/filter/filter_impl/hystrix_filter.go
+++ b/filter/filter_impl/hystrix_filter.go
@@ -17,6 +17,7 @@
package filter_impl
import (
+ "context"
"fmt"
"regexp"
"sync"
@@ -82,7 +83,7 @@ type HystrixFilter struct {
ifNewMap sync.Map
}
-func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (hf *HystrixFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
cmdName := fmt.Sprintf("%s&method=%s", invoker.GetUrl().Key(), invocation.MethodName())
@@ -115,12 +116,12 @@ func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.In
configLoadMutex.RUnlock()
if err != nil {
logger.Errorf("[Hystrix Filter]Errors occurred getting circuit for %s , will invoke without hystrix, error is: ", cmdName, err)
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
logger.Infof("[Hystrix Filter]Using hystrix filter: %s", cmdName)
var result protocol.Result
_ = hystrix.Do(cmdName, func() error {
- result = invoker.Invoke(invocation)
+ result = invoker.Invoke(ctx, invocation)
err := result.Error()
if err != nil {
result.SetError(NewHystrixFilterError(err, false))
@@ -144,7 +145,7 @@ func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.In
return result
}
-func (hf *HystrixFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (hf *HystrixFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
func GetHystrixFilterConsumer() filter.Filter {
diff --git a/filter/filter_impl/hystrix_filter_test.go b/filter/filter_impl/hystrix_filter_test.go
index 2bbc3e079e7ae563db1efa18f82423931fd5919d..894573036ae6dd9edca88e8e4cdd92e7643abcb5 100644
--- a/filter/filter_impl/hystrix_filter_test.go
+++ b/filter/filter_impl/hystrix_filter_test.go
@@ -17,6 +17,7 @@
package filter_impl
import (
+ "context"
"regexp"
"testing"
)
@@ -125,7 +126,7 @@ type testMockSuccessInvoker struct {
protocol.BaseInvoker
}
-func (iv *testMockSuccessInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (iv *testMockSuccessInvoker) Invoke(context context.Context, invocation protocol.Invocation) protocol.Result {
return &protocol.RPCResult{
Rest: "Sucess",
Err: nil,
@@ -136,7 +137,7 @@ type testMockFailInvoker struct {
protocol.BaseInvoker
}
-func (iv *testMockFailInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (iv *testMockFailInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
return &protocol.RPCResult{
Err: errors.Errorf("exception"),
}
@@ -144,7 +145,7 @@ func (iv *testMockFailInvoker) Invoke(invocation protocol.Invocation) protocol.R
func TestHystrixFilter_Invoke_Success(t *testing.T) {
hf := &HystrixFilter{}
- result := hf.Invoke(&testMockSuccessInvoker{}, &invocation.RPCInvocation{})
+ result := hf.Invoke(context.Background(), &testMockSuccessInvoker{}, &invocation.RPCInvocation{})
assert.NotNil(t, result)
assert.NoError(t, result.Error())
assert.NotNil(t, result.Result())
@@ -152,7 +153,7 @@ func TestHystrixFilter_Invoke_Success(t *testing.T) {
func TestHystrixFilter_Invoke_Fail(t *testing.T) {
hf := &HystrixFilter{}
- result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{})
+ result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{})
assert.NotNil(t, result)
assert.Error(t, result.Error())
}
@@ -164,7 +165,7 @@ func TestHystricFilter_Invoke_CircuitBreak(t *testing.T) {
resChan := make(chan protocol.Result, 50)
for i := 0; i < 50; i++ {
go func() {
- result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{})
+ result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{})
resChan <- result
}()
}
@@ -189,7 +190,7 @@ func TestHystricFilter_Invoke_CircuitBreak_Omit_Exception(t *testing.T) {
resChan := make(chan protocol.Result, 50)
for i := 0; i < 50; i++ {
go func() {
- result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{})
+ result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{})
resChan <- result
}()
}
diff --git a/filter/filter_impl/token_filter.go b/filter/filter_impl/token_filter.go
index 180f3e6631a2fd0b317af3a4addd8d77287d82d5..702ee33d4d2e9756ab3b4dbb4bfc9b7c42907080 100644
--- a/filter/filter_impl/token_filter.go
+++ b/filter/filter_impl/token_filter.go
@@ -18,6 +18,7 @@ limitations under the License.
package filter_impl
import (
+ "context"
"strings"
)
@@ -42,22 +43,22 @@ func init() {
type TokenFilter struct{}
-func (tf *TokenFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (tf *TokenFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
invokerTkn := invoker.GetUrl().GetParam(constant.TOKEN_KEY, "")
if len(invokerTkn) > 0 {
attachs := invocation.Attachments()
remoteTkn, exist := attachs[constant.TOKEN_KEY]
if exist && strings.EqualFold(invokerTkn, remoteTkn) {
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
return &protocol.RPCResult{Err: perrors.Errorf("Invalid token! Forbid invoke remote service %v method %s ",
invoker, invocation.MethodName())}
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (tf *TokenFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (tf *TokenFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/filter/filter_impl/token_filter_test.go b/filter/filter_impl/token_filter_test.go
index 675d33dc7d401b04f59037c1ec2eb44c8d6ecbe4..672082c729bc371a40573a66d13bc57a7024186b 100644
--- a/filter/filter_impl/token_filter_test.go
+++ b/filter/filter_impl/token_filter_test.go
@@ -18,6 +18,7 @@ limitations under the License.
package filter_impl
import (
+ "context"
"net/url"
"testing"
)
@@ -41,8 +42,10 @@ func TestTokenFilter_Invoke(t *testing.T) {
common.WithParamsValue(constant.TOKEN_KEY, "ori_key"))
attch := make(map[string]string, 0)
attch[constant.TOKEN_KEY] = "ori_key"
- result := filter.Invoke(protocol.NewBaseInvoker(*url),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := filter.Invoke(context.Background(),
+ protocol.NewBaseInvoker(*url),
+ invocation.NewRPCInvocation("MethodName",
+ []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
@@ -53,8 +56,7 @@ func TestTokenFilter_InvokeEmptyToken(t *testing.T) {
url := common.URL{}
attch := make(map[string]string, 0)
attch[constant.TOKEN_KEY] = "ori_key"
- result := filter.Invoke(protocol.NewBaseInvoker(url),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
@@ -66,8 +68,7 @@ func TestTokenFilter_InvokeEmptyAttach(t *testing.T) {
common.WithParams(url.Values{}),
common.WithParamsValue(constant.TOKEN_KEY, "ori_key"))
attch := make(map[string]string, 0)
- result := filter.Invoke(protocol.NewBaseInvoker(*url),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(*url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.NotNil(t, result.Error())
}
@@ -79,7 +80,7 @@ func TestTokenFilter_InvokeNotEqual(t *testing.T) {
common.WithParamsValue(constant.TOKEN_KEY, "ori_key"))
attch := make(map[string]string, 0)
attch[constant.TOKEN_KEY] = "err_key"
- result := filter.Invoke(protocol.NewBaseInvoker(*url),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := filter.Invoke(context.Background(),
+ protocol.NewBaseInvoker(*url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.NotNil(t, result.Error())
}
diff --git a/filter/filter_impl/tps_limit_filter.go b/filter/filter_impl/tps_limit_filter.go
index 77414a8ea70743983cadc609c875920cff525487..8852260e9e7b4b833728da97dc8f273d3e52dec7 100644
--- a/filter/filter_impl/tps_limit_filter.go
+++ b/filter/filter_impl/tps_limit_filter.go
@@ -17,6 +17,9 @@
package filter_impl
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
@@ -51,22 +54,22 @@ func init() {
type TpsLimitFilter struct {
}
-func (t TpsLimitFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (t TpsLimitFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
url := invoker.GetUrl()
tpsLimiter := url.GetParam(constant.TPS_LIMITER_KEY, "")
rejectedExeHandler := url.GetParam(constant.TPS_REJECTED_EXECUTION_HANDLER_KEY, constant.DEFAULT_KEY)
if len(tpsLimiter) > 0 {
allow := extension.GetTpsLimiter(tpsLimiter).IsAllowable(invoker.GetUrl(), invocation)
if allow {
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
logger.Errorf("The invocation was rejected due to over the tps limitation, url: %s ", url.String())
return extension.GetRejectedExecutionHandler(rejectedExeHandler).RejectedExecution(url, invocation)
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (t TpsLimitFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (t TpsLimitFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/filter/filter_impl/tps_limit_filter_test.go b/filter/filter_impl/tps_limit_filter_test.go
index 5e04804aa23c4e6e417f6bb9975a3269a2118739..cc423ae1e5f3589dd60b0c8655f1123c290f0ffc 100644
--- a/filter/filter_impl/tps_limit_filter_test.go
+++ b/filter/filter_impl/tps_limit_filter_test.go
@@ -18,6 +18,7 @@
package filter_impl
import (
+ "context"
"net/url"
"testing"
)
@@ -45,8 +46,10 @@ func TestTpsLimitFilter_Invoke_With_No_TpsLimiter(t *testing.T) {
common.WithParamsValue(constant.TPS_LIMITER_KEY, ""))
attch := make(map[string]string, 0)
- result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := tpsFilter.Invoke(context.Background(),
+ protocol.NewBaseInvoker(*invokeUrl),
+ invocation.NewRPCInvocation("MethodName",
+ []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
@@ -67,8 +70,10 @@ func TestGenericFilter_Invoke_With_Default_TpsLimiter(t *testing.T) {
common.WithParamsValue(constant.TPS_LIMITER_KEY, constant.DEFAULT_KEY))
attch := make(map[string]string, 0)
- result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := tpsFilter.Invoke(context.Background(),
+ protocol.NewBaseInvoker(*invokeUrl),
+ invocation.NewRPCInvocation("MethodName",
+ []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
@@ -96,8 +101,8 @@ func TestGenericFilter_Invoke_With_Default_TpsLimiter_Not_Allow(t *testing.T) {
common.WithParamsValue(constant.TPS_LIMITER_KEY, constant.DEFAULT_KEY))
attch := make(map[string]string, 0)
- result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl),
- invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
+ result := tpsFilter.Invoke(context.Background(),
+ protocol.NewBaseInvoker(*invokeUrl), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
diff --git a/protocol/dubbo/dubbo_invoker.go b/protocol/dubbo/dubbo_invoker.go
index 6dcf2568fa8c88a864c567486a501c2ad7feb3f7..8100fbe3a6760e456a9eecedfd39e5230dd2c797 100644
--- a/protocol/dubbo/dubbo_invoker.go
+++ b/protocol/dubbo/dubbo_invoker.go
@@ -18,6 +18,7 @@
package dubbo
import (
+ "context"
"strconv"
"sync"
)
@@ -53,7 +54,7 @@ func NewDubboInvoker(url common.URL, client *Client) *DubboInvoker {
}
}
-func (di *DubboInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
var (
err error
diff --git a/protocol/dubbo/dubbo_invoker_test.go b/protocol/dubbo/dubbo_invoker_test.go
index 8a032d0ca9536e53dc18c43f8e699212f7c30ec4..e360d57b8cdd61674d35a665e8ee85e03421cc8f 100644
--- a/protocol/dubbo/dubbo_invoker_test.go
+++ b/protocol/dubbo/dubbo_invoker_test.go
@@ -18,6 +18,7 @@
package dubbo
import (
+ "context"
"sync"
"testing"
"time"
@@ -53,14 +54,14 @@ func TestDubboInvoker_Invoke(t *testing.T) {
invocation.WithReply(user), invocation.WithAttachments(map[string]string{"test_key": "test_value"}))
// Call
- res := invoker.Invoke(inv)
+ res := invoker.Invoke(context.Background(), inv)
assert.NoError(t, res.Error())
assert.Equal(t, User{Id: "1", Name: "username"}, *res.Result().(*User))
assert.Equal(t, "test_value", res.Attachments()["test_key"]) // test attachments for request/response
// CallOneway
inv.SetAttachments(constant.ASYNC_KEY, "true")
- res = invoker.Invoke(inv)
+ res = invoker.Invoke(context.Background(), inv)
assert.NoError(t, res.Error())
// AsyncCall
@@ -71,13 +72,13 @@ func TestDubboInvoker_Invoke(t *testing.T) {
assert.Equal(t, User{Id: "1", Name: "username"}, *r.Reply.(*Response).reply.(*User))
lock.Unlock()
})
- res = invoker.Invoke(inv)
+ res = invoker.Invoke(context.Background(), inv)
assert.NoError(t, res.Error())
// Err_No_Reply
inv.SetAttachments(constant.ASYNC_KEY, "false")
inv.SetReply(nil)
- res = invoker.Invoke(inv)
+ res = invoker.Invoke(context.Background(), inv)
assert.EqualError(t, res.Error(), "request need @response")
// destroy
diff --git a/protocol/dubbo/listener.go b/protocol/dubbo/listener.go
index 2e4b3999dfc08262a2cfb80f29c9a9e7bc2decf8..1ed6e9cf57f3399ce2a7a8134bad9924d6799460 100644
--- a/protocol/dubbo/listener.go
+++ b/protocol/dubbo/listener.go
@@ -18,6 +18,7 @@
package dubbo
import (
+ "context"
"fmt"
"net/url"
"sync"
@@ -258,7 +259,7 @@ func (h *RpcServerHandler) OnMessage(session getty.Session, pkg interface{}) {
args := p.Body.(map[string]interface{})["args"].([]interface{})
inv := invocation.NewRPCInvocation(p.Service.Method, args, attachments)
- result := invoker.Invoke(inv)
+ result := invoker.Invoke(context.Background(), inv)
if err := result.Error(); err != nil {
p.Header.ResponseStatus = hessian.Response_OK
p.Body = hessian.NewResponse(nil, err, result.Attachments())
diff --git a/protocol/grpc/common_test.go b/protocol/grpc/common_test.go
index 7f78bdc40d07a9089c1cf40f55803f04b39cb949..165b82fabc5703a720766b04659b158d2b3fdbdf 100644
--- a/protocol/grpc/common_test.go
+++ b/protocol/grpc/common_test.go
@@ -97,7 +97,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f
invo := invocation.NewRPCInvocation("SayHello", args, nil)
if interceptor == nil {
- result := base.GetProxyImpl().Invoke(invo)
+ result := base.GetProxyImpl().Invoke(context.Background(), invo)
return result.Result(), result.Error()
}
info := &native_grpc.UnaryServerInfo{
@@ -105,7 +105,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f
FullMethod: "/helloworld.Greeter/SayHello",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
- result := base.GetProxyImpl().Invoke(invo)
+ result := base.GetProxyImpl().Invoke(context.Background(), invo)
return result.Result(), result.Error()
}
return interceptor(ctx, in, info, handler)
diff --git a/protocol/grpc/grpc_invoker.go b/protocol/grpc/grpc_invoker.go
index b74612b896addb1ff08c3abe44198c147996a126..88149397e79aa435a6a9d41911ae0e603754534e 100644
--- a/protocol/grpc/grpc_invoker.go
+++ b/protocol/grpc/grpc_invoker.go
@@ -50,7 +50,7 @@ func NewGrpcInvoker(url common.URL, client *Client) *GrpcInvoker {
}
}
-func (gi *GrpcInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (gi *GrpcInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
var (
result protocol.RPCResult
)
diff --git a/protocol/grpc/grpc_invoker_test.go b/protocol/grpc/grpc_invoker_test.go
index 4f97e1063191692ce5f47e0d4f8242d95cc8a6fc..5d4b97051438f8404cd8fd89bcf73d24e0121868 100644
--- a/protocol/grpc/grpc_invoker_test.go
+++ b/protocol/grpc/grpc_invoker_test.go
@@ -49,7 +49,7 @@ func TestInvoke(t *testing.T) {
bizReply := &internal.HelloReply{}
invo := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("SayHello"),
invocation.WithParameterValues(args), invocation.WithReply(bizReply))
- res := invoker.Invoke(invo)
+ res := invoker.Invoke(context.Background(), invo)
assert.Nil(t, res.Error())
assert.NotNil(t, res.Result())
assert.Equal(t, "Hello request name", bizReply.Message)
diff --git a/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go b/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go
index 4ed55ab7612200d28816508e4c4fcb7de0a803c0..f5d3a49b0916050fc6b2e6373fde0b70df0a1c31 100644
--- a/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go
+++ b/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go
@@ -271,7 +271,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f
args = append(args, in)
invo := invocation.NewRPCInvocation("SayHello", args, nil)
if interceptor == nil {
- result := base.GetProxyImpl().Invoke(invo)
+ result := base.GetProxyImpl().Invoke(context.Background(), invo)
return result.Result(), result.Error()
}
info := &grpc.UnaryServerInfo{
@@ -279,7 +279,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f
FullMethod: "/main.Greeter/SayHello",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
- result := base.GetProxyImpl().Invoke(invo)
+ result := base.GetProxyImpl().Invoke(context.Background(), invo)
return result.Result(), result.Error()
}
return interceptor(ctx, in, info, handler)
diff --git a/protocol/invoker.go b/protocol/invoker.go
index f5d41a09ad2778c12c7e5e68167a4d0acc9e3f4c..a1cf6264ae2b9f631b1bb12f88e8378ad5857919 100644
--- a/protocol/invoker.go
+++ b/protocol/invoker.go
@@ -17,6 +17,9 @@
package protocol
+import (
+ "context"
+)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/logger"
@@ -26,7 +29,7 @@ import (
// Extension - Invoker
type Invoker interface {
common.Node
- Invoke(Invocation) Result
+ Invoke(context.Context, Invocation) Result
}
/////////////////////////////
@@ -59,7 +62,7 @@ func (bi *BaseInvoker) IsDestroyed() bool {
return bi.destroyed
}
-func (bi *BaseInvoker) Invoke(invocation Invocation) Result {
+func (bi *BaseInvoker) Invoke(context context.Context, invocation Invocation) Result {
return &RPCResult{}
}
diff --git a/protocol/jsonrpc/jsonrpc_invoker.go b/protocol/jsonrpc/jsonrpc_invoker.go
index 2c130e0d7617e96a1724edc5b63f8e66f251446e..07e09b07c9e55e804a15cc79587b020230f087fa 100644
--- a/protocol/jsonrpc/jsonrpc_invoker.go
+++ b/protocol/jsonrpc/jsonrpc_invoker.go
@@ -41,7 +41,7 @@ func NewJsonrpcInvoker(url common.URL, client *HTTPClient) *JsonrpcInvoker {
}
}
-func (ji *JsonrpcInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (ji *JsonrpcInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
var (
result protocol.RPCResult
@@ -50,12 +50,12 @@ func (ji *JsonrpcInvoker) Invoke(invocation protocol.Invocation) protocol.Result
inv := invocation.(*invocation_impl.RPCInvocation)
url := ji.GetUrl()
req := ji.client.NewRequest(url, inv.MethodName(), inv.Arguments())
- ctx := context.WithValue(context.Background(), constant.DUBBOGO_CTX_KEY, map[string]string{
+ ctxNew := context.WithValue(context.Background(), constant.DUBBOGO_CTX_KEY, map[string]string{
"X-Proxy-Id": "dubbogo",
"X-Services": url.Path,
"X-Method": inv.MethodName(),
})
- result.Err = ji.client.Call(ctx, url, req, inv.Reply())
+ result.Err = ji.client.Call(ctxNew, url, req, inv.Reply())
if result.Err == nil {
result.Rest = inv.Reply()
}
diff --git a/protocol/jsonrpc/jsonrpc_invoker_test.go b/protocol/jsonrpc/jsonrpc_invoker_test.go
index 8c910339858f4960ad0e394ae6271863d7654adc..9eed22e67155f1b0915cbb398bcef55962258407 100644
--- a/protocol/jsonrpc/jsonrpc_invoker_test.go
+++ b/protocol/jsonrpc/jsonrpc_invoker_test.go
@@ -60,7 +60,7 @@ func TestJsonrpcInvoker_Invoke(t *testing.T) {
jsonInvoker := NewJsonrpcInvoker(url, client)
user := &User{}
- res := jsonInvoker.Invoke(invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"), invocation.WithArguments([]interface{}{"1", "username"}),
+ res := jsonInvoker.Invoke(context.Background(), invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"), invocation.WithArguments([]interface{}{"1", "username"}),
invocation.WithReply(user)))
assert.NoError(t, res.Error())
diff --git a/protocol/jsonrpc/server.go b/protocol/jsonrpc/server.go
index dc79e4a36bd6cce575d50588d11b003cb8e25abe..3decc8867474d99da8f50584e60375af0e54b225 100644
--- a/protocol/jsonrpc/server.go
+++ b/protocol/jsonrpc/server.go
@@ -330,7 +330,7 @@ func serveRequest(ctx context.Context,
exporter, _ := jsonrpcProtocol.ExporterMap().Load(path)
invoker := exporter.(*JsonrpcExporter).GetInvoker()
if invoker != nil {
- result := invoker.Invoke(invocation.NewRPCInvocation(methodName, args, map[string]string{
+ result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodName, args, map[string]string{
constant.PATH_KEY: path,
constant.VERSION_KEY: codec.req.Version,
}))
diff --git a/protocol/mock/mock_invoker.go b/protocol/mock/mock_invoker.go
index c509cef054f5a23fe504486e01d7cc0e8772711d..5c5b476b7b07f6c41a74a7ec8f51648aff84b1a3 100644
--- a/protocol/mock/mock_invoker.go
+++ b/protocol/mock/mock_invoker.go
@@ -21,6 +21,7 @@
package mock
import (
+ "context"
"reflect"
)
@@ -91,7 +92,7 @@ func (mr *MockInvokerMockRecorder) Destroy() *gomock.Call {
}
// Invoke mocks base method
-func (m *MockInvoker) Invoke(arg0 protocol.Invocation) protocol.Result {
+func (m *MockInvoker) Invoke(ctx context.Context, arg0 protocol.Invocation) protocol.Result {
ret := m.ctrl.Call(m, "Invoke", arg0)
ret0, _ := ret[0].(protocol.Result)
return ret0
diff --git a/protocol/protocolwrapper/protocol_filter_wrapper.go b/protocol/protocolwrapper/protocol_filter_wrapper.go
index 7c58fabea3cccf5a39e1622fedd4a3a297e05983..33ea38201251df3abc6639b416200611cc993e56 100644
--- a/protocol/protocolwrapper/protocol_filter_wrapper.go
+++ b/protocol/protocolwrapper/protocol_filter_wrapper.go
@@ -18,6 +18,7 @@
package protocolwrapper
import (
+ "context"
"strings"
)
@@ -102,9 +103,9 @@ func (fi *FilterInvoker) IsAvailable() bool {
return fi.invoker.IsAvailable()
}
-func (fi *FilterInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
- result := fi.filter.Invoke(fi.next, invocation)
- return fi.filter.OnResponse(result, fi.invoker, invocation)
+func (fi *FilterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
+ result := fi.filter.Invoke(ctx, fi.next, invocation)
+ return fi.filter.OnResponse(ctx, result, fi.invoker, invocation)
}
func (fi *FilterInvoker) Destroy() {
diff --git a/protocol/protocolwrapper/protocol_filter_wrapper_test.go b/protocol/protocolwrapper/protocol_filter_wrapper_test.go
index dc376313549c24da1cc6cb64a42e8445ef4fe346..8491d57462d47d6af72040d41b78dcb30e6da697 100644
--- a/protocol/protocolwrapper/protocol_filter_wrapper_test.go
+++ b/protocol/protocolwrapper/protocol_filter_wrapper_test.go
@@ -18,6 +18,7 @@
package protocolwrapper
import (
+ "context"
"net/url"
"testing"
)
@@ -66,7 +67,7 @@ func init() {
type EchoFilterForTest struct{}
-func (ef *EchoFilterForTest) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *EchoFilterForTest) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking echo filter.")
logger.Debugf("%v,%v", invocation.MethodName(), len(invocation.Arguments()))
if invocation.MethodName() == constant.ECHO && len(invocation.Arguments()) == 1 {
@@ -75,10 +76,10 @@ func (ef *EchoFilterForTest) Invoke(invoker protocol.Invoker, invocation protoco
}
}
- return invoker.Invoke(invocation)
+ return invoker.Invoke(ctx, invocation)
}
-func (ef *EchoFilterForTest) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
+func (ef *EchoFilterForTest) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
diff --git a/registry/protocol/protocol.go b/registry/protocol/protocol.go
index 8655312a4eb508dfe5c910855ba5f3e3aacd666e..9e6b9999b976d5cfcc76560731f383a52c2642f4 100644
--- a/registry/protocol/protocol.go
+++ b/registry/protocol/protocol.go
@@ -18,6 +18,7 @@
package protocol
import (
+ "context"
"strings"
"sync"
)
@@ -356,10 +357,10 @@ func newWrappedInvoker(invoker protocol.Invoker, url *common.URL) *wrappedInvoke
}
}
-func (ivk *wrappedInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
+func (ivk *wrappedInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
// get right url
ivk.invoker.(*proxy_factory.ProxyInvoker).BaseInvoker = *protocol.NewBaseInvoker(ivk.GetUrl())
- return ivk.invoker.Invoke(invocation)
+ return ivk.invoker.Invoke(ctx, invocation)
}
type providerConfigurationListener struct {