diff --git a/cluster/cluster_impl/available_cluster_invoker.go b/cluster/cluster_impl/available_cluster_invoker.go index c59c0702c216fe5c58d190a023322aaa00ac9c17..bc6705c8156aaeb6a0a52e08b1aa539e179013ca 100644 --- a/cluster/cluster_impl/available_cluster_invoker.go +++ b/cluster/cluster_impl/available_cluster_invoker.go @@ -18,6 +18,7 @@ limitations under the License. package cluster_impl import ( + "context" "fmt" ) @@ -40,7 +41,7 @@ func NewAvailableClusterInvoker(directory cluster.Directory) protocol.Invoker { } } -func (invoker *availableClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *availableClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) err := invoker.checkInvokers(invokers, invocation) if err != nil { @@ -54,7 +55,7 @@ func (invoker *availableClusterInvoker) Invoke(invocation protocol.Invocation) p for _, ivk := range invokers { if ivk.IsAvailable() { - return ivk.Invoke(invocation) + return ivk.Invoke(ctx, invocation) } } return &protocol.RPCResult{Err: errors.New(fmt.Sprintf("no provider available in %v", invokers))} diff --git a/cluster/cluster_impl/available_cluster_invoker_test.go b/cluster/cluster_impl/available_cluster_invoker_test.go index 04032a7f24dec0e73acb15921f753921391f1515..de04db1da4e8e6df12960b1a2ee81b0044379d6f 100644 --- a/cluster/cluster_impl/available_cluster_invoker_test.go +++ b/cluster/cluster_impl/available_cluster_invoker_test.go @@ -66,7 +66,7 @@ func TestAvailableClusterInvokerSuccess(t *testing.T) { invoker.EXPECT().IsAvailable().Return(true) invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Equal(t, mockResult, result) } @@ -80,7 +80,7 @@ func TestAvailableClusterInvokerNoAvail(t *testing.T) { invoker.EXPECT().IsAvailable().Return(false) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.TODO(), &invocation.RPCInvocation{}) assert.NotNil(t, result.Error()) assert.True(t, strings.Contains(result.Error().Error(), "no provider available")) diff --git a/cluster/cluster_impl/broadcast_cluster_invoker.go b/cluster/cluster_impl/broadcast_cluster_invoker.go index 238df0acfa7fb946e38bfbfd490bce7c0bb34e60..1b49e9a115252d4eca94bedd557ebcc21fee4cc7 100644 --- a/cluster/cluster_impl/broadcast_cluster_invoker.go +++ b/cluster/cluster_impl/broadcast_cluster_invoker.go @@ -17,6 +17,9 @@ limitations under the License. package cluster_impl +import ( + "context" +) import ( "github.com/apache/dubbo-go/cluster" "github.com/apache/dubbo-go/common/logger" @@ -33,7 +36,7 @@ func newBroadcastClusterInvoker(directory cluster.Directory) protocol.Invoker { } } -func (invoker *broadcastClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *broadcastClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) err := invoker.checkInvokers(invokers, invocation) if err != nil { @@ -46,7 +49,7 @@ func (invoker *broadcastClusterInvoker) Invoke(invocation protocol.Invocation) p var result protocol.Result for _, ivk := range invokers { - result = ivk.Invoke(invocation) + result = ivk.Invoke(ctx, invocation) if result.Error() != nil { logger.Warnf("broadcast invoker invoke err: %v when use invoker: %v\n", result.Error(), ivk) err = result.Error() diff --git a/cluster/cluster_impl/broadcast_cluster_invoker_test.go b/cluster/cluster_impl/broadcast_cluster_invoker_test.go index 565684a8ae25c648ff77aef71d2ced0665202fe7..b20d962e2cffb34d0a151488a1bdf63499e4de86 100644 --- a/cluster/cluster_impl/broadcast_cluster_invoker_test.go +++ b/cluster/cluster_impl/broadcast_cluster_invoker_test.go @@ -74,7 +74,7 @@ func Test_BroadcastInvokeSuccess(t *testing.T) { clusterInvoker := registerBroadcast(t, invokers...) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Equal(t, mockResult, result) } @@ -104,6 +104,6 @@ func Test_BroadcastInvokeFailed(t *testing.T) { clusterInvoker := registerBroadcast(t, invokers...) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Equal(t, mockFailedResult.Err, result.Error()) } diff --git a/cluster/cluster_impl/failback_cluster_invoker.go b/cluster/cluster_impl/failback_cluster_invoker.go index c8dbeda09f62e88b51dd4ad2b6b09d5715f0b224..46b0ff634e56c45223a5aeb5566b9b1401518960 100644 --- a/cluster/cluster_impl/failback_cluster_invoker.go +++ b/cluster/cluster_impl/failback_cluster_invoker.go @@ -18,6 +18,7 @@ package cluster_impl import ( + "context" "strconv" "sync" "time" @@ -71,7 +72,7 @@ func newFailbackClusterInvoker(directory cluster.Directory) protocol.Invoker { return invoker } -func (invoker *failbackClusterInvoker) process() { +func (invoker *failbackClusterInvoker) process(ctx context.Context) { invoker.ticker = time.NewTicker(time.Second * 1) for range invoker.ticker.C { // check each timeout task and re-run @@ -102,7 +103,7 @@ func (invoker *failbackClusterInvoker) process() { retryInvoker := invoker.doSelect(retryTask.loadbalance, retryTask.invocation, retryTask.invokers, invoked) var result protocol.Result - result = retryInvoker.Invoke(retryTask.invocation) + result = retryInvoker.Invoke(ctx, retryTask.invocation) if result.Error() != nil { retryTask.lastInvoker = retryInvoker invoker.checkRetry(retryTask, result.Error()) @@ -126,7 +127,7 @@ func (invoker *failbackClusterInvoker) checkRetry(retryTask *retryTimerTask, err } } -func (invoker *failbackClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *failbackClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) err := invoker.checkInvokers(invokers, invocation) if err != nil { @@ -150,11 +151,11 @@ func (invoker *failbackClusterInvoker) Invoke(invocation protocol.Invocation) pr ivk := invoker.doSelect(loadbalance, invocation, invokers, invoked) //DO INVOKE - result = ivk.Invoke(invocation) + result = ivk.Invoke(ctx, invocation) if result.Error() != nil { invoker.once.Do(func() { invoker.taskList = queue.New(invoker.failbackTasks) - go invoker.process() + go invoker.process(ctx) }) taskLen := invoker.taskList.Len() diff --git a/cluster/cluster_impl/failback_cluster_test.go b/cluster/cluster_impl/failback_cluster_test.go index 1d2266cabebf591b09188fb723f02126a3f1e0ec..895077922a88abc05416e58459205b449831ac56 100644 --- a/cluster/cluster_impl/failback_cluster_test.go +++ b/cluster/cluster_impl/failback_cluster_test.go @@ -72,7 +72,7 @@ func Test_FailbackSuceess(t *testing.T) { mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Equal(t, mockResult, result) } @@ -102,7 +102,7 @@ func Test_FailbackRetryOneSuccess(t *testing.T) { return mockSuccResult }) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) assert.Equal(t, 0, len(result.Attachments())) @@ -150,7 +150,7 @@ func Test_FailbackRetryFailed(t *testing.T) { } // first call should failed. - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) assert.Equal(t, 0, len(result.Attachments())) @@ -192,7 +192,7 @@ func Test_FailbackRetryFailed10Times(t *testing.T) { }).Times(10) for i := 0; i < 10; i++ { - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) assert.Equal(t, 0, len(result.Attachments())) @@ -222,14 +222,14 @@ func Test_FailbackOutOfLimit(t *testing.T) { invoker.EXPECT().Invoke(gomock.Any()).Return(mockFailedResult).Times(11) // reached limit - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) assert.Equal(t, 0, len(result.Attachments())) // all will be out of limit for i := 0; i < 10; i++ { - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) assert.Equal(t, 0, len(result.Attachments())) diff --git a/cluster/cluster_impl/failfast_cluster_invoker.go b/cluster/cluster_impl/failfast_cluster_invoker.go index 734ea2c6cb19bf54a338a76a10c9cfcc59d3954b..49e7c7689f5a19a36154e092a6a83cc39da604ba 100644 --- a/cluster/cluster_impl/failfast_cluster_invoker.go +++ b/cluster/cluster_impl/failfast_cluster_invoker.go @@ -17,6 +17,9 @@ limitations under the License. package cluster_impl +import ( + "context" +) import ( "github.com/apache/dubbo-go/cluster" "github.com/apache/dubbo-go/protocol" @@ -32,7 +35,7 @@ func newFailFastClusterInvoker(directory cluster.Directory) protocol.Invoker { } } -func (invoker *failfastClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *failfastClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) err := invoker.checkInvokers(invokers, invocation) if err != nil { @@ -47,5 +50,5 @@ func (invoker *failfastClusterInvoker) Invoke(invocation protocol.Invocation) pr } ivk := invoker.doSelect(loadbalance, invocation, invokers, nil) - return ivk.Invoke(invocation) + return ivk.Invoke(ctx, invocation) } diff --git a/cluster/cluster_impl/failfast_cluster_test.go b/cluster/cluster_impl/failfast_cluster_test.go index 1a4342e6c2b74fd6b1359646eeb463bb6dc17d0a..9585f03b7fa8f45a19c7c47e04dcd57cc1e4bb11 100644 --- a/cluster/cluster_impl/failfast_cluster_test.go +++ b/cluster/cluster_impl/failfast_cluster_test.go @@ -69,7 +69,7 @@ func Test_FailfastInvokeSuccess(t *testing.T) { mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NoError(t, result.Error()) res := result.Result().(rest) @@ -89,7 +89,7 @@ func Test_FailfastInvokeFail(t *testing.T) { mockResult := &protocol.RPCResult{Err: perrors.New("error")} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NotNil(t, result.Error()) assert.Equal(t, "error", result.Error().Error()) diff --git a/cluster/cluster_impl/failover_cluster_invoker.go b/cluster/cluster_impl/failover_cluster_invoker.go index dcce7369931a11f31fb6b9e4e1a6c0aa0ec7cdf6..6178a05a1226ba629d2456ad6886b02a26288e45 100644 --- a/cluster/cluster_impl/failover_cluster_invoker.go +++ b/cluster/cluster_impl/failover_cluster_invoker.go @@ -18,6 +18,7 @@ package cluster_impl import ( + "context" "strconv" ) @@ -43,7 +44,7 @@ func newFailoverClusterInvoker(directory cluster.Directory) protocol.Invoker { } } -func (invoker *failoverClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *failoverClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) err := invoker.checkInvokers(invokers, invocation) @@ -95,7 +96,7 @@ func (invoker *failoverClusterInvoker) Invoke(invocation protocol.Invocation) pr } invoked = append(invoked, ivk) //DO INVOKE - result = ivk.Invoke(invocation) + result = ivk.Invoke(ctx, invocation) if result.Error() != nil { providers = append(providers, ivk.GetUrl().Key()) continue diff --git a/cluster/cluster_impl/failover_cluster_test.go b/cluster/cluster_impl/failover_cluster_test.go index 78b799320dfa58d55e531c658ec5eb0e69306cff..7bde83ea66a49f9317732ec46da0f11800f846eb 100644 --- a/cluster/cluster_impl/failover_cluster_test.go +++ b/cluster/cluster_impl/failover_cluster_test.go @@ -77,7 +77,7 @@ type rest struct { success bool } -func (bi *MockInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (bi *MockInvoker) Invoke(c context.Context, invocation protocol.Invocation) protocol.Result { count++ var success bool var err error = nil @@ -112,9 +112,9 @@ func normalInvoke(t *testing.T, successCount int, urlParam url.Values, invocatio staticDir := directory.NewStaticDirectory(invokers) clusterInvoker := failoverCluster.Join(staticDir) if len(invocations) > 0 { - return clusterInvoker.Invoke(invocations[0]) + return clusterInvoker.Invoke(context.Background(), invocations[0]) } - return clusterInvoker.Invoke(&invocation.RPCInvocation{}) + return clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) } func Test_FailoverInvokeSuccess(t *testing.T) { urlParams := url.Values{} @@ -155,14 +155,14 @@ func Test_FailoverDestroy(t *testing.T) { invokers := []protocol.Invoker{} for i := 0; i < 10; i++ { - url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url, _ := common.NewURL(context.Background(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) invokers = append(invokers, NewMockInvoker(url, 1)) } staticDir := directory.NewStaticDirectory(invokers) clusterInvoker := failoverCluster.Join(staticDir) assert.Equal(t, true, clusterInvoker.IsAvailable()) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NoError(t, result.Error()) count = 0 clusterInvoker.Destroy() diff --git a/cluster/cluster_impl/failsafe_cluster_invoker.go b/cluster/cluster_impl/failsafe_cluster_invoker.go index b95f997fef87cf466f07c4e506e41758e7998e52..4d8fe27719eb71fa287fe4142d8e92ca17acfba4 100644 --- a/cluster/cluster_impl/failsafe_cluster_invoker.go +++ b/cluster/cluster_impl/failsafe_cluster_invoker.go @@ -17,6 +17,9 @@ package cluster_impl +import ( + "context" +) import ( "github.com/apache/dubbo-go/cluster" "github.com/apache/dubbo-go/common/constant" @@ -42,7 +45,7 @@ func newFailsafeClusterInvoker(directory cluster.Directory) protocol.Invoker { } } -func (invoker *failsafeClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *failsafeClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) err := invoker.checkInvokers(invokers, invocation) @@ -65,7 +68,7 @@ func (invoker *failsafeClusterInvoker) Invoke(invocation protocol.Invocation) pr ivk := invoker.doSelect(loadbalance, invocation, invokers, invoked) //DO INVOKE - result = ivk.Invoke(invocation) + result = ivk.Invoke(ctx, invocation) if result.Error() != nil { // ignore logger.Errorf("Failsafe ignore exception: %v.\n", result.Error().Error()) diff --git a/cluster/cluster_impl/failsafe_cluster_test.go b/cluster/cluster_impl/failsafe_cluster_test.go index 7888b97c3a02bd4679f8ec5267637b8d2a7c12e4..930b4bb16628e2b363659a65fc174543b7f2cf6e 100644 --- a/cluster/cluster_impl/failsafe_cluster_test.go +++ b/cluster/cluster_impl/failsafe_cluster_test.go @@ -69,7 +69,7 @@ func Test_FailSafeInvokeSuccess(t *testing.T) { mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NoError(t, result.Error()) res := result.Result().(rest) @@ -88,7 +88,7 @@ func Test_FailSafeInvokeFail(t *testing.T) { mockResult := &protocol.RPCResult{Err: perrors.New("error")} invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NoError(t, result.Error()) assert.Nil(t, result.Result()) diff --git a/cluster/cluster_impl/forking_cluster_invoker.go b/cluster/cluster_impl/forking_cluster_invoker.go index d6cf2f4b89ab4f322fa758deecae90c60742ef49..c830079ff6d3c29c3385eda289782f5e52877be2 100644 --- a/cluster/cluster_impl/forking_cluster_invoker.go +++ b/cluster/cluster_impl/forking_cluster_invoker.go @@ -18,6 +18,7 @@ limitations under the License. package cluster_impl import ( + "context" "errors" "fmt" "time" @@ -44,7 +45,7 @@ func newForkingClusterInvoker(directory cluster.Directory) protocol.Invoker { } } -func (invoker *forkingClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *forkingClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { err := invoker.checkWhetherDestroyed() if err != nil { return &protocol.RPCResult{Err: err} @@ -75,7 +76,7 @@ func (invoker *forkingClusterInvoker) Invoke(invocation protocol.Invocation) pro resultQ := queue.New(1) for _, ivk := range selected { go func(k protocol.Invoker) { - result := k.Invoke(invocation) + result := k.Invoke(ctx, invocation) err := resultQ.Put(result) if err != nil { logger.Errorf("resultQ put failed with exception: %v.\n", err) diff --git a/cluster/cluster_impl/forking_cluster_test.go b/cluster/cluster_impl/forking_cluster_test.go index 8603f8aedc4e28a3a4ca2f115355debc1a5ecc62..d819781eb23631e6b8eef76e5bdf7d7837f43d53 100644 --- a/cluster/cluster_impl/forking_cluster_test.go +++ b/cluster/cluster_impl/forking_cluster_test.go @@ -87,7 +87,7 @@ func Test_ForkingInvokeSuccess(t *testing.T) { clusterInvoker := registerForking(t, invokers...) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Equal(t, mockResult, result) wg.Wait() } @@ -117,7 +117,7 @@ func Test_ForkingInvokeTimeout(t *testing.T) { clusterInvoker := registerForking(t, invokers...) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NotNil(t, result) assert.NotNil(t, result.Error()) wg.Wait() @@ -156,7 +156,7 @@ func Test_ForkingInvokeHalfTimeout(t *testing.T) { clusterInvoker := registerForking(t, invokers...) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.Equal(t, mockResult, result) wg.Wait() } diff --git a/cluster/cluster_impl/registry_aware_cluster_invoker.go b/cluster/cluster_impl/registry_aware_cluster_invoker.go index 5785c02489f95168d5419f0087f38b07c851a4a3..cded5bf16432e6b0c590e15b81c28369889a5f88 100644 --- a/cluster/cluster_impl/registry_aware_cluster_invoker.go +++ b/cluster/cluster_impl/registry_aware_cluster_invoker.go @@ -17,6 +17,9 @@ package cluster_impl +import ( + "context" +) import ( "github.com/apache/dubbo-go/cluster" "github.com/apache/dubbo-go/common/constant" @@ -33,19 +36,19 @@ func newRegistryAwareClusterInvoker(directory cluster.Directory) protocol.Invoke } } -func (invoker *registryAwareClusterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (invoker *registryAwareClusterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { invokers := invoker.directory.List(invocation) //First, pick the invoker (XXXClusterInvoker) that comes from the local registry, distinguish by a 'default' key. for _, invoker := range invokers { if invoker.IsAvailable() && invoker.GetUrl().GetParam(constant.REGISTRY_DEFAULT_KEY, "false") == "true" { - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } } //If none of the invokers has a local signal, pick the first one available. for _, invoker := range invokers { if invoker.IsAvailable() { - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } } return nil diff --git a/cluster/cluster_impl/registry_aware_cluster_test.go b/cluster/cluster_impl/registry_aware_cluster_test.go index 4ae15cc5066c70646dee66cf4ef601202653cb07..7f916c1aaa5609beb3d818e08f5b0950c3273e6d 100644 --- a/cluster/cluster_impl/registry_aware_cluster_test.go +++ b/cluster/cluster_impl/registry_aware_cluster_test.go @@ -39,13 +39,13 @@ func Test_RegAwareInvokeSuccess(t *testing.T) { invokers := []protocol.Invoker{} for i := 0; i < 10; i++ { - url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url, _ := common.NewURL(context.Background(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) invokers = append(invokers, NewMockInvoker(url, 1)) } staticDir := directory.NewStaticDirectory(invokers) clusterInvoker := regAwareCluster.Join(staticDir) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NoError(t, result.Error()) count = 0 } @@ -55,14 +55,14 @@ func TestDestroy(t *testing.T) { invokers := []protocol.Invoker{} for i := 0; i < 10; i++ { - url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) + url, _ := common.NewURL(context.Background(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i)) invokers = append(invokers, NewMockInvoker(url, 1)) } staticDir := directory.NewStaticDirectory(invokers) clusterInvoker := regAwareCluster.Join(staticDir) assert.Equal(t, true, clusterInvoker.IsAvailable()) - result := clusterInvoker.Invoke(&invocation.RPCInvocation{}) + result := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) assert.NoError(t, result.Error()) count = 0 clusterInvoker.Destroy() diff --git a/cluster/router/condition_router_test.go b/cluster/router/condition_router_test.go index 7d8b0d88cab688e6ea10d1562a27de4609d51f58..7acbdabc9b6c1976664fce7596ce22c187f48068 100644 --- a/cluster/router/condition_router_test.go +++ b/cluster/router/condition_router_test.go @@ -93,7 +93,7 @@ type rest struct { var count int -func (bi *MockInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (bi *MockInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { count++ var success bool var err error = nil diff --git a/common/proxy/proxy.go b/common/proxy/proxy.go index 0b5e2860495604207dc8e4a384225e2fca47df1a..8f3f8916100a1b7a3fe25a64ba7e5149cd3282dd 100644 --- a/common/proxy/proxy.go +++ b/common/proxy/proxy.go @@ -18,6 +18,7 @@ package proxy import ( + "context" "reflect" "sync" ) @@ -129,7 +130,7 @@ func (p *Proxy) Implement(v common.RPCService) { inv.SetAttachments(k, value) } - result := p.invoke.Invoke(inv) + result := p.invoke.Invoke(context.Background(), inv) err = result.Error() logger.Infof("[makeDubboCallProxy] result: %v, err: %v", result.Result(), err) diff --git a/common/proxy/proxy_factory/default.go b/common/proxy/proxy_factory/default.go index 06824fdc1e27cde5e1905be3277451dd4395049c..1f2f80e52abf7a262c84f15c7278946eded9ab26 100644 --- a/common/proxy/proxy_factory/default.go +++ b/common/proxy/proxy_factory/default.go @@ -18,6 +18,7 @@ package proxy_factory import ( + "context" "reflect" "strings" ) @@ -75,7 +76,7 @@ type ProxyInvoker struct { protocol.BaseInvoker } -func (pi *ProxyInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (pi *ProxyInvoker) Invoke(context context.Context, invocation protocol.Invocation) protocol.Result { result := &protocol.RPCResult{} result.SetAttachments(invocation.Attachments()) diff --git a/filter/filter.go b/filter/filter.go index 5bd78998a76a1b0e8af99b0b3f0d7e6c103bb794..6c9e4455476b42d97718b5364d9687ac9671f687 100644 --- a/filter/filter.go +++ b/filter/filter.go @@ -17,12 +17,15 @@ package filter +import ( + "context" +) import ( "github.com/apache/dubbo-go/protocol" ) // Extension - Filter type Filter interface { - Invoke(protocol.Invoker, protocol.Invocation) protocol.Result - OnResponse(protocol.Result, protocol.Invoker, protocol.Invocation) protocol.Result + Invoke(context.Context, protocol.Invoker, protocol.Invocation) protocol.Result + OnResponse(context.Context, protocol.Result, protocol.Invoker, protocol.Invocation) protocol.Result } diff --git a/filter/filter_impl/access_log_filter.go b/filter/filter_impl/access_log_filter.go index a1b022f27edef4a3bdb84c6117364394cd72aefe..468393ba5be0c0991b9ab218ebc440d699382c20 100644 --- a/filter/filter_impl/access_log_filter.go +++ b/filter/filter_impl/access_log_filter.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "os" "reflect" "strings" @@ -66,13 +67,13 @@ type AccessLogFilter struct { logChan chan AccessLogData } -func (ef *AccessLogFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *AccessLogFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { accessLog := invoker.GetUrl().GetParam(constant.ACCESS_LOG_KEY, "") if len(accessLog) > 0 { accessLogData := AccessLogData{data: ef.buildAccessLogData(invoker, invocation), accessLog: accessLog} ef.logIntoChannel(accessLogData) } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } // it won't block the invocation @@ -119,7 +120,7 @@ func (ef *AccessLogFilter) buildAccessLogData(invoker protocol.Invoker, invocati return dataMap } -func (ef *AccessLogFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *AccessLogFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/filter/filter_impl/access_log_filter_test.go b/filter/filter_impl/access_log_filter_test.go index 5076962486da1ca40c4bdf6d7ba4b75a05bb0f92..14b9166b0fc486638c77388c76b49423a8d4a83e 100644 --- a/filter/filter_impl/access_log_filter_test.go +++ b/filter/filter_impl/access_log_filter_test.go @@ -49,7 +49,7 @@ func TestAccessLogFilter_Invoke_Not_Config(t *testing.T) { inv := invocation.NewRPCInvocation("MethodName", []interface{}{"OK", "Hello"}, attach) accessLogFilter := GetAccessLogFilter() - result := accessLogFilter.Invoke(invoker, inv) + result := accessLogFilter.Invoke(context.Background(), invoker, inv) assert.Nil(t, result.Error()) } @@ -70,13 +70,13 @@ func TestAccessLogFilter_Invoke_Default_Config(t *testing.T) { inv := invocation.NewRPCInvocation("MethodName", []interface{}{"OK", "Hello"}, attach) accessLogFilter := GetAccessLogFilter() - result := accessLogFilter.Invoke(invoker, inv) + result := accessLogFilter.Invoke(context.Background(), invoker, inv) assert.Nil(t, result.Error()) } func TestAccessLogFilter_OnResponse(t *testing.T) { result := &protocol.RPCResult{} accessLogFilter := GetAccessLogFilter() - response := accessLogFilter.OnResponse(result, nil, nil) + response := accessLogFilter.OnResponse(nil, result, nil, nil) assert.Equal(t, result, response) } diff --git a/filter/filter_impl/active_filter.go b/filter/filter_impl/active_filter.go index e23e4dde74fdeb7f56c5ccad9caa3202e92882a4..cc46fc9d8624f6e756ccfe5c491c3177450e10b5 100644 --- a/filter/filter_impl/active_filter.go +++ b/filter/filter_impl/active_filter.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "strconv" ) @@ -41,15 +42,15 @@ func init() { type ActiveFilter struct { } -func (ef *ActiveFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *ActiveFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { logger.Infof("invoking active filter. %v,%v", invocation.MethodName(), len(invocation.Arguments())) invocation.(*invocation2.RPCInvocation).SetAttachments(dubboInvokeStartTime, strconv.FormatInt(protocol.CurrentTimeMillis(), 10)) protocol.BeginCount(invoker.GetUrl(), invocation.MethodName()) - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (ef *ActiveFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *ActiveFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { startTime, err := strconv.ParseInt(invocation.(*invocation2.RPCInvocation).AttachmentsByKey(dubboInvokeStartTime, "0"), 10, 64) if err != nil { diff --git a/filter/filter_impl/active_filter_test.go b/filter/filter_impl/active_filter_test.go index acc4f9121641bfbbc484a711c0ea04dffeab55e3..7b355086f9d48b3fb864ed40d1cb5db999543d77 100644 --- a/filter/filter_impl/active_filter_test.go +++ b/filter/filter_impl/active_filter_test.go @@ -28,7 +28,7 @@ func TestActiveFilter_Invoke(t *testing.T) { invoker := mock.NewMockInvoker(ctrl) invoker.EXPECT().Invoke(gomock.Any()).Return(nil) invoker.EXPECT().GetUrl().Return(url).Times(1) - filter.Invoke(invoker, invoc) + filter.Invoke(context.Background(), invoker, invoc) assert.True(t, invoc.AttachmentsByKey(dubboInvokeStartTime, "") != "") } @@ -48,7 +48,7 @@ func TestActiveFilter_OnResponse(t *testing.T) { result := &protocol.RPCResult{ Err: errors.New("test"), } - filter.OnResponse(result, invoker, invoc) + filter.OnResponse(nil, result, invoker, invoc) methodStatus := protocol.GetMethodStatus(url, "test") urlStatus := protocol.GetURLStatus(url) diff --git a/filter/filter_impl/echo_filter.go b/filter/filter_impl/echo_filter.go index f67a47ac8704b1f6e10135bd24234cc0b8965dec..f6bdd4a4e8398c65303d426a48f104e12314ded3 100644 --- a/filter/filter_impl/echo_filter.go +++ b/filter/filter_impl/echo_filter.go @@ -17,6 +17,9 @@ package filter_impl +import ( + "context" +) import ( "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" @@ -38,7 +41,7 @@ func init() { // Echo func(ctx context.Context, arg interface{}, rsp *Xxx) error type EchoFilter struct{} -func (ef *EchoFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *EchoFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { logger.Infof("invoking echo filter.") logger.Debugf("%v,%v", invocation.MethodName(), len(invocation.Arguments())) if invocation.MethodName() == constant.ECHO && len(invocation.Arguments()) == 1 { @@ -48,10 +51,10 @@ func (ef *EchoFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invoc } } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (ef *EchoFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *EchoFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/filter/filter_impl/echo_filter_test.go b/filter/filter_impl/echo_filter_test.go index b75b9c19a1f073cc23dfccfa97a51e456e59d9cc..fc09bdce696c6be3c9e11d0ac864b187d1d85cde 100644 --- a/filter/filter_impl/echo_filter_test.go +++ b/filter/filter_impl/echo_filter_test.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "testing" ) @@ -33,12 +34,10 @@ import ( func TestEchoFilter_Invoke(t *testing.T) { filter := GetFilter() - result := filter.Invoke(protocol.NewBaseInvoker(common.URL{}), - invocation.NewRPCInvocation("$echo", []interface{}{"OK"}, nil)) + result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(common.URL{}), invocation.NewRPCInvocation("$echo", []interface{}{"OK"}, nil)) assert.Equal(t, "OK", result.Result()) - result = filter.Invoke(protocol.NewBaseInvoker(common.URL{}), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil)) + result = filter.Invoke(context.Background(), protocol.NewBaseInvoker(common.URL{}), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil)) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) } diff --git a/filter/filter_impl/execute_limit_filter.go b/filter/filter_impl/execute_limit_filter.go index a192aede400b1d73b7e604b09126ae372a1e91db..f9ff87751b21979f9d794db88deb9f4d8527f0d1 100644 --- a/filter/filter_impl/execute_limit_filter.go +++ b/filter/filter_impl/execute_limit_filter.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "strconv" "sync" "sync/atomic" @@ -75,7 +76,7 @@ type ExecuteState struct { concurrentCount int64 } -func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *ExecuteLimitFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { methodConfigPrefix := "methods." + invocation.MethodName() + "." url := invoker.GetUrl() limitTarget := url.ServiceKey() @@ -97,7 +98,7 @@ func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protoc } if limitRate < 0 { - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } state, _ := ef.executeState.LoadOrStore(limitTarget, &ExecuteState{ @@ -113,10 +114,10 @@ func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protoc return extension.GetRejectedExecutionHandler(rejectedHandlerConfig).RejectedExecution(url, invocation) } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (ef *ExecuteLimitFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *ExecuteLimitFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/filter/filter_impl/execute_limit_filter_test.go b/filter/filter_impl/execute_limit_filter_test.go index e3836251df4ba78befcbb5720affb5dbc3cbdf1f..ae8641f2db0b98b59f9939cfc85f3ad096b1bc7f 100644 --- a/filter/filter_impl/execute_limit_filter_test.go +++ b/filter/filter_impl/execute_limit_filter_test.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "net/url" "testing" ) @@ -43,7 +44,7 @@ func TestExecuteLimitFilter_Invoke_Ignored(t *testing.T) { limitFilter := GetExecuteLimitFilter() - result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc) + result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc) assert.NotNil(t, result) assert.Nil(t, result.Error()) } @@ -60,7 +61,7 @@ func TestExecuteLimitFilter_Invoke_Configure_Error(t *testing.T) { limitFilter := GetExecuteLimitFilter() - result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc) + result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc) assert.NotNil(t, result) assert.Nil(t, result.Error()) } @@ -77,7 +78,7 @@ func TestExecuteLimitFilter_Invoke(t *testing.T) { limitFilter := GetExecuteLimitFilter() - result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc) + result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc) assert.NotNil(t, result) assert.Nil(t, result.Error()) } diff --git a/filter/filter_impl/generic_filter.go b/filter/filter_impl/generic_filter.go index 3bfae1e35d76cd65289d5f100da621a8fa745d1b..9d3804d9434ce2ab108dfa8be4607a6425f2d29c 100644 --- a/filter/filter_impl/generic_filter.go +++ b/filter/filter_impl/generic_filter.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "reflect" "strings" ) @@ -44,7 +45,7 @@ func init() { type GenericFilter struct{} -func (ef *GenericFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *GenericFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 { oldArguments := invocation.Arguments() @@ -60,13 +61,13 @@ func (ef *GenericFilter) Invoke(invoker protocol.Invoker, invocation protocol.In } newInvocation := invocation2.NewRPCInvocation(invocation.MethodName(), newArguments, invocation.Attachments()) newInvocation.SetReply(invocation.Reply()) - return invoker.Invoke(newInvocation) + return invoker.Invoke(ctx, newInvocation) } } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (ef *GenericFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *GenericFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/filter/filter_impl/generic_service_filter.go b/filter/filter_impl/generic_service_filter.go index da33f13e5ef29a7164c3776b65cc5cabd4b43888..6beebf4566b657d4d4ea0d2c737cdf3344bdcbe4 100644 --- a/filter/filter_impl/generic_service_filter.go +++ b/filter/filter_impl/generic_service_filter.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "reflect" "strings" ) @@ -49,12 +50,12 @@ func init() { type GenericServiceFilter struct{} -func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *GenericServiceFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { logger.Infof("invoking generic service filter.") logger.Debugf("generic service filter methodName:%v,args:%v", invocation.MethodName(), len(invocation.Arguments())) if invocation.MethodName() != constant.GENERIC || len(invocation.Arguments()) != 3 { - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } var ( @@ -107,10 +108,10 @@ func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation prot } newInvocation := invocation2.NewRPCInvocation(methodName, newParams, invocation.Attachments()) newInvocation.SetReply(invocation.Reply()) - return invoker.Invoke(newInvocation) + return invoker.Invoke(ctx, newInvocation) } -func (ef *GenericServiceFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *GenericServiceFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 && result.Result() != nil { v := reflect.ValueOf(result.Result()) if v.Kind() == reflect.Ptr { diff --git a/filter/filter_impl/generic_service_filter_test.go b/filter/filter_impl/generic_service_filter_test.go index e36ec5086ecacffbf56a0da6dd9249ffd6fec649..8211e717564465bba3009772715a3ab1cd3322dd 100644 --- a/filter/filter_impl/generic_service_filter_test.go +++ b/filter/filter_impl/generic_service_filter_test.go @@ -99,7 +99,7 @@ func TestGenericServiceFilter_Invoke(t *testing.T) { rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil) filter := GetGenericServiceFilter() url, _ := common.NewURL(context.Background(), "testprotocol://127.0.0.1:20000/com.test.Path") - result := filter.Invoke(&proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation) + result := filter.Invoke(context.Background(), &proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation) assert.NotNil(t, result) assert.Nil(t, result.Error()) } @@ -124,7 +124,7 @@ func TestGenericServiceFilter_ResponseTestStruct(t *testing.T) { filter := GetGenericServiceFilter() methodName := "$invoke" rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil) - r := filter.OnResponse(result, nil, rpcInvocation) + r := filter.OnResponse(nil, result, nil, rpcInvocation) assert.NotNil(t, r.Result()) assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.Map) } @@ -142,7 +142,7 @@ func TestGenericServiceFilter_ResponseString(t *testing.T) { filter := GetGenericServiceFilter() methodName := "$invoke" rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil) - r := filter.OnResponse(result, nil, rpcInvocation) + r := filter.OnResponse(nil, result, nil, rpcInvocation) assert.NotNil(t, r.Result()) assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.String) } diff --git a/filter/filter_impl/graceful_shutdown_filter.go b/filter/filter_impl/graceful_shutdown_filter.go index 1af7e1f8c32ea3924550399a7ff5e76c68368636..95e625b2d56895a4d57823e4e0e2e7d1d5e90a08 100644 --- a/filter/filter_impl/graceful_shutdown_filter.go +++ b/filter/filter_impl/graceful_shutdown_filter.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "sync/atomic" ) @@ -52,16 +53,16 @@ type gracefulShutdownFilter struct { shutdownConfig *config.ShutdownConfig } -func (gf *gracefulShutdownFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (gf *gracefulShutdownFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { if gf.rejectNewRequest() { logger.Info("The application is closing, new request will be rejected.") return gf.getRejectHandler().RejectedExecution(invoker.GetUrl(), invocation) } atomic.AddInt32(&gf.activeCount, 1) - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (gf *gracefulShutdownFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (gf *gracefulShutdownFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { atomic.AddInt32(&gf.activeCount, -1) // although this isn't thread safe, it won't be a problem if the gf.rejectNewRequest() is true. if gf.shutdownConfig != nil && gf.activeCount <= 0 { diff --git a/filter/filter_impl/graceful_shutdown_filter_test.go b/filter/filter_impl/graceful_shutdown_filter_test.go index fc437c3557fa452273e770d3d50678401ba3b33b..4c670933e3dcec29ad9ae7bfef250b4236ae7c54 100644 --- a/filter/filter_impl/graceful_shutdown_filter_test.go +++ b/filter/filter_impl/graceful_shutdown_filter_test.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "net/url" "testing" ) @@ -53,7 +54,7 @@ func TestGenericFilter_Invoke(t *testing.T) { assert.Equal(t, extension.GetRejectedExecutionHandler(constant.DEFAULT_KEY), shutdownFilter.getRejectHandler()) - result := shutdownFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc) + result := shutdownFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc) assert.NotNil(t, result) assert.Nil(t, result.Error()) @@ -64,7 +65,7 @@ func TestGenericFilter_Invoke(t *testing.T) { shutdownFilter.shutdownConfig = providerConfig.ShutdownConfig assert.True(t, shutdownFilter.rejectNewRequest()) - result = shutdownFilter.OnResponse(nil, protocol.NewBaseInvoker(*invokeUrl), invoc) + result = shutdownFilter.OnResponse(nil, nil, protocol.NewBaseInvoker(*invokeUrl), invoc) rejectHandler := &common2.OnlyLogRejectedExecutionHandler{} extension.SetRejectedExecutionHandler("mock", func() filter.RejectedExecutionHandler { diff --git a/filter/filter_impl/hystrix_filter.go b/filter/filter_impl/hystrix_filter.go index a6e07803046005b5ab31d7a02ea9e25f4b74da75..0f40d815ffbd4c199ad30cad44eb1a94e93cf916 100644 --- a/filter/filter_impl/hystrix_filter.go +++ b/filter/filter_impl/hystrix_filter.go @@ -17,6 +17,7 @@ package filter_impl import ( + "context" "fmt" "regexp" "sync" @@ -82,7 +83,7 @@ type HystrixFilter struct { ifNewMap sync.Map } -func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (hf *HystrixFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { cmdName := fmt.Sprintf("%s&method=%s", invoker.GetUrl().Key(), invocation.MethodName()) @@ -115,12 +116,12 @@ func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.In configLoadMutex.RUnlock() if err != nil { logger.Errorf("[Hystrix Filter]Errors occurred getting circuit for %s , will invoke without hystrix, error is: ", cmdName, err) - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } logger.Infof("[Hystrix Filter]Using hystrix filter: %s", cmdName) var result protocol.Result _ = hystrix.Do(cmdName, func() error { - result = invoker.Invoke(invocation) + result = invoker.Invoke(ctx, invocation) err := result.Error() if err != nil { result.SetError(NewHystrixFilterError(err, false)) @@ -144,7 +145,7 @@ func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.In return result } -func (hf *HystrixFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (hf *HystrixFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } func GetHystrixFilterConsumer() filter.Filter { diff --git a/filter/filter_impl/hystrix_filter_test.go b/filter/filter_impl/hystrix_filter_test.go index 2bbc3e079e7ae563db1efa18f82423931fd5919d..894573036ae6dd9edca88e8e4cdd92e7643abcb5 100644 --- a/filter/filter_impl/hystrix_filter_test.go +++ b/filter/filter_impl/hystrix_filter_test.go @@ -17,6 +17,7 @@ package filter_impl import ( + "context" "regexp" "testing" ) @@ -125,7 +126,7 @@ type testMockSuccessInvoker struct { protocol.BaseInvoker } -func (iv *testMockSuccessInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (iv *testMockSuccessInvoker) Invoke(context context.Context, invocation protocol.Invocation) protocol.Result { return &protocol.RPCResult{ Rest: "Sucess", Err: nil, @@ -136,7 +137,7 @@ type testMockFailInvoker struct { protocol.BaseInvoker } -func (iv *testMockFailInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (iv *testMockFailInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { return &protocol.RPCResult{ Err: errors.Errorf("exception"), } @@ -144,7 +145,7 @@ func (iv *testMockFailInvoker) Invoke(invocation protocol.Invocation) protocol.R func TestHystrixFilter_Invoke_Success(t *testing.T) { hf := &HystrixFilter{} - result := hf.Invoke(&testMockSuccessInvoker{}, &invocation.RPCInvocation{}) + result := hf.Invoke(context.Background(), &testMockSuccessInvoker{}, &invocation.RPCInvocation{}) assert.NotNil(t, result) assert.NoError(t, result.Error()) assert.NotNil(t, result.Result()) @@ -152,7 +153,7 @@ func TestHystrixFilter_Invoke_Success(t *testing.T) { func TestHystrixFilter_Invoke_Fail(t *testing.T) { hf := &HystrixFilter{} - result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{}) + result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{}) assert.NotNil(t, result) assert.Error(t, result.Error()) } @@ -164,7 +165,7 @@ func TestHystricFilter_Invoke_CircuitBreak(t *testing.T) { resChan := make(chan protocol.Result, 50) for i := 0; i < 50; i++ { go func() { - result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{}) + result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{}) resChan <- result }() } @@ -189,7 +190,7 @@ func TestHystricFilter_Invoke_CircuitBreak_Omit_Exception(t *testing.T) { resChan := make(chan protocol.Result, 50) for i := 0; i < 50; i++ { go func() { - result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{}) + result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{}) resChan <- result }() } diff --git a/filter/filter_impl/token_filter.go b/filter/filter_impl/token_filter.go index 180f3e6631a2fd0b317af3a4addd8d77287d82d5..702ee33d4d2e9756ab3b4dbb4bfc9b7c42907080 100644 --- a/filter/filter_impl/token_filter.go +++ b/filter/filter_impl/token_filter.go @@ -18,6 +18,7 @@ limitations under the License. package filter_impl import ( + "context" "strings" ) @@ -42,22 +43,22 @@ func init() { type TokenFilter struct{} -func (tf *TokenFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (tf *TokenFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { invokerTkn := invoker.GetUrl().GetParam(constant.TOKEN_KEY, "") if len(invokerTkn) > 0 { attachs := invocation.Attachments() remoteTkn, exist := attachs[constant.TOKEN_KEY] if exist && strings.EqualFold(invokerTkn, remoteTkn) { - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } return &protocol.RPCResult{Err: perrors.Errorf("Invalid token! Forbid invoke remote service %v method %s ", invoker, invocation.MethodName())} } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (tf *TokenFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (tf *TokenFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/filter/filter_impl/token_filter_test.go b/filter/filter_impl/token_filter_test.go index 675d33dc7d401b04f59037c1ec2eb44c8d6ecbe4..672082c729bc371a40573a66d13bc57a7024186b 100644 --- a/filter/filter_impl/token_filter_test.go +++ b/filter/filter_impl/token_filter_test.go @@ -18,6 +18,7 @@ limitations under the License. package filter_impl import ( + "context" "net/url" "testing" ) @@ -41,8 +42,10 @@ func TestTokenFilter_Invoke(t *testing.T) { common.WithParamsValue(constant.TOKEN_KEY, "ori_key")) attch := make(map[string]string, 0) attch[constant.TOKEN_KEY] = "ori_key" - result := filter.Invoke(protocol.NewBaseInvoker(*url), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := filter.Invoke(context.Background(), + protocol.NewBaseInvoker(*url), + invocation.NewRPCInvocation("MethodName", + []interface{}{"OK"}, attch)) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) } @@ -53,8 +56,7 @@ func TestTokenFilter_InvokeEmptyToken(t *testing.T) { url := common.URL{} attch := make(map[string]string, 0) attch[constant.TOKEN_KEY] = "ori_key" - result := filter.Invoke(protocol.NewBaseInvoker(url), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) } @@ -66,8 +68,7 @@ func TestTokenFilter_InvokeEmptyAttach(t *testing.T) { common.WithParams(url.Values{}), common.WithParamsValue(constant.TOKEN_KEY, "ori_key")) attch := make(map[string]string, 0) - result := filter.Invoke(protocol.NewBaseInvoker(*url), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(*url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) assert.NotNil(t, result.Error()) } @@ -79,7 +80,7 @@ func TestTokenFilter_InvokeNotEqual(t *testing.T) { common.WithParamsValue(constant.TOKEN_KEY, "ori_key")) attch := make(map[string]string, 0) attch[constant.TOKEN_KEY] = "err_key" - result := filter.Invoke(protocol.NewBaseInvoker(*url), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := filter.Invoke(context.Background(), + protocol.NewBaseInvoker(*url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) assert.NotNil(t, result.Error()) } diff --git a/filter/filter_impl/tps_limit_filter.go b/filter/filter_impl/tps_limit_filter.go index 77414a8ea70743983cadc609c875920cff525487..8852260e9e7b4b833728da97dc8f273d3e52dec7 100644 --- a/filter/filter_impl/tps_limit_filter.go +++ b/filter/filter_impl/tps_limit_filter.go @@ -17,6 +17,9 @@ package filter_impl +import ( + "context" +) import ( "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" @@ -51,22 +54,22 @@ func init() { type TpsLimitFilter struct { } -func (t TpsLimitFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (t TpsLimitFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { url := invoker.GetUrl() tpsLimiter := url.GetParam(constant.TPS_LIMITER_KEY, "") rejectedExeHandler := url.GetParam(constant.TPS_REJECTED_EXECUTION_HANDLER_KEY, constant.DEFAULT_KEY) if len(tpsLimiter) > 0 { allow := extension.GetTpsLimiter(tpsLimiter).IsAllowable(invoker.GetUrl(), invocation) if allow { - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } logger.Errorf("The invocation was rejected due to over the tps limitation, url: %s ", url.String()) return extension.GetRejectedExecutionHandler(rejectedExeHandler).RejectedExecution(url, invocation) } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (t TpsLimitFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (t TpsLimitFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/filter/filter_impl/tps_limit_filter_test.go b/filter/filter_impl/tps_limit_filter_test.go index 5e04804aa23c4e6e417f6bb9975a3269a2118739..cc423ae1e5f3589dd60b0c8655f1123c290f0ffc 100644 --- a/filter/filter_impl/tps_limit_filter_test.go +++ b/filter/filter_impl/tps_limit_filter_test.go @@ -18,6 +18,7 @@ package filter_impl import ( + "context" "net/url" "testing" ) @@ -45,8 +46,10 @@ func TestTpsLimitFilter_Invoke_With_No_TpsLimiter(t *testing.T) { common.WithParamsValue(constant.TPS_LIMITER_KEY, "")) attch := make(map[string]string, 0) - result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := tpsFilter.Invoke(context.Background(), + protocol.NewBaseInvoker(*invokeUrl), + invocation.NewRPCInvocation("MethodName", + []interface{}{"OK"}, attch)) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) @@ -67,8 +70,10 @@ func TestGenericFilter_Invoke_With_Default_TpsLimiter(t *testing.T) { common.WithParamsValue(constant.TPS_LIMITER_KEY, constant.DEFAULT_KEY)) attch := make(map[string]string, 0) - result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := tpsFilter.Invoke(context.Background(), + protocol.NewBaseInvoker(*invokeUrl), + invocation.NewRPCInvocation("MethodName", + []interface{}{"OK"}, attch)) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) } @@ -96,8 +101,8 @@ func TestGenericFilter_Invoke_With_Default_TpsLimiter_Not_Allow(t *testing.T) { common.WithParamsValue(constant.TPS_LIMITER_KEY, constant.DEFAULT_KEY)) attch := make(map[string]string, 0) - result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), - invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) + result := tpsFilter.Invoke(context.Background(), + protocol.NewBaseInvoker(*invokeUrl), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch)) assert.Nil(t, result.Error()) assert.Nil(t, result.Result()) } diff --git a/protocol/dubbo/dubbo_invoker.go b/protocol/dubbo/dubbo_invoker.go index 6dcf2568fa8c88a864c567486a501c2ad7feb3f7..8100fbe3a6760e456a9eecedfd39e5230dd2c797 100644 --- a/protocol/dubbo/dubbo_invoker.go +++ b/protocol/dubbo/dubbo_invoker.go @@ -18,6 +18,7 @@ package dubbo import ( + "context" "strconv" "sync" ) @@ -53,7 +54,7 @@ func NewDubboInvoker(url common.URL, client *Client) *DubboInvoker { } } -func (di *DubboInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { var ( err error diff --git a/protocol/dubbo/dubbo_invoker_test.go b/protocol/dubbo/dubbo_invoker_test.go index 8a032d0ca9536e53dc18c43f8e699212f7c30ec4..e360d57b8cdd61674d35a665e8ee85e03421cc8f 100644 --- a/protocol/dubbo/dubbo_invoker_test.go +++ b/protocol/dubbo/dubbo_invoker_test.go @@ -18,6 +18,7 @@ package dubbo import ( + "context" "sync" "testing" "time" @@ -53,14 +54,14 @@ func TestDubboInvoker_Invoke(t *testing.T) { invocation.WithReply(user), invocation.WithAttachments(map[string]string{"test_key": "test_value"})) // Call - res := invoker.Invoke(inv) + res := invoker.Invoke(context.Background(), inv) assert.NoError(t, res.Error()) assert.Equal(t, User{Id: "1", Name: "username"}, *res.Result().(*User)) assert.Equal(t, "test_value", res.Attachments()["test_key"]) // test attachments for request/response // CallOneway inv.SetAttachments(constant.ASYNC_KEY, "true") - res = invoker.Invoke(inv) + res = invoker.Invoke(context.Background(), inv) assert.NoError(t, res.Error()) // AsyncCall @@ -71,13 +72,13 @@ func TestDubboInvoker_Invoke(t *testing.T) { assert.Equal(t, User{Id: "1", Name: "username"}, *r.Reply.(*Response).reply.(*User)) lock.Unlock() }) - res = invoker.Invoke(inv) + res = invoker.Invoke(context.Background(), inv) assert.NoError(t, res.Error()) // Err_No_Reply inv.SetAttachments(constant.ASYNC_KEY, "false") inv.SetReply(nil) - res = invoker.Invoke(inv) + res = invoker.Invoke(context.Background(), inv) assert.EqualError(t, res.Error(), "request need @response") // destroy diff --git a/protocol/dubbo/listener.go b/protocol/dubbo/listener.go index 2e4b3999dfc08262a2cfb80f29c9a9e7bc2decf8..1ed6e9cf57f3399ce2a7a8134bad9924d6799460 100644 --- a/protocol/dubbo/listener.go +++ b/protocol/dubbo/listener.go @@ -18,6 +18,7 @@ package dubbo import ( + "context" "fmt" "net/url" "sync" @@ -258,7 +259,7 @@ func (h *RpcServerHandler) OnMessage(session getty.Session, pkg interface{}) { args := p.Body.(map[string]interface{})["args"].([]interface{}) inv := invocation.NewRPCInvocation(p.Service.Method, args, attachments) - result := invoker.Invoke(inv) + result := invoker.Invoke(context.Background(), inv) if err := result.Error(); err != nil { p.Header.ResponseStatus = hessian.Response_OK p.Body = hessian.NewResponse(nil, err, result.Attachments()) diff --git a/protocol/grpc/common_test.go b/protocol/grpc/common_test.go index 7f78bdc40d07a9089c1cf40f55803f04b39cb949..165b82fabc5703a720766b04659b158d2b3fdbdf 100644 --- a/protocol/grpc/common_test.go +++ b/protocol/grpc/common_test.go @@ -97,7 +97,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f invo := invocation.NewRPCInvocation("SayHello", args, nil) if interceptor == nil { - result := base.GetProxyImpl().Invoke(invo) + result := base.GetProxyImpl().Invoke(context.Background(), invo) return result.Result(), result.Error() } info := &native_grpc.UnaryServerInfo{ @@ -105,7 +105,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f FullMethod: "/helloworld.Greeter/SayHello", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - result := base.GetProxyImpl().Invoke(invo) + result := base.GetProxyImpl().Invoke(context.Background(), invo) return result.Result(), result.Error() } return interceptor(ctx, in, info, handler) diff --git a/protocol/grpc/grpc_invoker.go b/protocol/grpc/grpc_invoker.go index b74612b896addb1ff08c3abe44198c147996a126..88149397e79aa435a6a9d41911ae0e603754534e 100644 --- a/protocol/grpc/grpc_invoker.go +++ b/protocol/grpc/grpc_invoker.go @@ -50,7 +50,7 @@ func NewGrpcInvoker(url common.URL, client *Client) *GrpcInvoker { } } -func (gi *GrpcInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (gi *GrpcInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { var ( result protocol.RPCResult ) diff --git a/protocol/grpc/grpc_invoker_test.go b/protocol/grpc/grpc_invoker_test.go index 4f97e1063191692ce5f47e0d4f8242d95cc8a6fc..5d4b97051438f8404cd8fd89bcf73d24e0121868 100644 --- a/protocol/grpc/grpc_invoker_test.go +++ b/protocol/grpc/grpc_invoker_test.go @@ -49,7 +49,7 @@ func TestInvoke(t *testing.T) { bizReply := &internal.HelloReply{} invo := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("SayHello"), invocation.WithParameterValues(args), invocation.WithReply(bizReply)) - res := invoker.Invoke(invo) + res := invoker.Invoke(context.Background(), invo) assert.Nil(t, res.Error()) assert.NotNil(t, res.Result()) assert.Equal(t, "Hello request name", bizReply.Message) diff --git a/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go b/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go index 4ed55ab7612200d28816508e4c4fcb7de0a803c0..f5d3a49b0916050fc6b2e6373fde0b70df0a1c31 100644 --- a/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go +++ b/protocol/grpc/protoc-gen-dubbo/examples/helloworld.pb.go @@ -271,7 +271,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f args = append(args, in) invo := invocation.NewRPCInvocation("SayHello", args, nil) if interceptor == nil { - result := base.GetProxyImpl().Invoke(invo) + result := base.GetProxyImpl().Invoke(context.Background(), invo) return result.Result(), result.Error() } info := &grpc.UnaryServerInfo{ @@ -279,7 +279,7 @@ func _DUBBO_Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec f FullMethod: "/main.Greeter/SayHello", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - result := base.GetProxyImpl().Invoke(invo) + result := base.GetProxyImpl().Invoke(context.Background(), invo) return result.Result(), result.Error() } return interceptor(ctx, in, info, handler) diff --git a/protocol/invoker.go b/protocol/invoker.go index f5d41a09ad2778c12c7e5e68167a4d0acc9e3f4c..a1cf6264ae2b9f631b1bb12f88e8378ad5857919 100644 --- a/protocol/invoker.go +++ b/protocol/invoker.go @@ -17,6 +17,9 @@ package protocol +import ( + "context" +) import ( "github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common/logger" @@ -26,7 +29,7 @@ import ( // Extension - Invoker type Invoker interface { common.Node - Invoke(Invocation) Result + Invoke(context.Context, Invocation) Result } ///////////////////////////// @@ -59,7 +62,7 @@ func (bi *BaseInvoker) IsDestroyed() bool { return bi.destroyed } -func (bi *BaseInvoker) Invoke(invocation Invocation) Result { +func (bi *BaseInvoker) Invoke(context context.Context, invocation Invocation) Result { return &RPCResult{} } diff --git a/protocol/jsonrpc/jsonrpc_invoker.go b/protocol/jsonrpc/jsonrpc_invoker.go index 2c130e0d7617e96a1724edc5b63f8e66f251446e..07e09b07c9e55e804a15cc79587b020230f087fa 100644 --- a/protocol/jsonrpc/jsonrpc_invoker.go +++ b/protocol/jsonrpc/jsonrpc_invoker.go @@ -41,7 +41,7 @@ func NewJsonrpcInvoker(url common.URL, client *HTTPClient) *JsonrpcInvoker { } } -func (ji *JsonrpcInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (ji *JsonrpcInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { var ( result protocol.RPCResult @@ -50,12 +50,12 @@ func (ji *JsonrpcInvoker) Invoke(invocation protocol.Invocation) protocol.Result inv := invocation.(*invocation_impl.RPCInvocation) url := ji.GetUrl() req := ji.client.NewRequest(url, inv.MethodName(), inv.Arguments()) - ctx := context.WithValue(context.Background(), constant.DUBBOGO_CTX_KEY, map[string]string{ + ctxNew := context.WithValue(context.Background(), constant.DUBBOGO_CTX_KEY, map[string]string{ "X-Proxy-Id": "dubbogo", "X-Services": url.Path, "X-Method": inv.MethodName(), }) - result.Err = ji.client.Call(ctx, url, req, inv.Reply()) + result.Err = ji.client.Call(ctxNew, url, req, inv.Reply()) if result.Err == nil { result.Rest = inv.Reply() } diff --git a/protocol/jsonrpc/jsonrpc_invoker_test.go b/protocol/jsonrpc/jsonrpc_invoker_test.go index 8c910339858f4960ad0e394ae6271863d7654adc..9eed22e67155f1b0915cbb398bcef55962258407 100644 --- a/protocol/jsonrpc/jsonrpc_invoker_test.go +++ b/protocol/jsonrpc/jsonrpc_invoker_test.go @@ -60,7 +60,7 @@ func TestJsonrpcInvoker_Invoke(t *testing.T) { jsonInvoker := NewJsonrpcInvoker(url, client) user := &User{} - res := jsonInvoker.Invoke(invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"), invocation.WithArguments([]interface{}{"1", "username"}), + res := jsonInvoker.Invoke(context.Background(), invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"), invocation.WithArguments([]interface{}{"1", "username"}), invocation.WithReply(user))) assert.NoError(t, res.Error()) diff --git a/protocol/jsonrpc/server.go b/protocol/jsonrpc/server.go index dc79e4a36bd6cce575d50588d11b003cb8e25abe..3decc8867474d99da8f50584e60375af0e54b225 100644 --- a/protocol/jsonrpc/server.go +++ b/protocol/jsonrpc/server.go @@ -330,7 +330,7 @@ func serveRequest(ctx context.Context, exporter, _ := jsonrpcProtocol.ExporterMap().Load(path) invoker := exporter.(*JsonrpcExporter).GetInvoker() if invoker != nil { - result := invoker.Invoke(invocation.NewRPCInvocation(methodName, args, map[string]string{ + result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodName, args, map[string]string{ constant.PATH_KEY: path, constant.VERSION_KEY: codec.req.Version, })) diff --git a/protocol/mock/mock_invoker.go b/protocol/mock/mock_invoker.go index c509cef054f5a23fe504486e01d7cc0e8772711d..5c5b476b7b07f6c41a74a7ec8f51648aff84b1a3 100644 --- a/protocol/mock/mock_invoker.go +++ b/protocol/mock/mock_invoker.go @@ -21,6 +21,7 @@ package mock import ( + "context" "reflect" ) @@ -91,7 +92,7 @@ func (mr *MockInvokerMockRecorder) Destroy() *gomock.Call { } // Invoke mocks base method -func (m *MockInvoker) Invoke(arg0 protocol.Invocation) protocol.Result { +func (m *MockInvoker) Invoke(ctx context.Context, arg0 protocol.Invocation) protocol.Result { ret := m.ctrl.Call(m, "Invoke", arg0) ret0, _ := ret[0].(protocol.Result) return ret0 diff --git a/protocol/protocolwrapper/protocol_filter_wrapper.go b/protocol/protocolwrapper/protocol_filter_wrapper.go index 7c58fabea3cccf5a39e1622fedd4a3a297e05983..33ea38201251df3abc6639b416200611cc993e56 100644 --- a/protocol/protocolwrapper/protocol_filter_wrapper.go +++ b/protocol/protocolwrapper/protocol_filter_wrapper.go @@ -18,6 +18,7 @@ package protocolwrapper import ( + "context" "strings" ) @@ -102,9 +103,9 @@ func (fi *FilterInvoker) IsAvailable() bool { return fi.invoker.IsAvailable() } -func (fi *FilterInvoker) Invoke(invocation protocol.Invocation) protocol.Result { - result := fi.filter.Invoke(fi.next, invocation) - return fi.filter.OnResponse(result, fi.invoker, invocation) +func (fi *FilterInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { + result := fi.filter.Invoke(ctx, fi.next, invocation) + return fi.filter.OnResponse(ctx, result, fi.invoker, invocation) } func (fi *FilterInvoker) Destroy() { diff --git a/protocol/protocolwrapper/protocol_filter_wrapper_test.go b/protocol/protocolwrapper/protocol_filter_wrapper_test.go index dc376313549c24da1cc6cb64a42e8445ef4fe346..8491d57462d47d6af72040d41b78dcb30e6da697 100644 --- a/protocol/protocolwrapper/protocol_filter_wrapper_test.go +++ b/protocol/protocolwrapper/protocol_filter_wrapper_test.go @@ -18,6 +18,7 @@ package protocolwrapper import ( + "context" "net/url" "testing" ) @@ -66,7 +67,7 @@ func init() { type EchoFilterForTest struct{} -func (ef *EchoFilterForTest) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *EchoFilterForTest) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { logger.Infof("invoking echo filter.") logger.Debugf("%v,%v", invocation.MethodName(), len(invocation.Arguments())) if invocation.MethodName() == constant.ECHO && len(invocation.Arguments()) == 1 { @@ -75,10 +76,10 @@ func (ef *EchoFilterForTest) Invoke(invoker protocol.Invoker, invocation protoco } } - return invoker.Invoke(invocation) + return invoker.Invoke(ctx, invocation) } -func (ef *EchoFilterForTest) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { +func (ef *EchoFilterForTest) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { return result } diff --git a/registry/protocol/protocol.go b/registry/protocol/protocol.go index 8655312a4eb508dfe5c910855ba5f3e3aacd666e..9e6b9999b976d5cfcc76560731f383a52c2642f4 100644 --- a/registry/protocol/protocol.go +++ b/registry/protocol/protocol.go @@ -18,6 +18,7 @@ package protocol import ( + "context" "strings" "sync" ) @@ -356,10 +357,10 @@ func newWrappedInvoker(invoker protocol.Invoker, url *common.URL) *wrappedInvoke } } -func (ivk *wrappedInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (ivk *wrappedInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { // get right url ivk.invoker.(*proxy_factory.ProxyInvoker).BaseInvoker = *protocol.NewBaseInvoker(ivk.GetUrl()) - return ivk.invoker.Invoke(invocation) + return ivk.invoker.Invoke(ctx, invocation) } type providerConfigurationListener struct {