diff --git a/protocol/rest/client/client_impl/resty_client.go b/protocol/rest/client/client_impl/resty_client.go index aa6c23137dc68492948b85a85555a5340572ac49..af9637ea91ca12b92ad94ce2616be7911876a928 100644 --- a/protocol/rest/client/client_impl/resty_client.go +++ b/protocol/rest/client/client_impl/resty_client.go @@ -65,7 +65,7 @@ func NewRestyClient(restOption *client.RestOptions) client.RestClient { } } -func (rc *RestyClient) Do(restRequest *client.RestRequest, res interface{}) error { +func (rc *RestyClient) Do(restRequest *client.RestClientRequest, res interface{}) error { r, err := rc.client.R(). SetHeader("Content-Type", restRequest.Consumes). SetHeader("Accept", restRequest.Produces). diff --git a/protocol/rest/client/rest_client.go b/protocol/rest/client/rest_client.go index 7d020abc81c2bd44473ed25ffec4b2b657e7bcc0..3acccb53ae187f5b09910403a16ad657316c93de 100644 --- a/protocol/rest/client/rest_client.go +++ b/protocol/rest/client/rest_client.go @@ -26,7 +26,7 @@ type RestOptions struct { ConnectTimeout time.Duration } -type RestRequest struct { +type RestClientRequest struct { Location string Path string Produces string @@ -39,5 +39,5 @@ type RestRequest struct { } type RestClient interface { - Do(request *RestRequest, res interface{}) error + Do(request *RestClientRequest, res interface{}) error } diff --git a/protocol/rest/rest_invoker.go b/protocol/rest/rest_invoker.go index 0c82035ac5eb9a52ab188baa971dbdf1b864e970..c8e3feaff6e37b2862b78a7d0da7ec06cfe09ebf 100644 --- a/protocol/rest/rest_invoker.go +++ b/protocol/rest/rest_invoker.go @@ -78,8 +78,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio if len(inv.Arguments()) > methodConfig.Body && methodConfig.Body >= 0 { body = inv.Arguments()[methodConfig.Body] } - - req := &client.RestRequest{ + req := &client.RestClientRequest{ Location: ri.GetUrl().Location, Produces: methodConfig.Produces, Consumes: methodConfig.Consumes, diff --git a/protocol/rest/rest_protocol.go b/protocol/rest/rest_protocol.go index 47ecb6093b4cfa12a1d3397fa45d59b1e173a93a..e15eeb39d72212eb9f1d0235313eba231d3b0a36 100644 --- a/protocol/rest/rest_protocol.go +++ b/protocol/rest/rest_protocol.go @@ -75,7 +75,9 @@ func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter { } rp.SetExporterMap(serviceKey, exporter) restServer := rp.getServer(url, restServiceConfig.Server) - restServer.Deploy(invoker, restServiceConfig.RestMethodConfigsMap) + for _, methodConfig := range restServiceConfig.RestMethodConfigsMap { + restServer.Deploy(methodConfig, server.GetRouteFunc(invoker, methodConfig)) + } return exporter } diff --git a/protocol/rest/server/rest_server.go b/protocol/rest/server/rest_server.go index c10c98a7b677d47c43b64643a69d5b3768a6c663..b7eb555625cedf3a96bb56829720e1bdd5ef84e6 100644 --- a/protocol/rest/server/rest_server.go +++ b/protocol/rest/server/rest_server.go @@ -17,15 +17,270 @@ package server +import ( + "context" + "net/http" + "reflect" + "strconv" + "strings" +) + +import ( + perrors "github.com/pkg/errors" +) + import ( "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/protocol" - "github.com/apache/dubbo-go/protocol/rest/config" + "github.com/apache/dubbo-go/protocol/invocation" + rest_config "github.com/apache/dubbo-go/protocol/rest/config" ) type RestServer interface { Start(url common.URL) - Deploy(invoker protocol.Invoker, restMethodConfig map[string]*config.RestMethodConfig) - UnDeploy(restMethodConfig map[string]*config.RestMethodConfig) + Deploy(restMethodConfig *rest_config.RestMethodConfig, routeFunc func(request RestServerRequest, response RestServerResponse)) + UnDeploy(restMethodConfig *rest_config.RestMethodConfig) Destroy() } + +// RestServerRequest interface +type RestServerRequest interface { + PathParameter(name string) string + PathParameters() map[string]string + QueryParameter(name string) string + QueryParameters(name string) []string + BodyParameter(name string) (string, error) + HeaderParameter(name string) string + ReadEntity(entityPointer interface{}) error +} + +// RestServerResponse interface +type RestServerResponse interface { + WriteError(httpStatus int, err error) (writeErr error) + WriteEntity(value interface{}) error +} + +func GetRouteFunc(invoker protocol.Invoker, methodConfig *rest_config.RestMethodConfig) func(req RestServerRequest, resp RestServerResponse) { + return func(req RestServerRequest, resp RestServerResponse) { + var ( + err error + args []interface{} + ) + svc := common.ServiceMap.GetService(invoker.GetUrl().Protocol, strings.TrimPrefix(invoker.GetUrl().Path, "/")) + // get method + method := svc.Method()[methodConfig.MethodName] + argsTypes := method.ArgsType() + replyType := method.ReplyType() + if (len(argsTypes) == 1 || len(argsTypes) == 2 && replyType == nil) && + argsTypes[0].String() == "[]interface {}" { + args = getArgsInterfaceFromRequest(req, methodConfig) + } else { + args = getArgsFromRequest(req, argsTypes, methodConfig) + } + result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodConfig.MethodName, args, make(map[string]string))) + if result.Error() != nil { + err = resp.WriteError(http.StatusInternalServerError, result.Error()) + if err != nil { + logger.Errorf("[Go Restful] WriteError error:%v", err) + } + return + } + err = resp.WriteEntity(result.Result()) + if err != nil { + logger.Errorf("[Go Restful] WriteEntity error:%v", err) + } + } +} + +// when service function like GetUser(req []interface{}, rsp *User) error +// use this method to get arguments +func getArgsInterfaceFromRequest(req RestServerRequest, methodConfig *rest_config.RestMethodConfig) []interface{} { + argsMap := make(map[int]interface{}, 8) + maxKey := 0 + for k, v := range methodConfig.PathParamsMap { + if maxKey < k { + maxKey = k + } + argsMap[k] = req.PathParameter(v) + } + for k, v := range methodConfig.QueryParamsMap { + if maxKey < k { + maxKey = k + } + params := req.QueryParameters(v) + if len(params) == 1 { + argsMap[k] = params[0] + } else { + argsMap[k] = params + } + } + for k, v := range methodConfig.HeadersMap { + if maxKey < k { + maxKey = k + } + argsMap[k] = req.HeaderParameter(v) + } + if methodConfig.Body >= 0 { + if maxKey < methodConfig.Body { + maxKey = methodConfig.Body + } + m := make(map[string]interface{}) + // TODO read as a slice + if err := req.ReadEntity(&m); err != nil { + logger.Warnf("[Go restful] Read body entity as map[string]interface{} error:%v", perrors.WithStack(err)) + } else { + argsMap[methodConfig.Body] = m + } + } + args := make([]interface{}, maxKey+1) + for k, v := range argsMap { + if k >= 0 { + args[k] = v + } + } + return args +} + +// get arguments from server.RestServerRequest +func getArgsFromRequest(req RestServerRequest, argsTypes []reflect.Type, methodConfig *rest_config.RestMethodConfig) []interface{} { + argsLength := len(argsTypes) + args := make([]interface{}, argsLength) + for i, t := range argsTypes { + args[i] = reflect.Zero(t).Interface() + } + assembleArgsFromPathParams(methodConfig, argsLength, argsTypes, req, args) + assembleArgsFromQueryParams(methodConfig, argsLength, argsTypes, req, args) + assembleArgsFromBody(methodConfig, argsTypes, req, args) + assembleArgsFromHeaders(methodConfig, req, argsLength, argsTypes, args) + return args +} + +// assemble arguments from headers +func assembleArgsFromHeaders(methodConfig *rest_config.RestMethodConfig, req RestServerRequest, argsLength int, argsTypes []reflect.Type, args []interface{}) { + for k, v := range methodConfig.HeadersMap { + param := req.HeaderParameter(v) + if k < 0 || k >= argsLength { + logger.Errorf("[Go restful] Header param parse error, the args:%v doesn't exist", k) + continue + } + t := argsTypes[k] + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.String { + args[k] = param + } else { + logger.Errorf("[Go restful] Header param parse error, the args:%v of type isn't string", k) + } + } +} + +// assemble arguments from body +func assembleArgsFromBody(methodConfig *rest_config.RestMethodConfig, argsTypes []reflect.Type, req RestServerRequest, args []interface{}) { + if methodConfig.Body >= 0 && methodConfig.Body < len(argsTypes) { + t := argsTypes[methodConfig.Body] + kind := t.Kind() + if kind == reflect.Ptr { + t = t.Elem() + } + var ni interface{} + if t.String() == "[]interface {}" { + ni = make([]map[string]interface{}, 0) + } else if t.String() == "interface {}" { + ni = make(map[string]interface{}) + } else { + n := reflect.New(t) + if n.CanInterface() { + ni = n.Interface() + } + } + if err := req.ReadEntity(&ni); err != nil { + logger.Errorf("[Go restful] Read body entity error:%v", err) + } else { + args[methodConfig.Body] = ni + } + } +} + +// assemble arguments from query params +func assembleArgsFromQueryParams(methodConfig *rest_config.RestMethodConfig, argsLength int, argsTypes []reflect.Type, req RestServerRequest, args []interface{}) { + var ( + err error + param interface{} + i64 int64 + ) + for k, v := range methodConfig.QueryParamsMap { + if k < 0 || k >= argsLength { + logger.Errorf("[Go restful] Query param parse error, the args:%v doesn't exist", k) + continue + } + t := argsTypes[k] + kind := t.Kind() + if kind == reflect.Ptr { + t = t.Elem() + } + if kind == reflect.Slice { + param = req.QueryParameters(v) + } else if kind == reflect.String { + param = req.QueryParameter(v) + } else if kind == reflect.Int { + param, err = strconv.Atoi(req.QueryParameter(v)) + } else if kind == reflect.Int32 { + i64, err = strconv.ParseInt(req.QueryParameter(v), 10, 32) + if err == nil { + param = int32(i64) + } + } else if kind == reflect.Int64 { + param, err = strconv.ParseInt(req.QueryParameter(v), 10, 64) + } else { + logger.Errorf("[Go restful] Query param parse error, the args:%v of type isn't int or string or slice", k) + continue + } + if err != nil { + logger.Errorf("[Go restful] Query param parse error, error is %v", err) + continue + } + args[k] = param + } +} + +// assemble arguments from path params +func assembleArgsFromPathParams(methodConfig *rest_config.RestMethodConfig, argsLength int, argsTypes []reflect.Type, req RestServerRequest, args []interface{}) { + var ( + err error + param interface{} + i64 int64 + ) + for k, v := range methodConfig.PathParamsMap { + if k < 0 || k >= argsLength { + logger.Errorf("[Go restful] Path param parse error, the args:%v doesn't exist", k) + continue + } + t := argsTypes[k] + kind := t.Kind() + if kind == reflect.Ptr { + t = t.Elem() + } + if kind == reflect.Int { + param, err = strconv.Atoi(req.PathParameter(v)) + } else if kind == reflect.Int32 { + i64, err = strconv.ParseInt(req.PathParameter(v), 10, 32) + if err == nil { + param = int32(i64) + } + } else if kind == reflect.Int64 { + param, err = strconv.ParseInt(req.PathParameter(v), 10, 64) + } else if kind == reflect.String { + param = req.PathParameter(v) + } else { + logger.Warnf("[Go restful] Path param parse error, the args:%v of type isn't int or string", k) + continue + } + if err != nil { + logger.Errorf("[Go restful] Path param parse error, error is %v", err) + continue + } + args[k] = param + } +} diff --git a/protocol/rest/server/server_impl/go_restful_server.go b/protocol/rest/server/server_impl/go_restful_server.go index 812699b3b66987ec895541d45419b0d6c9ba2b70..3c81fe0bfda901fc905ffb31be153d916ed9acff 100644 --- a/protocol/rest/server/server_impl/go_restful_server.go +++ b/protocol/rest/server/server_impl/go_restful_server.go @@ -22,8 +22,6 @@ import ( "fmt" "net" "net/http" - "reflect" - "strconv" "strings" "time" ) @@ -38,8 +36,6 @@ import ( "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" "github.com/apache/dubbo-go/common/logger" - "github.com/apache/dubbo-go/protocol" - "github.com/apache/dubbo-go/protocol/invocation" "github.com/apache/dubbo-go/protocol/rest/config" "github.com/apache/dubbo-go/protocol/rest/server" ) @@ -80,57 +76,27 @@ func (grs *GoRestfulServer) Start(url common.URL) { }() } -func (grs *GoRestfulServer) Deploy(invoker protocol.Invoker, restMethodConfig map[string]*config.RestMethodConfig) { - for methodName, config := range restMethodConfig { - ws := new(restful.WebService) - ws.Path(config.Path). - Produces(strings.Split(config.Produces, ",")...). - Consumes(strings.Split(config.Consumes, ",")...). - Route(ws.Method(config.MethodType).To(getFunc(methodName, invoker, config))) - grs.container.Add(ws) +func (grs *GoRestfulServer) Deploy(restMethodConfig *config.RestMethodConfig, routeFunc func(request server.RestServerRequest, response server.RestServerResponse)) { + ws := new(restful.WebService) + rf := func(req *restful.Request, resp *restful.Response) { + respAdapter := NewGoRestfulResponseAdapter(resp) + reqAdapter := NewGoRestfulRequestAdapter(req) + routeFunc(reqAdapter, respAdapter) } + ws.Path(restMethodConfig.Path). + Produces(strings.Split(restMethodConfig.Produces, ",")...). + Consumes(strings.Split(restMethodConfig.Consumes, ",")...). + Route(ws.Method(restMethodConfig.MethodType).To(rf)) + grs.container.Add(ws) } -func getFunc(methodName string, invoker protocol.Invoker, config *config.RestMethodConfig) func(req *restful.Request, resp *restful.Response) { - return func(req *restful.Request, resp *restful.Response) { - var ( - err error - args []interface{} - ) - svc := common.ServiceMap.GetService(invoker.GetUrl().Protocol, strings.TrimPrefix(invoker.GetUrl().Path, "/")) - // get method - method := svc.Method()[methodName] - argsTypes := method.ArgsType() - replyType := method.ReplyType() - if (len(argsTypes) == 1 || len(argsTypes) == 2 && replyType == nil) && - argsTypes[0].String() == "[]interface {}" { - args = getArgsInterfaceFromRequest(req, config) - } else { - args = getArgsFromRequest(req, argsTypes, config) - } - result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodName, args, make(map[string]string))) - if result.Error() != nil { - err = resp.WriteError(http.StatusInternalServerError, result.Error()) - if err != nil { - logger.Errorf("[Go Restful] WriteError error:%v", err) - } - return - } - err = resp.WriteEntity(result.Result()) - if err != nil { - logger.Error("[Go Restful] WriteEntity error:%v", err) - } - } -} -func (grs *GoRestfulServer) UnDeploy(restMethodConfig map[string]*config.RestMethodConfig) { - for _, config := range restMethodConfig { - ws := new(restful.WebService) - ws.Path(config.Path) - err := grs.container.Remove(ws) - if err != nil { - logger.Warnf("[Go restful] Remove web service error:%v", err) - } +func (grs *GoRestfulServer) UnDeploy(restMethodConfig *config.RestMethodConfig) { + ws := new(restful.WebService) + ws.Path(restMethodConfig.Path) + err := grs.container.Remove(ws) + if err != nil { + logger.Warnf("[Go restful] Remove web service error:%v", err) } } @@ -143,179 +109,66 @@ func (grs *GoRestfulServer) Destroy() { logger.Infof("[Go Restful] Server exiting") } -func getArgsInterfaceFromRequest(req *restful.Request, config *config.RestMethodConfig) []interface{} { - argsMap := make(map[int]interface{}, 8) - maxKey := 0 - for k, v := range config.PathParamsMap { - if maxKey < k { - maxKey = k - } - argsMap[k] = req.PathParameter(v) - } - for k, v := range config.QueryParamsMap { - if maxKey < k { - maxKey = k - } - params := req.QueryParameters(v) - if len(params) == 1 { - argsMap[k] = params[0] - } else { - argsMap[k] = params - } - } - for k, v := range config.HeadersMap { - if maxKey < k { - maxKey = k - } - argsMap[k] = req.HeaderParameter(v) - } - if config.Body >= 0 { - if maxKey < config.Body { - maxKey = config.Body - } - m := make(map[string]interface{}) - // TODO read as a slice - if err := req.ReadEntity(&m); err != nil { - logger.Warnf("[Go restful] Read body entity as map[string]interface{} error:%v", perrors.WithStack(err)) - } else { - argsMap[config.Body] = m - } - } - args := make([]interface{}, maxKey+1) - for k, v := range argsMap { - if k >= 0 { - args[k] = v - } - } - return args +func GetNewGoRestfulServer() server.RestServer { + return NewGoRestfulServer() } -func getArgsFromRequest(req *restful.Request, argsTypes []reflect.Type, config *config.RestMethodConfig) []interface{} { - argsLength := len(argsTypes) - args := make([]interface{}, argsLength) - for i, t := range argsTypes { - args[i] = reflect.Zero(t).Interface() - } - var ( - err error - param interface{} - i64 int64 - ) - for k, v := range config.PathParamsMap { - if k < 0 || k >= argsLength { - logger.Errorf("[Go restful] Path param parse error, the args:%v doesn't exist", k) - continue - } - t := argsTypes[k] - kind := t.Kind() - if kind == reflect.Ptr { - t = t.Elem() - } - if kind == reflect.Int { - param, err = strconv.Atoi(req.PathParameter(v)) - } else if kind == reflect.Int32 { - i64, err = strconv.ParseInt(req.PathParameter(v), 10, 32) - if err == nil { - param = int32(i64) - } - } else if kind == reflect.Int64 { - param, err = strconv.ParseInt(req.PathParameter(v), 10, 64) - } else if kind == reflect.String { - param = req.PathParameter(v) - } else { - logger.Warnf("[Go restful] Path param parse error, the args:%v of type isn't int or string", k) - continue - } - if err != nil { - logger.Errorf("[Go restful] Path param parse error, error is %v", err) - continue - } - args[k] = param - } - for k, v := range config.QueryParamsMap { - if k < 0 || k >= argsLength { - logger.Errorf("[Go restful] Query param parse error, the args:%v doesn't exist", k) - continue - } - t := argsTypes[k] - kind := t.Kind() - if kind == reflect.Ptr { - t = t.Elem() - } - if kind == reflect.Slice { - param = req.QueryParameters(v) - } else if kind == reflect.String { - param = req.QueryParameter(v) - } else if kind == reflect.Int { - param, err = strconv.Atoi(req.QueryParameter(v)) - } else if kind == reflect.Int32 { - i64, err = strconv.ParseInt(req.QueryParameter(v), 10, 32) - if err == nil { - param = int32(i64) - } - } else if kind == reflect.Int64 { - param, err = strconv.ParseInt(req.QueryParameter(v), 10, 64) - } else { - logger.Errorf("[Go restful] Query param parse error, the args:%v of type isn't int or string or slice", k) - continue - } - if err != nil { - logger.Errorf("[Go restful] Query param parse error, error is %v", err) - continue - } - args[k] = param - } +// Let user addFilter +// addFilter should before config.Load() +func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) { + filterSlice = append(filterSlice, filterFuc) +} - if config.Body >= 0 && config.Body < len(argsTypes) { - t := argsTypes[config.Body] - kind := t.Kind() - if kind == reflect.Ptr { - t = t.Elem() - } - var ni interface{} - if t.String() == "[]interface {}" { - ni = make([]map[string]interface{}, 0) - } else if t.String() == "interface {}" { - ni = make(map[string]interface{}) - } else { - n := reflect.New(t) - if n.CanInterface() { - ni = n.Interface() - } - } - if err := req.ReadEntity(&ni); err != nil { - logger.Errorf("[Go restful] Read body entity error:%v", err) - } else { - args[config.Body] = ni - } - } +// Go Restful Request adapt to RestServerRequest +type GoRestfulRequestAdapter struct { + rawRequest *restful.Request +} - for k, v := range config.HeadersMap { - param := req.HeaderParameter(v) - if k < 0 || k >= argsLength { - logger.Errorf("[Go restful] Header param parse error, the args:%v doesn't exist", k) - continue - } - t := argsTypes[k] - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if t.Kind() == reflect.String { - args[k] = param - } else { - logger.Errorf("[Go restful] Header param parse error, the args:%v of type isn't string", k) - } - } +func (gra *GoRestfulRequestAdapter) PathParameter(name string) string { + return gra.rawRequest.PathParameter(name) +} - return args +func (gra *GoRestfulRequestAdapter) PathParameters() map[string]string { + return gra.rawRequest.PathParameters() } -func GetNewGoRestfulServer() server.RestServer { - return NewGoRestfulServer() +func (gra *GoRestfulRequestAdapter) QueryParameter(name string) string { + return gra.rawRequest.QueryParameter(name) } -// Let user addFilter -// addFilter should before config.Load() -func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) { - filterSlice = append(filterSlice, filterFuc) +func (gra *GoRestfulRequestAdapter) QueryParameters(name string) []string { + return gra.rawRequest.QueryParameters(name) +} + +func (gra *GoRestfulRequestAdapter) BodyParameter(name string) (string, error) { + return gra.rawRequest.BodyParameter(name) +} + +func (gra *GoRestfulRequestAdapter) HeaderParameter(name string) string { + return gra.rawRequest.HeaderParameter(name) +} + +func (gra *GoRestfulRequestAdapter) ReadEntity(entityPointer interface{}) error { + return gra.rawRequest.ReadEntity(entityPointer) +} + +func NewGoRestfulRequestAdapter(rawRequest *restful.Request) server.RestServerRequest { + return &GoRestfulRequestAdapter{rawRequest: rawRequest} +} + +// Go Restful Request adapt to RestClientRequest +type GoRestfulResponseAdapter struct { + rawResponse *restful.Response +} + +func (grsa *GoRestfulResponseAdapter) WriteError(httpStatus int, err error) (writeErr error) { + return grsa.rawResponse.WriteError(httpStatus, err) +} + +func (grsa *GoRestfulResponseAdapter) WriteEntity(value interface{}) error { + return grsa.rawResponse.WriteEntity(value) +} + +func NewGoRestfulResponseAdapter(rawResponse *restful.Response) server.RestServerResponse { + return &GoRestfulResponseAdapter{rawResponse: rawResponse} }