Skip to content
Snippets Groups Projects
Commit 9ba91fa5 authored by Ming Deng's avatar Ming Deng Committed by GitHub
Browse files

Merge pull request #330 from flycash/feature/context

Ftr: Context support
parents cea4597d 812c73f9
No related branches found
No related tags found
No related merge requests found
Showing
with 97 additions and 74 deletions
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"os"
"reflect"
"strings"
......@@ -66,13 +67,13 @@ type AccessLogFilter struct {
logChan chan AccessLogData
}
func (ef *AccessLogFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *AccessLogFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
accessLog := invoker.GetUrl().GetParam(constant.ACCESS_LOG_KEY, "")
if len(accessLog) > 0 {
accessLogData := AccessLogData{data: ef.buildAccessLogData(invoker, invocation), accessLog: accessLog}
ef.logIntoChannel(accessLogData)
}
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
// it won't block the invocation
......@@ -119,7 +120,7 @@ func (ef *AccessLogFilter) buildAccessLogData(invoker protocol.Invoker, invocati
return dataMap
}
func (ef *AccessLogFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *AccessLogFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
......
......@@ -49,7 +49,7 @@ func TestAccessLogFilter_Invoke_Not_Config(t *testing.T) {
inv := invocation.NewRPCInvocation("MethodName", []interface{}{"OK", "Hello"}, attach)
accessLogFilter := GetAccessLogFilter()
result := accessLogFilter.Invoke(invoker, inv)
result := accessLogFilter.Invoke(context.Background(), invoker, inv)
assert.Nil(t, result.Error())
}
......@@ -70,13 +70,13 @@ func TestAccessLogFilter_Invoke_Default_Config(t *testing.T) {
inv := invocation.NewRPCInvocation("MethodName", []interface{}{"OK", "Hello"}, attach)
accessLogFilter := GetAccessLogFilter()
result := accessLogFilter.Invoke(invoker, inv)
result := accessLogFilter.Invoke(context.Background(), invoker, inv)
assert.Nil(t, result.Error())
}
func TestAccessLogFilter_OnResponse(t *testing.T) {
result := &protocol.RPCResult{}
accessLogFilter := GetAccessLogFilter()
response := accessLogFilter.OnResponse(result, nil, nil)
response := accessLogFilter.OnResponse(nil, result, nil, nil)
assert.Equal(t, result, response)
}
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"strconv"
)
......@@ -41,15 +42,15 @@ func init() {
type ActiveFilter struct {
}
func (ef *ActiveFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *ActiveFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking active filter. %v,%v", invocation.MethodName(), len(invocation.Arguments()))
invocation.(*invocation2.RPCInvocation).SetAttachments(dubboInvokeStartTime, strconv.FormatInt(protocol.CurrentTimeMillis(), 10))
protocol.BeginCount(invoker.GetUrl(), invocation.MethodName())
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (ef *ActiveFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *ActiveFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
startTime, err := strconv.ParseInt(invocation.(*invocation2.RPCInvocation).AttachmentsByKey(dubboInvokeStartTime, "0"), 10, 64)
if err != nil {
......
......@@ -28,7 +28,7 @@ func TestActiveFilter_Invoke(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
invoker.EXPECT().Invoke(gomock.Any()).Return(nil)
invoker.EXPECT().GetUrl().Return(url).Times(1)
filter.Invoke(invoker, invoc)
filter.Invoke(context.Background(), invoker, invoc)
assert.True(t, invoc.AttachmentsByKey(dubboInvokeStartTime, "") != "")
}
......@@ -48,7 +48,7 @@ func TestActiveFilter_OnResponse(t *testing.T) {
result := &protocol.RPCResult{
Err: errors.New("test"),
}
filter.OnResponse(result, invoker, invoc)
filter.OnResponse(nil, result, invoker, invoc)
methodStatus := protocol.GetMethodStatus(url, "test")
urlStatus := protocol.GetURLStatus(url)
......
......@@ -17,6 +17,9 @@
package filter_impl
import (
"context"
)
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
......@@ -38,7 +41,7 @@ func init() {
// Echo func(ctx context.Context, arg interface{}, rsp *Xxx) error
type EchoFilter struct{}
func (ef *EchoFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *EchoFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking echo filter.")
logger.Debugf("%v,%v", invocation.MethodName(), len(invocation.Arguments()))
if invocation.MethodName() == constant.ECHO && len(invocation.Arguments()) == 1 {
......@@ -48,10 +51,10 @@ func (ef *EchoFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invoc
}
}
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (ef *EchoFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *EchoFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
......
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"testing"
)
......@@ -33,12 +34,10 @@ import (
func TestEchoFilter_Invoke(t *testing.T) {
filter := GetFilter()
result := filter.Invoke(protocol.NewBaseInvoker(common.URL{}),
invocation.NewRPCInvocation("$echo", []interface{}{"OK"}, nil))
result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(common.URL{}), invocation.NewRPCInvocation("$echo", []interface{}{"OK"}, nil))
assert.Equal(t, "OK", result.Result())
result = filter.Invoke(protocol.NewBaseInvoker(common.URL{}),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil))
result = filter.Invoke(context.Background(), protocol.NewBaseInvoker(common.URL{}), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"strconv"
"sync"
"sync/atomic"
......@@ -75,7 +76,7 @@ type ExecuteState struct {
concurrentCount int64
}
func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *ExecuteLimitFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
methodConfigPrefix := "methods." + invocation.MethodName() + "."
url := invoker.GetUrl()
limitTarget := url.ServiceKey()
......@@ -97,7 +98,7 @@ func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protoc
}
if limitRate < 0 {
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
state, _ := ef.executeState.LoadOrStore(limitTarget, &ExecuteState{
......@@ -113,10 +114,10 @@ func (ef *ExecuteLimitFilter) Invoke(invoker protocol.Invoker, invocation protoc
return extension.GetRejectedExecutionHandler(rejectedHandlerConfig).RejectedExecution(url, invocation)
}
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (ef *ExecuteLimitFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *ExecuteLimitFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
......
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"net/url"
"testing"
)
......@@ -43,7 +44,7 @@ func TestExecuteLimitFilter_Invoke_Ignored(t *testing.T) {
limitFilter := GetExecuteLimitFilter()
result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
......@@ -60,7 +61,7 @@ func TestExecuteLimitFilter_Invoke_Configure_Error(t *testing.T) {
limitFilter := GetExecuteLimitFilter()
result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
......@@ -77,7 +78,7 @@ func TestExecuteLimitFilter_Invoke(t *testing.T) {
limitFilter := GetExecuteLimitFilter()
result := limitFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
result := limitFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"reflect"
"strings"
)
......@@ -44,7 +45,7 @@ func init() {
type GenericFilter struct{}
func (ef *GenericFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *GenericFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 {
oldArguments := invocation.Arguments()
......@@ -60,13 +61,13 @@ func (ef *GenericFilter) Invoke(invoker protocol.Invoker, invocation protocol.In
}
newInvocation := invocation2.NewRPCInvocation(invocation.MethodName(), newArguments, invocation.Attachments())
newInvocation.SetReply(invocation.Reply())
return invoker.Invoke(newInvocation)
return invoker.Invoke(ctx, newInvocation)
}
}
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (ef *GenericFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *GenericFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
......
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"reflect"
"strings"
)
......@@ -49,12 +50,12 @@ func init() {
type GenericServiceFilter struct{}
func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *GenericServiceFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking generic service filter.")
logger.Debugf("generic service filter methodName:%v,args:%v", invocation.MethodName(), len(invocation.Arguments()))
if invocation.MethodName() != constant.GENERIC || len(invocation.Arguments()) != 3 {
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
var (
......@@ -107,10 +108,10 @@ func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation prot
}
newInvocation := invocation2.NewRPCInvocation(methodName, newParams, invocation.Attachments())
newInvocation.SetReply(invocation.Reply())
return invoker.Invoke(newInvocation)
return invoker.Invoke(ctx, newInvocation)
}
func (ef *GenericServiceFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (ef *GenericServiceFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 && result.Result() != nil {
v := reflect.ValueOf(result.Result())
if v.Kind() == reflect.Ptr {
......
......@@ -99,7 +99,7 @@ func TestGenericServiceFilter_Invoke(t *testing.T) {
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
filter := GetGenericServiceFilter()
url, _ := common.NewURL(context.Background(), "testprotocol://127.0.0.1:20000/com.test.Path")
result := filter.Invoke(&proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation)
result := filter.Invoke(context.Background(), &proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
}
......@@ -124,7 +124,7 @@ func TestGenericServiceFilter_ResponseTestStruct(t *testing.T) {
filter := GetGenericServiceFilter()
methodName := "$invoke"
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
r := filter.OnResponse(result, nil, rpcInvocation)
r := filter.OnResponse(nil, result, nil, rpcInvocation)
assert.NotNil(t, r.Result())
assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.Map)
}
......@@ -142,7 +142,7 @@ func TestGenericServiceFilter_ResponseString(t *testing.T) {
filter := GetGenericServiceFilter()
methodName := "$invoke"
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
r := filter.OnResponse(result, nil, rpcInvocation)
r := filter.OnResponse(nil, result, nil, rpcInvocation)
assert.NotNil(t, r.Result())
assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.String)
}
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"sync/atomic"
)
......@@ -52,16 +53,16 @@ type gracefulShutdownFilter struct {
shutdownConfig *config.ShutdownConfig
}
func (gf *gracefulShutdownFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (gf *gracefulShutdownFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
if gf.rejectNewRequest() {
logger.Info("The application is closing, new request will be rejected.")
return gf.getRejectHandler().RejectedExecution(invoker.GetUrl(), invocation)
}
atomic.AddInt32(&gf.activeCount, 1)
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (gf *gracefulShutdownFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (gf *gracefulShutdownFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
atomic.AddInt32(&gf.activeCount, -1)
// although this isn't thread safe, it won't be a problem if the gf.rejectNewRequest() is true.
if gf.shutdownConfig != nil && gf.activeCount <= 0 {
......
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"net/url"
"testing"
)
......@@ -53,7 +54,7 @@ func TestGenericFilter_Invoke(t *testing.T) {
assert.Equal(t, extension.GetRejectedExecutionHandler(constant.DEFAULT_KEY),
shutdownFilter.getRejectHandler())
result := shutdownFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl), invoc)
result := shutdownFilter.Invoke(context.Background(), protocol.NewBaseInvoker(*invokeUrl), invoc)
assert.NotNil(t, result)
assert.Nil(t, result.Error())
......@@ -64,7 +65,7 @@ func TestGenericFilter_Invoke(t *testing.T) {
shutdownFilter.shutdownConfig = providerConfig.ShutdownConfig
assert.True(t, shutdownFilter.rejectNewRequest())
result = shutdownFilter.OnResponse(nil, protocol.NewBaseInvoker(*invokeUrl), invoc)
result = shutdownFilter.OnResponse(nil, nil, protocol.NewBaseInvoker(*invokeUrl), invoc)
rejectHandler := &common2.OnlyLogRejectedExecutionHandler{}
extension.SetRejectedExecutionHandler("mock", func() filter.RejectedExecutionHandler {
......
......@@ -17,6 +17,7 @@
package filter_impl
import (
"context"
"fmt"
"regexp"
"sync"
......@@ -82,7 +83,7 @@ type HystrixFilter struct {
ifNewMap sync.Map
}
func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (hf *HystrixFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
cmdName := fmt.Sprintf("%s&method=%s", invoker.GetUrl().Key(), invocation.MethodName())
......@@ -115,12 +116,12 @@ func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.In
configLoadMutex.RUnlock()
if err != nil {
logger.Errorf("[Hystrix Filter]Errors occurred getting circuit for %s , will invoke without hystrix, error is: ", cmdName, err)
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
logger.Infof("[Hystrix Filter]Using hystrix filter: %s", cmdName)
var result protocol.Result
_ = hystrix.Do(cmdName, func() error {
result = invoker.Invoke(invocation)
result = invoker.Invoke(ctx, invocation)
err := result.Error()
if err != nil {
result.SetError(NewHystrixFilterError(err, false))
......@@ -144,7 +145,7 @@ func (hf *HystrixFilter) Invoke(invoker protocol.Invoker, invocation protocol.In
return result
}
func (hf *HystrixFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (hf *HystrixFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
func GetHystrixFilterConsumer() filter.Filter {
......
......@@ -17,6 +17,7 @@
package filter_impl
import (
"context"
"regexp"
"testing"
)
......@@ -125,7 +126,7 @@ type testMockSuccessInvoker struct {
protocol.BaseInvoker
}
func (iv *testMockSuccessInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
func (iv *testMockSuccessInvoker) Invoke(context context.Context, invocation protocol.Invocation) protocol.Result {
return &protocol.RPCResult{
Rest: "Sucess",
Err: nil,
......@@ -136,7 +137,7 @@ type testMockFailInvoker struct {
protocol.BaseInvoker
}
func (iv *testMockFailInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
func (iv *testMockFailInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
return &protocol.RPCResult{
Err: errors.Errorf("exception"),
}
......@@ -144,7 +145,7 @@ func (iv *testMockFailInvoker) Invoke(invocation protocol.Invocation) protocol.R
func TestHystrixFilter_Invoke_Success(t *testing.T) {
hf := &HystrixFilter{}
result := hf.Invoke(&testMockSuccessInvoker{}, &invocation.RPCInvocation{})
result := hf.Invoke(context.Background(), &testMockSuccessInvoker{}, &invocation.RPCInvocation{})
assert.NotNil(t, result)
assert.NoError(t, result.Error())
assert.NotNil(t, result.Result())
......@@ -152,7 +153,7 @@ func TestHystrixFilter_Invoke_Success(t *testing.T) {
func TestHystrixFilter_Invoke_Fail(t *testing.T) {
hf := &HystrixFilter{}
result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{})
result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{})
assert.NotNil(t, result)
assert.Error(t, result.Error())
}
......@@ -164,7 +165,7 @@ func TestHystricFilter_Invoke_CircuitBreak(t *testing.T) {
resChan := make(chan protocol.Result, 50)
for i := 0; i < 50; i++ {
go func() {
result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{})
result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{})
resChan <- result
}()
}
......@@ -189,7 +190,7 @@ func TestHystricFilter_Invoke_CircuitBreak_Omit_Exception(t *testing.T) {
resChan := make(chan protocol.Result, 50)
for i := 0; i < 50; i++ {
go func() {
result := hf.Invoke(&testMockFailInvoker{}, &invocation.RPCInvocation{})
result := hf.Invoke(context.Background(), &testMockFailInvoker{}, &invocation.RPCInvocation{})
resChan <- result
}()
}
......
......@@ -18,6 +18,7 @@ limitations under the License.
package filter_impl
import (
"context"
"strings"
)
......@@ -42,22 +43,22 @@ func init() {
type TokenFilter struct{}
func (tf *TokenFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (tf *TokenFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
invokerTkn := invoker.GetUrl().GetParam(constant.TOKEN_KEY, "")
if len(invokerTkn) > 0 {
attachs := invocation.Attachments()
remoteTkn, exist := attachs[constant.TOKEN_KEY]
if exist && strings.EqualFold(invokerTkn, remoteTkn) {
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
return &protocol.RPCResult{Err: perrors.Errorf("Invalid token! Forbid invoke remote service %v method %s ",
invoker, invocation.MethodName())}
}
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (tf *TokenFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (tf *TokenFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
......
......@@ -18,6 +18,7 @@ limitations under the License.
package filter_impl
import (
"context"
"net/url"
"testing"
)
......@@ -41,8 +42,10 @@ func TestTokenFilter_Invoke(t *testing.T) {
common.WithParamsValue(constant.TOKEN_KEY, "ori_key"))
attch := make(map[string]string, 0)
attch[constant.TOKEN_KEY] = "ori_key"
result := filter.Invoke(protocol.NewBaseInvoker(*url),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := filter.Invoke(context.Background(),
protocol.NewBaseInvoker(*url),
invocation.NewRPCInvocation("MethodName",
[]interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
......@@ -53,8 +56,7 @@ func TestTokenFilter_InvokeEmptyToken(t *testing.T) {
url := common.URL{}
attch := make(map[string]string, 0)
attch[constant.TOKEN_KEY] = "ori_key"
result := filter.Invoke(protocol.NewBaseInvoker(url),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
......@@ -66,8 +68,7 @@ func TestTokenFilter_InvokeEmptyAttach(t *testing.T) {
common.WithParams(url.Values{}),
common.WithParamsValue(constant.TOKEN_KEY, "ori_key"))
attch := make(map[string]string, 0)
result := filter.Invoke(protocol.NewBaseInvoker(*url),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := filter.Invoke(context.Background(), protocol.NewBaseInvoker(*url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.NotNil(t, result.Error())
}
......@@ -79,7 +80,7 @@ func TestTokenFilter_InvokeNotEqual(t *testing.T) {
common.WithParamsValue(constant.TOKEN_KEY, "ori_key"))
attch := make(map[string]string, 0)
attch[constant.TOKEN_KEY] = "err_key"
result := filter.Invoke(protocol.NewBaseInvoker(*url),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := filter.Invoke(context.Background(),
protocol.NewBaseInvoker(*url), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.NotNil(t, result.Error())
}
......@@ -17,6 +17,9 @@
package filter_impl
import (
"context"
)
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
......@@ -51,22 +54,22 @@ func init() {
type TpsLimitFilter struct {
}
func (t TpsLimitFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (t TpsLimitFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
url := invoker.GetUrl()
tpsLimiter := url.GetParam(constant.TPS_LIMITER_KEY, "")
rejectedExeHandler := url.GetParam(constant.TPS_REJECTED_EXECUTION_HANDLER_KEY, constant.DEFAULT_KEY)
if len(tpsLimiter) > 0 {
allow := extension.GetTpsLimiter(tpsLimiter).IsAllowable(invoker.GetUrl(), invocation)
if allow {
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
logger.Errorf("The invocation was rejected due to over the tps limitation, url: %s ", url.String())
return extension.GetRejectedExecutionHandler(rejectedExeHandler).RejectedExecution(url, invocation)
}
return invoker.Invoke(invocation)
return invoker.Invoke(ctx, invocation)
}
func (t TpsLimitFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
func (t TpsLimitFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
......
......@@ -18,6 +18,7 @@
package filter_impl
import (
"context"
"net/url"
"testing"
)
......@@ -45,8 +46,10 @@ func TestTpsLimitFilter_Invoke_With_No_TpsLimiter(t *testing.T) {
common.WithParamsValue(constant.TPS_LIMITER_KEY, ""))
attch := make(map[string]string, 0)
result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := tpsFilter.Invoke(context.Background(),
protocol.NewBaseInvoker(*invokeUrl),
invocation.NewRPCInvocation("MethodName",
[]interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
......@@ -67,8 +70,10 @@ func TestGenericFilter_Invoke_With_Default_TpsLimiter(t *testing.T) {
common.WithParamsValue(constant.TPS_LIMITER_KEY, constant.DEFAULT_KEY))
attch := make(map[string]string, 0)
result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := tpsFilter.Invoke(context.Background(),
protocol.NewBaseInvoker(*invokeUrl),
invocation.NewRPCInvocation("MethodName",
[]interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
......@@ -96,8 +101,8 @@ func TestGenericFilter_Invoke_With_Default_TpsLimiter_Not_Allow(t *testing.T) {
common.WithParamsValue(constant.TPS_LIMITER_KEY, constant.DEFAULT_KEY))
attch := make(map[string]string, 0)
result := tpsFilter.Invoke(protocol.NewBaseInvoker(*invokeUrl),
invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
result := tpsFilter.Invoke(context.Background(),
protocol.NewBaseInvoker(*invokeUrl), invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, attch))
assert.Nil(t, result.Error())
assert.Nil(t, result.Result())
}
......@@ -18,6 +18,7 @@
package dubbo
import (
"context"
"strconv"
"sync"
)
......@@ -53,7 +54,7 @@ func NewDubboInvoker(url common.URL, client *Client) *DubboInvoker {
}
}
func (di *DubboInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
var (
err error
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment