diff --git a/protocol/rest/server/server_impl/go_restful_server.go b/protocol/rest/server/server_impl/go_restful_server.go index a630a34edb55ef39b356ea04a940b0120e38912c..a6d223b5bb9ded77b41a71fa69d296980564f61e 100644 --- a/protocol/rest/server/server_impl/go_restful_server.go +++ b/protocol/rest/server/server_impl/go_restful_server.go @@ -20,7 +20,6 @@ package server_impl import ( "context" "fmt" - "github.com/apache/dubbo-go/protocol/rest/server" "net" "net/http" "reflect" @@ -42,12 +41,15 @@ import ( "github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/protocol" "github.com/apache/dubbo-go/protocol/invocation" + "github.com/apache/dubbo-go/protocol/rest/server" ) func init() { extension.SetRestServer(constant.DEFAULT_REST_SERVER, GetNewGoRestfulServer) } +var filterSlice []restful.FilterFunction + type GoRestfulServer struct { srv *http.Server container *restful.Container @@ -59,6 +61,11 @@ func NewGoRestfulServer() *GoRestfulServer { func (grs *GoRestfulServer) Start(url common.URL) { grs.container = restful.NewContainer() + if len(filterSlice) > 0 { + for _, filter := range filterSlice { + grs.container.Filter(filter) + } + } grs.srv = &http.Server{ Handler: grs.container, } @@ -309,3 +316,9 @@ func getArgsFromRequest(req *restful.Request, argsTypes []reflect.Type, config * func GetNewGoRestfulServer() server.RestServer { return NewGoRestfulServer() } + +// Let user addFilter +// addFilter should before config.Load() +func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) { + filterSlice = append(filterSlice, filterFuc) +}