diff --git a/protocol/rest/rest_invoker_test.go b/protocol/rest/rest_invoker_test.go index 42a4fbd358955e8bd2d78a80818090baf62fa784..e44c5d9a21026992178bd432676c99bc837c361b 100644 --- a/protocol/rest/rest_invoker_test.go +++ b/protocol/rest/rest_invoker_test.go @@ -24,6 +24,7 @@ import ( ) import ( + "github.com/emicklei/go-restful/v3" "github.com/stretchr/testify/assert" ) @@ -35,12 +36,25 @@ import ( "github.com/apache/dubbo-go/protocol/rest/client" "github.com/apache/dubbo-go/protocol/rest/client/client_impl" rest_config "github.com/apache/dubbo-go/protocol/rest/config" + "github.com/apache/dubbo-go/protocol/rest/server/server_impl" ) func TestRestInvoker_Invoke(t *testing.T) { // Refer proto := GetRestProtocol() defer proto.Destroy() + var filterNum int + server_impl.AddGoRestfulServerFilter(func(request *restful.Request, response *restful.Response, chain *restful.FilterChain) { + println(request.SelectedRoutePath()) + filterNum = filterNum + 1 + chain.ProcessFilter(request, response) + }) + server_impl.AddGoRestfulServerFilter(func(request *restful.Request, response *restful.Response, chain *restful.FilterChain) { + println("filter2") + filterNum = filterNum + 1 + chain.ProcessFilter(request, response) + }) + url, err := common.NewURL("rest://127.0.0.1:8877/com.ikurento.user.UserProvider?anyhost=true&" + "application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&" + "environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%2C&" + @@ -191,6 +205,7 @@ func TestRestInvoker_Invoke(t *testing.T) { res = invoker.Invoke(context.Background(), inv) assert.Error(t, res.Error(), "test error") + assert.Equal(t, filterNum, 12) err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider") assert.NoError(t, err) } diff --git a/protocol/rest/server/server_impl/go_restful_server.go b/protocol/rest/server/server_impl/go_restful_server.go index 3ea25531d62f5bd5fdb3b4be3e0fd3892b6b6b54..69f36a5c80aa51f52dfcfabc5a1bd4003f4cd727 100644 --- a/protocol/rest/server/server_impl/go_restful_server.go +++ b/protocol/rest/server/server_impl/go_restful_server.go @@ -48,6 +48,8 @@ func init() { extension.SetRestServer(constant.DEFAULT_REST_SERVER, GetNewGoRestfulServer) } +var filterSlice []restful.FilterFunction + type GoRestfulServer struct { srv *http.Server container *restful.Container @@ -59,6 +61,9 @@ func NewGoRestfulServer() *GoRestfulServer { func (grs *GoRestfulServer) Start(url common.URL) { grs.container = restful.NewContainer() + for _, filter := range filterSlice { + grs.container.Filter(filter) + } grs.srv = &http.Server{ Handler: grs.container, } @@ -309,3 +314,9 @@ func getArgsFromRequest(req *restful.Request, argsTypes []reflect.Type, config * func GetNewGoRestfulServer() server.RestServer { return NewGoRestfulServer() } + +// Let user addFilter +// addFilter should before config.Load() +func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) { + filterSlice = append(filterSlice, filterFuc) +}