Skip to content
Snippets Groups Projects
Commit bdd57ef5 authored by vito.he's avatar vito.he Committed by GitHub
Browse files

Merge pull request #470 from apache/develop

Develop 2 feature/dubbo-2.7.5
parents 2e0d7102 f5bc9ed0
No related branches found
No related tags found
No related merge requests found
Showing
with 464 additions and 294 deletions
...@@ -50,7 +50,7 @@ type ( ...@@ -50,7 +50,7 @@ type (
func TestHTTPClient_Call(t *testing.T) { func TestHTTPClient_Call(t *testing.T) {
methods, err := common.ServiceMap.Register("jsonrpc", &UserProvider{}) methods, err := common.ServiceMap.Register("com.ikurento.user.UserProvider", "jsonrpc", &UserProvider{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods) assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods)
......
...@@ -43,8 +43,9 @@ func NewJsonrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync. ...@@ -43,8 +43,9 @@ func NewJsonrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.
// Unexport ... // Unexport ...
func (je *JsonrpcExporter) Unexport() { func (je *JsonrpcExporter) Unexport() {
serviceId := je.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "") serviceId := je.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
interfaceName := je.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
je.BaseExporter.Unexport() je.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(JSONRPC, serviceId) err := common.ServiceMap.UnRegister(interfaceName, JSONRPC, serviceId)
if err != nil { if err != nil {
logger.Errorf("[JsonrpcExporter.Unexport] error: %v", err) logger.Errorf("[JsonrpcExporter.Unexport] error: %v", err)
} }
......
...@@ -36,7 +36,7 @@ import ( ...@@ -36,7 +36,7 @@ import (
func TestJsonrpcInvoker_Invoke(t *testing.T) { func TestJsonrpcInvoker_Invoke(t *testing.T) {
methods, err := common.ServiceMap.Register("jsonrpc", &UserProvider{}) methods, err := common.ServiceMap.Register("UserProvider", "jsonrpc", &UserProvider{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods) assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods)
......
// Licensed to the Apache Software Foundation (ASF) under one or more /*
// contributor license agreements. See the NOTICE file distributed with * Licensed to the Apache Software Foundation (ASF) under one or more
// this work for additional information regarding copyright ownership. * contributor license agreements. See the NOTICE file distributed with
// The ASF licenses this file to You under the Apache License, Version 2.0 * this work for additional information regarding copyright ownership.
// (the "License"); you may not use this file except in compliance with * The ASF licenses this file to You under the Apache License, Version 2.0
// the License. You may obtain a copy of the License at * (the "License"); you may not use this file except in compliance with
// * the License. You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0 *
// * http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software *
// distributed under the License is distributed on an "AS IS" BASIS, * Unless required by applicable law or agreed to in writing, software
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// limitations under the License. * See the License for the specific language governing permissions and
// * limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: invoker.go // Source: invoker.go
......
...@@ -40,10 +40,12 @@ func init() { ...@@ -40,10 +40,12 @@ func init() {
extension.SetRestClient(constant.DEFAULT_REST_CLIENT, NewRestyClient) extension.SetRestClient(constant.DEFAULT_REST_CLIENT, NewRestyClient)
} }
// RestyClient a rest client implement by Resty
type RestyClient struct { type RestyClient struct {
client *resty.Client client *resty.Client
} }
// NewRestyClient a constructor of RestyClient
func NewRestyClient(restOption *client.RestOptions) client.RestClient { func NewRestyClient(restOption *client.RestOptions) client.RestClient {
client := resty.New() client := resty.New()
client.SetTransport( client.SetTransport(
...@@ -65,21 +67,21 @@ func NewRestyClient(restOption *client.RestOptions) client.RestClient { ...@@ -65,21 +67,21 @@ func NewRestyClient(restOption *client.RestOptions) client.RestClient {
} }
} }
func (rc *RestyClient) Do(restRequest *client.RestRequest, res interface{}) error { // Do send request by RestyClient
r, err := rc.client.R(). func (rc *RestyClient) Do(restRequest *client.RestClientRequest, res interface{}) error {
SetHeader("Content-Type", restRequest.Consumes). req := rc.client.R()
SetHeader("Accept", restRequest.Produces). req.Header = restRequest.Header
resp, err := req.
SetPathParams(restRequest.PathParams). SetPathParams(restRequest.PathParams).
SetQueryParams(restRequest.QueryParams). SetQueryParams(restRequest.QueryParams).
SetHeaders(restRequest.Headers).
SetBody(restRequest.Body). SetBody(restRequest.Body).
SetResult(res). SetResult(res).
Execute(restRequest.Method, "http://"+path.Join(restRequest.Location, restRequest.Path)) Execute(restRequest.Method, "http://"+path.Join(restRequest.Location, restRequest.Path))
if err != nil { if err != nil {
return perrors.WithStack(err) return perrors.WithStack(err)
} }
if r.IsError() { if resp.IsError() {
return perrors.New(r.String()) return perrors.New(resp.String())
} }
return nil return nil
} }
...@@ -18,26 +18,28 @@ ...@@ -18,26 +18,28 @@
package client package client
import ( import (
"net/http"
"time" "time"
) )
// RestOptions
type RestOptions struct { type RestOptions struct {
RequestTimeout time.Duration RequestTimeout time.Duration
ConnectTimeout time.Duration ConnectTimeout time.Duration
} }
type RestRequest struct { // RestClientRequest
type RestClientRequest struct {
Header http.Header
Location string Location string
Path string Path string
Produces string
Consumes string
Method string Method string
PathParams map[string]string PathParams map[string]string
QueryParams map[string]string QueryParams map[string]string
Body interface{} Body interface{}
Headers map[string]string
} }
// RestClient user can implement this client interface to send request
type RestClient interface { type RestClient interface {
Do(request *RestRequest, res interface{}) error Do(request *RestClientRequest, res interface{}) error
} }
...@@ -40,8 +40,9 @@ func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map ...@@ -40,8 +40,9 @@ func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
func (re *RestExporter) Unexport() { func (re *RestExporter) Unexport() {
serviceId := re.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "") serviceId := re.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
interfaceName := re.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
re.BaseExporter.Unexport() re.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(REST, serviceId) err := common.ServiceMap.UnRegister(interfaceName, REST, serviceId)
if err != nil { if err != nil {
logger.Errorf("[RestExporter.Unexport] error: %v", err) logger.Errorf("[RestExporter.Unexport] error: %v", err)
} }
......
...@@ -20,6 +20,7 @@ package rest ...@@ -20,6 +20,7 @@ package rest
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
) )
import ( import (
...@@ -56,7 +57,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio ...@@ -56,7 +57,7 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio
body interface{} body interface{}
pathParams map[string]string pathParams map[string]string
queryParams map[string]string queryParams map[string]string
headers map[string]string header http.Header
err error err error
) )
if methodConfig == nil { if methodConfig == nil {
...@@ -71,24 +72,21 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio ...@@ -71,24 +72,21 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio
result.Err = err result.Err = err
return &result return &result
} }
if headers, err = restStringMapTransform(methodConfig.HeadersMap, inv.Arguments()); err != nil { if header, err = getRestHttpHeader(methodConfig, inv.Arguments()); err != nil {
result.Err = err result.Err = err
return &result return &result
} }
if len(inv.Arguments()) > methodConfig.Body && methodConfig.Body >= 0 { if len(inv.Arguments()) > methodConfig.Body && methodConfig.Body >= 0 {
body = inv.Arguments()[methodConfig.Body] body = inv.Arguments()[methodConfig.Body]
} }
req := &client.RestClientRequest{
req := &client.RestRequest{
Location: ri.GetUrl().Location, Location: ri.GetUrl().Location,
Produces: methodConfig.Produces,
Consumes: methodConfig.Consumes,
Method: methodConfig.MethodType, Method: methodConfig.MethodType,
Path: methodConfig.Path, Path: methodConfig.Path,
PathParams: pathParams, PathParams: pathParams,
QueryParams: queryParams, QueryParams: queryParams,
Body: body, Body: body,
Headers: headers, Header: header,
} }
result.Err = ri.client.Do(req, inv.Reply()) result.Err = ri.client.Do(req, inv.Reply())
if result.Err == nil { if result.Err == nil {
...@@ -107,3 +105,17 @@ func restStringMapTransform(paramsMap map[int]string, args []interface{}) (map[s ...@@ -107,3 +105,17 @@ func restStringMapTransform(paramsMap map[int]string, args []interface{}) (map[s
} }
return resMap, nil return resMap, nil
} }
func getRestHttpHeader(methodConfig *config.RestMethodConfig, args []interface{}) (http.Header, error) {
header := http.Header{}
headersMap := methodConfig.HeadersMap
header.Set("Content-Type", methodConfig.Consumes)
header.Set("Accept", methodConfig.Produces)
for k, v := range headersMap {
if k >= len(args) || k < 0 {
return nil, perrors.Errorf("[Rest Invoke] Index %v is out of bundle", k)
}
header.Set(v, fmt.Sprint(args[k]))
}
return header, nil
}
...@@ -61,7 +61,7 @@ func TestRestInvoker_Invoke(t *testing.T) { ...@@ -61,7 +61,7 @@ func TestRestInvoker_Invoke(t *testing.T) {
"module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" + "module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" +
"side=provider&timeout=3000&timestamp=1556509797245") "side=provider&timeout=3000&timestamp=1556509797245")
assert.NoError(t, err) assert.NoError(t, err)
_, err = common.ServiceMap.Register(url.Protocol, &UserProvider{}) _, err = common.ServiceMap.Register("UserProvider", url.Protocol, &UserProvider{})
assert.NoError(t, err) assert.NoError(t, err)
con := config.ProviderConfig{} con := config.ProviderConfig{}
config.SetProviderConfig(con) config.SetProviderConfig(con)
...@@ -206,6 +206,6 @@ func TestRestInvoker_Invoke(t *testing.T) { ...@@ -206,6 +206,6 @@ func TestRestInvoker_Invoke(t *testing.T) {
assert.Error(t, res.Error(), "test error") assert.Error(t, res.Error(), "test error")
assert.Equal(t, filterNum, 12) assert.Equal(t, filterNum, 12)
err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider") err = common.ServiceMap.UnRegister("UserProvider", url.Protocol, "com.ikurento.user.UserProvider")
assert.NoError(t, err) assert.NoError(t, err)
} }
...@@ -75,7 +75,9 @@ func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter { ...@@ -75,7 +75,9 @@ func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter {
} }
rp.SetExporterMap(serviceKey, exporter) rp.SetExporterMap(serviceKey, exporter)
restServer := rp.getServer(url, restServiceConfig.Server) restServer := rp.getServer(url, restServiceConfig.Server)
restServer.Deploy(invoker, restServiceConfig.RestMethodConfigsMap) for _, methodConfig := range restServiceConfig.RestMethodConfigsMap {
restServer.Deploy(methodConfig, server.GetRouteFunc(invoker, methodConfig))
}
return exporter return exporter
} }
......
...@@ -80,7 +80,7 @@ func TestRestProtocol_Export(t *testing.T) { ...@@ -80,7 +80,7 @@ func TestRestProtocol_Export(t *testing.T) {
"module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" + "module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" +
"side=provider&timeout=3000&timestamp=1556509797245") "side=provider&timeout=3000&timestamp=1556509797245")
assert.NoError(t, err) assert.NoError(t, err)
_, err = common.ServiceMap.Register(url.Protocol, &UserProvider{}) _, err = common.ServiceMap.Register("UserProvider", url.Protocol, &UserProvider{})
assert.NoError(t, err) assert.NoError(t, err)
con := config.ProviderConfig{} con := config.ProviderConfig{}
config.SetProviderConfig(con) config.SetProviderConfig(con)
...@@ -128,7 +128,7 @@ func TestRestProtocol_Export(t *testing.T) { ...@@ -128,7 +128,7 @@ func TestRestProtocol_Export(t *testing.T) {
proto.Destroy() proto.Destroy()
_, ok = proto.(*RestProtocol).serverMap[url.Location] _, ok = proto.(*RestProtocol).serverMap[url.Location]
assert.False(t, ok) assert.False(t, ok)
err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider") err = common.ServiceMap.UnRegister("UserProvider", url.Protocol, "com.ikurento.user.UserProvider")
assert.NoError(t, err) assert.NoError(t, err)
} }
......
...@@ -17,15 +17,306 @@ ...@@ -17,15 +17,306 @@
package server package server
import (
"context"
"errors"
"net/http"
"reflect"
"strconv"
"strings"
)
import (
perrors "github.com/pkg/errors"
)
import ( import (
"github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/protocol" "github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/rest/config" "github.com/apache/dubbo-go/protocol/invocation"
rest_config "github.com/apache/dubbo-go/protocol/rest/config"
) )
const parseParameterErrorStr = "An error occurred while parsing parameters on the server"
// RestServer user can implement this server interface
type RestServer interface { type RestServer interface {
// Start rest server
Start(url common.URL) Start(url common.URL)
Deploy(invoker protocol.Invoker, restMethodConfig map[string]*config.RestMethodConfig) // Deploy a http api
UnDeploy(restMethodConfig map[string]*config.RestMethodConfig) Deploy(restMethodConfig *rest_config.RestMethodConfig, routeFunc func(request RestServerRequest, response RestServerResponse))
// UnDeploy a http api
UnDeploy(restMethodConfig *rest_config.RestMethodConfig)
// Destroy rest server
Destroy() Destroy()
} }
// RestServerRequest interface
type RestServerRequest interface {
// RawRequest get the Ptr of http.Request
RawRequest() *http.Request
// PathParameter get the path parameter by name
PathParameter(name string) string
// PathParameters get the map of the path parameters
PathParameters() map[string]string
// QueryParameter get the query parameter by name
QueryParameter(name string) string
// QueryParameters get the map of query parameters
QueryParameters(name string) []string
// BodyParameter get the body parameter of name
BodyParameter(name string) (string, error)
// HeaderParameter get the header parameter of name
HeaderParameter(name string) string
// ReadEntity checks the Accept header and reads the content into the entityPointer.
ReadEntity(entityPointer interface{}) error
}
// RestServerResponse interface
type RestServerResponse interface {
http.ResponseWriter
// WriteError writes the http status and the error string on the response. err can be nil.
// Return an error if writing was not successful.
WriteError(httpStatus int, err error) (writeErr error)
// WriteEntity marshals the value using the representation denoted by the Accept Header.
WriteEntity(value interface{}) error
}
// GetRouteFunc
// A route function will be invoked by http server
func GetRouteFunc(invoker protocol.Invoker, methodConfig *rest_config.RestMethodConfig) func(req RestServerRequest, resp RestServerResponse) {
return func(req RestServerRequest, resp RestServerResponse) {
var (
err error
args []interface{}
)
svc := common.ServiceMap.GetService(invoker.GetUrl().Protocol, strings.TrimPrefix(invoker.GetUrl().Path, "/"))
// get method
method := svc.Method()[methodConfig.MethodName]
argsTypes := method.ArgsType()
replyType := method.ReplyType()
// two ways to prepare arguments
// if method like this 'func1(req []interface{}, rsp *User) error'
// we don't have arguments type
if (len(argsTypes) == 1 || len(argsTypes) == 2 && replyType == nil) &&
argsTypes[0].String() == "[]interface {}" {
args, err = getArgsInterfaceFromRequest(req, methodConfig)
} else {
args, err = getArgsFromRequest(req, argsTypes, methodConfig)
}
if err != nil {
logger.Errorf("[Go Restful] parsing http parameters error:%v", err)
err = resp.WriteError(http.StatusInternalServerError, errors.New(parseParameterErrorStr))
if err != nil {
logger.Errorf("[Go Restful] WriteErrorString error:%v", err)
}
}
result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodConfig.MethodName, args, make(map[string]string)))
if result.Error() != nil {
err = resp.WriteError(http.StatusInternalServerError, result.Error())
if err != nil {
logger.Errorf("[Go Restful] WriteError error:%v", err)
}
return
}
err = resp.WriteEntity(result.Result())
if err != nil {
logger.Errorf("[Go Restful] WriteEntity error:%v", err)
}
}
}
// getArgsInterfaceFromRequest when service function like GetUser(req []interface{}, rsp *User) error
// use this method to get arguments
func getArgsInterfaceFromRequest(req RestServerRequest, methodConfig *rest_config.RestMethodConfig) ([]interface{}, error) {
argsMap := make(map[int]interface{}, 8)
maxKey := 0
for k, v := range methodConfig.PathParamsMap {
if maxKey < k {
maxKey = k
}
argsMap[k] = req.PathParameter(v)
}
for k, v := range methodConfig.QueryParamsMap {
if maxKey < k {
maxKey = k
}
params := req.QueryParameters(v)
if len(params) == 1 {
argsMap[k] = params[0]
} else {
argsMap[k] = params
}
}
for k, v := range methodConfig.HeadersMap {
if maxKey < k {
maxKey = k
}
argsMap[k] = req.HeaderParameter(v)
}
if methodConfig.Body >= 0 {
if maxKey < methodConfig.Body {
maxKey = methodConfig.Body
}
m := make(map[string]interface{})
// TODO read as a slice
if err := req.ReadEntity(&m); err != nil {
return nil, perrors.Errorf("[Go restful] Read body entity as map[string]interface{} error:%v", err)
}
argsMap[methodConfig.Body] = m
}
args := make([]interface{}, maxKey+1)
for k, v := range argsMap {
if k >= 0 {
args[k] = v
}
}
return args, nil
}
// getArgsFromRequest get arguments from server.RestServerRequest
func getArgsFromRequest(req RestServerRequest, argsTypes []reflect.Type, methodConfig *rest_config.RestMethodConfig) ([]interface{}, error) {
argsLength := len(argsTypes)
args := make([]interface{}, argsLength)
for i, t := range argsTypes {
args[i] = reflect.Zero(t).Interface()
}
if err := assembleArgsFromPathParams(methodConfig, argsLength, argsTypes, req, args); err != nil {
return nil, err
}
if err := assembleArgsFromQueryParams(methodConfig, argsLength, argsTypes, req, args); err != nil {
return nil, err
}
if err := assembleArgsFromBody(methodConfig, argsTypes, req, args); err != nil {
return nil, err
}
if err := assembleArgsFromHeaders(methodConfig, req, argsLength, argsTypes, args); err != nil {
return nil, err
}
return args, nil
}
// assembleArgsFromHeaders assemble arguments from headers
func assembleArgsFromHeaders(methodConfig *rest_config.RestMethodConfig, req RestServerRequest, argsLength int, argsTypes []reflect.Type, args []interface{}) error {
for k, v := range methodConfig.HeadersMap {
param := req.HeaderParameter(v)
if k < 0 || k >= argsLength {
return perrors.Errorf("[Go restful] Header param parse error, the index %v args of method:%v doesn't exist", k, methodConfig.MethodName)
}
t := argsTypes[k]
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.String {
args[k] = param
} else {
return perrors.Errorf("[Go restful] Header param parse error, the index %v args's type isn't string", k)
}
}
return nil
}
// assembleArgsFromBody assemble arguments from body
func assembleArgsFromBody(methodConfig *rest_config.RestMethodConfig, argsTypes []reflect.Type, req RestServerRequest, args []interface{}) error {
if methodConfig.Body >= 0 && methodConfig.Body < len(argsTypes) {
t := argsTypes[methodConfig.Body]
kind := t.Kind()
if kind == reflect.Ptr {
t = t.Elem()
}
var ni interface{}
if t.String() == "[]interface {}" {
ni = make([]map[string]interface{}, 0)
} else if t.String() == "interface {}" {
ni = make(map[string]interface{})
} else {
n := reflect.New(t)
if n.CanInterface() {
ni = n.Interface()
}
}
if err := req.ReadEntity(&ni); err != nil {
return perrors.Errorf("[Go restful] Read body entity error, error is %v", perrors.WithStack(err))
}
args[methodConfig.Body] = ni
}
return nil
}
// assembleArgsFromQueryParams assemble arguments from query params
func assembleArgsFromQueryParams(methodConfig *rest_config.RestMethodConfig, argsLength int, argsTypes []reflect.Type, req RestServerRequest, args []interface{}) error {
var (
err error
param interface{}
i64 int64
)
for k, v := range methodConfig.QueryParamsMap {
if k < 0 || k >= argsLength {
return perrors.Errorf("[Go restful] Query param parse error, the index %v args of method:%v doesn't exist", k, methodConfig.MethodName)
}
t := argsTypes[k]
kind := t.Kind()
if kind == reflect.Ptr {
t = t.Elem()
}
if kind == reflect.Slice {
param = req.QueryParameters(v)
} else if kind == reflect.String {
param = req.QueryParameter(v)
} else if kind == reflect.Int {
param, err = strconv.Atoi(req.QueryParameter(v))
} else if kind == reflect.Int32 {
i64, err = strconv.ParseInt(req.QueryParameter(v), 10, 32)
if err == nil {
param = int32(i64)
}
} else if kind == reflect.Int64 {
param, err = strconv.ParseInt(req.QueryParameter(v), 10, 64)
} else {
return perrors.Errorf("[Go restful] Query param parse error, the index %v args's type isn't int or string or slice", k)
}
if err != nil {
return perrors.Errorf("[Go restful] Query param parse error, error:%v", perrors.WithStack(err))
}
args[k] = param
}
return nil
}
// assembleArgsFromPathParams assemble arguments from path params
func assembleArgsFromPathParams(methodConfig *rest_config.RestMethodConfig, argsLength int, argsTypes []reflect.Type, req RestServerRequest, args []interface{}) error {
var (
err error
param interface{}
i64 int64
)
for k, v := range methodConfig.PathParamsMap {
if k < 0 || k >= argsLength {
return perrors.Errorf("[Go restful] Path param parse error, the index %v args of method:%v doesn't exist", k, methodConfig.MethodName)
}
t := argsTypes[k]
kind := t.Kind()
if kind == reflect.Ptr {
t = t.Elem()
}
if kind == reflect.Int {
param, err = strconv.Atoi(req.PathParameter(v))
} else if kind == reflect.Int32 {
i64, err = strconv.ParseInt(req.PathParameter(v), 10, 32)
if err == nil {
param = int32(i64)
}
} else if kind == reflect.Int64 {
param, err = strconv.ParseInt(req.PathParameter(v), 10, 64)
} else if kind == reflect.String {
param = req.PathParameter(v)
} else {
return perrors.Errorf("[Go restful] Path param parse error, the index %v args's type isn't int or string", k)
}
if err != nil {
return perrors.Errorf("[Go restful] Path param parse error, error is %v", perrors.WithStack(err))
}
args[k] = param
}
return nil
}
...@@ -22,8 +22,6 @@ import ( ...@@ -22,8 +22,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"reflect"
"strconv"
"strings" "strings"
"time" "time"
) )
...@@ -38,27 +36,29 @@ import ( ...@@ -38,27 +36,29 @@ import (
"github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension" "github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
"github.com/apache/dubbo-go/protocol/rest/config" "github.com/apache/dubbo-go/protocol/rest/config"
"github.com/apache/dubbo-go/protocol/rest/server" "github.com/apache/dubbo-go/protocol/rest/server"
) )
func init() { func init() {
extension.SetRestServer(constant.DEFAULT_REST_SERVER, GetNewGoRestfulServer) extension.SetRestServer(constant.DEFAULT_REST_SERVER, NewGoRestfulServer)
} }
var filterSlice []restful.FilterFunction var filterSlice []restful.FilterFunction
// GoRestfulServer a rest server implement by go-restful
type GoRestfulServer struct { type GoRestfulServer struct {
srv *http.Server srv *http.Server
container *restful.Container container *restful.Container
} }
func NewGoRestfulServer() *GoRestfulServer { // NewGoRestfulServer a constructor of GoRestfulServer
func NewGoRestfulServer() server.RestServer {
return &GoRestfulServer{} return &GoRestfulServer{}
} }
// Start go-restful server
// It will add all go-restful filters
func (grs *GoRestfulServer) Start(url common.URL) { func (grs *GoRestfulServer) Start(url common.URL) {
grs.container = restful.NewContainer() grs.container = restful.NewContainer()
for _, filter := range filterSlice { for _, filter := range filterSlice {
...@@ -80,61 +80,32 @@ func (grs *GoRestfulServer) Start(url common.URL) { ...@@ -80,61 +80,32 @@ func (grs *GoRestfulServer) Start(url common.URL) {
}() }()
} }
func (grs *GoRestfulServer) Deploy(invoker protocol.Invoker, restMethodConfig map[string]*config.RestMethodConfig) { // Publish a http api in go-restful server
svc := common.ServiceMap.GetService(invoker.GetUrl().Protocol, strings.TrimPrefix(invoker.GetUrl().Path, "/")) // The routeFunc should be invoked when the server receive a request
for methodName, config := range restMethodConfig { func (grs *GoRestfulServer) Deploy(restMethodConfig *config.RestMethodConfig, routeFunc func(request server.RestServerRequest, response server.RestServerResponse)) {
// get method ws := &restful.WebService{}
method := svc.Method()[methodName] rf := func(req *restful.Request, resp *restful.Response) {
argsTypes := method.ArgsType() routeFunc(NewGoRestfulRequestAdapter(req), resp)
replyType := method.ReplyType()
ws := new(restful.WebService)
ws.Path(config.Path).
Produces(strings.Split(config.Produces, ",")...).
Consumes(strings.Split(config.Consumes, ",")...).
Route(ws.Method(config.MethodType).To(getFunc(methodName, invoker, argsTypes, replyType, config)))
grs.container.Add(ws)
} }
ws.Path(restMethodConfig.Path).
Produces(strings.Split(restMethodConfig.Produces, ",")...).
Consumes(strings.Split(restMethodConfig.Consumes, ",")...).
Route(ws.Method(restMethodConfig.MethodType).To(rf))
grs.container.Add(ws)
} }
func getFunc(methodName string, invoker protocol.Invoker, argsTypes []reflect.Type, // Delete a http api in go-restful server
replyType reflect.Type, config *config.RestMethodConfig) func(req *restful.Request, resp *restful.Response) { func (grs *GoRestfulServer) UnDeploy(restMethodConfig *config.RestMethodConfig) {
return func(req *restful.Request, resp *restful.Response) { ws := new(restful.WebService)
var ( ws.Path(restMethodConfig.Path)
err error err := grs.container.Remove(ws)
args []interface{} if err != nil {
) logger.Warnf("[Go restful] Remove web service error:%v", err)
if (len(argsTypes) == 1 || len(argsTypes) == 2 && replyType == nil) &&
argsTypes[0].String() == "[]interface {}" {
args = getArgsInterfaceFromRequest(req, config)
} else {
args = getArgsFromRequest(req, argsTypes, config)
}
result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodName, args, make(map[string]string)))
if result.Error() != nil {
err = resp.WriteError(http.StatusInternalServerError, result.Error())
if err != nil {
logger.Errorf("[Go Restful] WriteError error:%v", err)
}
return
}
err = resp.WriteEntity(result.Result())
if err != nil {
logger.Error("[Go Restful] WriteEntity error:%v", err)
}
}
}
func (grs *GoRestfulServer) UnDeploy(restMethodConfig map[string]*config.RestMethodConfig) {
for _, config := range restMethodConfig {
ws := new(restful.WebService)
ws.Path(config.Path)
err := grs.container.Remove(ws)
if err != nil {
logger.Warnf("[Go restful] Remove web service error:%v", err)
}
} }
} }
// Destroy the go-restful server
func (grs *GoRestfulServer) Destroy() { func (grs *GoRestfulServer) Destroy() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
...@@ -144,179 +115,59 @@ func (grs *GoRestfulServer) Destroy() { ...@@ -144,179 +115,59 @@ func (grs *GoRestfulServer) Destroy() {
logger.Infof("[Go Restful] Server exiting") logger.Infof("[Go Restful] Server exiting")
} }
func getArgsInterfaceFromRequest(req *restful.Request, config *config.RestMethodConfig) []interface{} { // AddGoRestfulServerFilter let user add the http server of go-restful
argsMap := make(map[int]interface{}, 8) // addFilter should before config.Load()
maxKey := 0 func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) {
for k, v := range config.PathParamsMap { filterSlice = append(filterSlice, filterFuc)
if maxKey < k {
maxKey = k
}
argsMap[k] = req.PathParameter(v)
}
for k, v := range config.QueryParamsMap {
if maxKey < k {
maxKey = k
}
params := req.QueryParameters(v)
if len(params) == 1 {
argsMap[k] = params[0]
} else {
argsMap[k] = params
}
}
for k, v := range config.HeadersMap {
if maxKey < k {
maxKey = k
}
argsMap[k] = req.HeaderParameter(v)
}
if config.Body >= 0 {
if maxKey < config.Body {
maxKey = config.Body
}
m := make(map[string]interface{})
// TODO read as a slice
if err := req.ReadEntity(&m); err != nil {
logger.Warnf("[Go restful] Read body entity as map[string]interface{} error:%v", perrors.WithStack(err))
} else {
argsMap[config.Body] = m
}
}
args := make([]interface{}, maxKey+1)
for k, v := range argsMap {
if k >= 0 {
args[k] = v
}
}
return args
} }
func getArgsFromRequest(req *restful.Request, argsTypes []reflect.Type, config *config.RestMethodConfig) []interface{} { // GoRestfulRequestAdapter a adapter struct about RestServerRequest
argsLength := len(argsTypes) type GoRestfulRequestAdapter struct {
args := make([]interface{}, argsLength) server.RestServerRequest
for i, t := range argsTypes { request *restful.Request
args[i] = reflect.Zero(t).Interface() }
}
var (
err error
param interface{}
i64 int64
)
for k, v := range config.PathParamsMap {
if k < 0 || k >= argsLength {
logger.Errorf("[Go restful] Path param parse error, the args:%v doesn't exist", k)
continue
}
t := argsTypes[k]
kind := t.Kind()
if kind == reflect.Ptr {
t = t.Elem()
}
if kind == reflect.Int {
param, err = strconv.Atoi(req.PathParameter(v))
} else if kind == reflect.Int32 {
i64, err = strconv.ParseInt(req.PathParameter(v), 10, 32)
if err == nil {
param = int32(i64)
}
} else if kind == reflect.Int64 {
param, err = strconv.ParseInt(req.PathParameter(v), 10, 64)
} else if kind == reflect.String {
param = req.PathParameter(v)
} else {
logger.Warnf("[Go restful] Path param parse error, the args:%v of type isn't int or string", k)
continue
}
if err != nil {
logger.Errorf("[Go restful] Path param parse error, error is %v", err)
continue
}
args[k] = param
}
for k, v := range config.QueryParamsMap {
if k < 0 || k >= argsLength {
logger.Errorf("[Go restful] Query param parse error, the args:%v doesn't exist", k)
continue
}
t := argsTypes[k]
kind := t.Kind()
if kind == reflect.Ptr {
t = t.Elem()
}
if kind == reflect.Slice {
param = req.QueryParameters(v)
} else if kind == reflect.String {
param = req.QueryParameter(v)
} else if kind == reflect.Int {
param, err = strconv.Atoi(req.QueryParameter(v))
} else if kind == reflect.Int32 {
i64, err = strconv.ParseInt(req.QueryParameter(v), 10, 32)
if err == nil {
param = int32(i64)
}
} else if kind == reflect.Int64 {
param, err = strconv.ParseInt(req.QueryParameter(v), 10, 64)
} else {
logger.Errorf("[Go restful] Query param parse error, the args:%v of type isn't int or string or slice", k)
continue
}
if err != nil {
logger.Errorf("[Go restful] Query param parse error, error is %v", err)
continue
}
args[k] = param
}
if config.Body >= 0 && config.Body < len(argsTypes) { // NewGoRestfulRequestAdapter a constructor of GoRestfulRequestAdapter
t := argsTypes[config.Body] func NewGoRestfulRequestAdapter(request *restful.Request) *GoRestfulRequestAdapter {
kind := t.Kind() return &GoRestfulRequestAdapter{request: request}
if kind == reflect.Ptr { }
t = t.Elem()
}
var ni interface{}
if t.String() == "[]interface {}" {
ni = make([]map[string]interface{}, 0)
} else if t.String() == "interface {}" {
ni = make(map[string]interface{})
} else {
n := reflect.New(t)
if n.CanInterface() {
ni = n.Interface()
}
}
if err := req.ReadEntity(&ni); err != nil {
logger.Errorf("[Go restful] Read body entity error:%v", err)
} else {
args[config.Body] = ni
}
}
for k, v := range config.HeadersMap { // RawRequest a adapter function of server.RestServerRequest's RawRequest
param := req.HeaderParameter(v) func (grra *GoRestfulRequestAdapter) RawRequest() *http.Request {
if k < 0 || k >= argsLength { return grra.request.Request
logger.Errorf("[Go restful] Header param parse error, the args:%v doesn't exist", k) }
continue
}
t := argsTypes[k]
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.String {
args[k] = param
} else {
logger.Errorf("[Go restful] Header param parse error, the args:%v of type isn't string", k)
}
}
return args // PathParameter a adapter function of server.RestServerRequest's PathParameter
func (grra *GoRestfulRequestAdapter) PathParameter(name string) string {
return grra.request.PathParameter(name)
} }
func GetNewGoRestfulServer() server.RestServer { // PathParameters a adapter function of server.RestServerRequest's QueryParameter
return NewGoRestfulServer() func (grra *GoRestfulRequestAdapter) PathParameters() map[string]string {
return grra.request.PathParameters()
} }
// Let user addFilter // QueryParameter a adapter function of server.RestServerRequest's QueryParameters
// addFilter should before config.Load() func (grra *GoRestfulRequestAdapter) QueryParameter(name string) string {
func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) { return grra.request.QueryParameter(name)
filterSlice = append(filterSlice, filterFuc) }
// QueryParameters a adapter function of server.RestServerRequest's QueryParameters
func (grra *GoRestfulRequestAdapter) QueryParameters(name string) []string {
return grra.request.QueryParameters(name)
}
// BodyParameter a adapter function of server.RestServerRequest's BodyParameter
func (grra *GoRestfulRequestAdapter) BodyParameter(name string) (string, error) {
return grra.request.BodyParameter(name)
}
// HeaderParameter a adapter function of server.RestServerRequest's HeaderParameter
func (grra *GoRestfulRequestAdapter) HeaderParameter(name string) string {
return grra.request.HeaderParameter(name)
}
// ReadEntity a adapter func of server.RestServerRequest's ReadEntity
func (grra *GoRestfulRequestAdapter) ReadEntity(entityPointer interface{}) error {
return grra.request.ReadEntity(entityPointer)
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package protocol package protocol
import ( import (
......
...@@ -121,6 +121,7 @@ func (r *BaseRegistry) Destroy() { ...@@ -121,6 +121,7 @@ func (r *BaseRegistry) Destroy() {
close(r.done) close(r.done)
// wait waitgroup done (wait listeners outside close over) // wait waitgroup done (wait listeners outside close over)
r.wg.Wait() r.wg.Wait()
//close registry client //close registry client
r.closeRegisters() r.closeRegisters()
} }
...@@ -178,7 +179,10 @@ func (r *BaseRegistry) RestartCallBack() bool { ...@@ -178,7 +179,10 @@ func (r *BaseRegistry) RestartCallBack() bool {
} }
logger.Infof("success to re-register service :%v", confIf.Key()) logger.Infof("success to re-register service :%v", confIf.Key())
} }
r.facadeBasedRegistry.InitListeners()
if flag {
r.facadeBasedRegistry.InitListeners()
}
return flag return flag
} }
...@@ -245,19 +249,13 @@ func (r *BaseRegistry) providerRegistry(c common.URL, params url.Values) (string ...@@ -245,19 +249,13 @@ func (r *BaseRegistry) providerRegistry(c common.URL, params url.Values) (string
logger.Errorf("facadeBasedRegistry.CreatePath(path{%s}) = error{%#v}", dubboPath, perrors.WithStack(err)) logger.Errorf("facadeBasedRegistry.CreatePath(path{%s}) = error{%#v}", dubboPath, perrors.WithStack(err))
return "", "", perrors.WithMessagef(err, "facadeBasedRegistry.CreatePath(path:%s)", dubboPath) return "", "", perrors.WithMessagef(err, "facadeBasedRegistry.CreatePath(path:%s)", dubboPath)
} }
params.Add("anyhost", "true") params.Add(constant.ANYHOST_KEY, "true")
// Dubbo java consumer to start looking for the provider url,because the category does not match, // Dubbo java consumer to start looking for the provider url,because the category does not match,
// the provider will not find, causing the consumer can not start, so we use consumers. // the provider will not find, causing the consumer can not start, so we use consumers.
// DubboRole = [...]string{"consumer", "", "", "provider"}
// params.Add("category", (RoleType(PROVIDER)).Role())
params.Add("category", (common.RoleType(common.PROVIDER)).String())
params.Add("dubbo", "dubbo-provider-golang-"+constant.Version)
params.Add("side", (common.RoleType(common.PROVIDER)).Role())
if len(c.Methods) == 0 { if len(c.Methods) == 0 {
params.Add("methods", strings.Join(c.Methods, ",")) params.Add(constant.METHODS_KEY, strings.Join(c.Methods, ","))
} }
logger.Debugf("provider url params:%#v", params) logger.Debugf("provider url params:%#v", params)
var host string var host string
...@@ -308,9 +306,6 @@ func (r *BaseRegistry) consumerRegistry(c common.URL, params url.Values) (string ...@@ -308,9 +306,6 @@ func (r *BaseRegistry) consumerRegistry(c common.URL, params url.Values) (string
} }
params.Add("protocol", c.Protocol) params.Add("protocol", c.Protocol)
params.Add("category", (common.RoleType(common.CONSUMER)).String())
params.Add("dubbo", "dubbogo-consumer-"+constant.Version)
rawURL = fmt.Sprintf("consumer://%s%s?%s", localIP, c.Path, params.Encode()) rawURL = fmt.Sprintf("consumer://%s%s?%s", localIP, c.Path, params.Encode())
dubboPath = fmt.Sprintf("/dubbo/%s/%s", r.service(c), (common.RoleType(common.CONSUMER)).String()) dubboPath = fmt.Sprintf("/dubbo/%s/%s", r.service(c), (common.RoleType(common.CONSUMER)).String())
......
...@@ -127,12 +127,13 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) { ...@@ -127,12 +127,13 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) {
} else if url.Protocol == constant.ROUTER_PROTOCOL || //2.for router } else if url.Protocol == constant.ROUTER_PROTOCOL || //2.for router
url.GetParam(constant.CATEGORY_KEY, constant.DEFAULT_CATEGORY) == constant.ROUTER_CATEGORY { url.GetParam(constant.CATEGORY_KEY, constant.DEFAULT_CATEGORY) == constant.ROUTER_CATEGORY {
url = nil url = nil
} }
switch res.Action { switch res.Action {
case remoting.EventTypeAdd, remoting.EventTypeUpdate: case remoting.EventTypeAdd, remoting.EventTypeUpdate:
logger.Infof("selector add service url{%s}", res.Service) logger.Infof("selector add service url{%s}", res.Service)
var urls []*common.URL
var urls []*common.URL
for _, v := range directory.GetRouterURLSet().Values() { for _, v := range directory.GetRouterURLSet().Values() {
urls = append(urls, v.(*common.URL)) urls = append(urls, v.(*common.URL))
} }
...@@ -140,8 +141,6 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) { ...@@ -140,8 +141,6 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) {
if len(urls) > 0 { if len(urls) > 0 {
dir.SetRouters(urls) dir.SetRouters(urls)
} }
//dir.cacheService.EventTypeAdd(res.Path, dir.serviceTTL)
oldInvoker = dir.cacheInvoker(url) oldInvoker = dir.cacheInvoker(url)
case remoting.EventTypeDel: case remoting.EventTypeDel:
oldInvoker = dir.uncacheInvoker(url) oldInvoker = dir.uncacheInvoker(url)
......
...@@ -164,9 +164,7 @@ func (r *etcdV3Registry) DoSubscribe(svc *common.URL) (registry.Listener, error) ...@@ -164,9 +164,7 @@ func (r *etcdV3Registry) DoSubscribe(svc *common.URL) (registry.Listener, error)
//register the svc to dataListener //register the svc to dataListener
r.dataListener.AddInterestedURL(svc) r.dataListener.AddInterestedURL(svc)
for _, v := range strings.Split(svc.GetParam(constant.CATEGORY_KEY, constant.DEFAULT_CATEGORY), ",") { go r.listener.ListenServiceEvent(fmt.Sprintf("/dubbo/%s/"+constant.DEFAULT_CATEGORY, svc.Service()), r.dataListener)
go r.listener.ListenServiceEvent(fmt.Sprintf("/dubbo/%s/"+v, svc.Service()), r.dataListener)
}
return configListener, nil return configListener, nil
} }
...@@ -63,7 +63,7 @@ func (suite *RegistryTestSuite) TestRegister() { ...@@ -63,7 +63,7 @@ func (suite *RegistryTestSuite) TestRegister() {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Regexp(t, ".*dubbo%3A%2F%2F127.0.0.1%3A20000%2Fcom.ikurento.user.UserProvider%3Fanyhost%3Dtrue%26category%3Dproviders%26cluster%3Dmock%26dubbo%3Ddubbo-provider-golang-1.3.0%26.*provider", children) assert.Regexp(t, ".*dubbo%3A%2F%2F127.0.0.1%3A20000%2Fcom.ikurento.user.UserProvider%3Fanyhost%3Dtrue%26cluster%3Dmock", children)
assert.NoError(t, err) assert.NoError(t, err)
} }
......
...@@ -21,7 +21,6 @@ import ( ...@@ -21,7 +21,6 @@ import (
"fmt" "fmt"
"os" "os"
"path" "path"
"strings"
"sync" "sync"
"time" "time"
) )
...@@ -135,9 +134,7 @@ func (r *kubernetesRegistry) DoSubscribe(svc *common.URL) (registry.Listener, er ...@@ -135,9 +134,7 @@ func (r *kubernetesRegistry) DoSubscribe(svc *common.URL) (registry.Listener, er
//register the svc to dataListener //register the svc to dataListener
r.dataListener.AddInterestedURL(svc) r.dataListener.AddInterestedURL(svc)
for _, v := range strings.Split(svc.GetParam(constant.CATEGORY_KEY, constant.DEFAULT_CATEGORY), ",") { go r.listener.ListenServiceEvent(fmt.Sprintf("/dubbo/%s/"+constant.DEFAULT_CATEGORY, svc.Service()), r.dataListener)
go r.listener.ListenServiceEvent(fmt.Sprintf("/dubbo/%s/"+v, svc.Service()), r.dataListener)
}
return configListener, nil return configListener, nil
} }
......
...@@ -45,6 +45,7 @@ import ( ...@@ -45,6 +45,7 @@ import (
var ( var (
regProtocol *registryProtocol regProtocol *registryProtocol
once sync.Once
) )
type registryProtocol struct { type registryProtocol struct {
...@@ -346,12 +347,12 @@ func setProviderUrl(regURL *common.URL, providerURL *common.URL) { ...@@ -346,12 +347,12 @@ func setProviderUrl(regURL *common.URL, providerURL *common.URL) {
regURL.SubURL = providerURL regURL.SubURL = providerURL
} }
// GetProtocol ... // GetProtocol return the singleton RegistryProtocol
func GetProtocol() protocol.Protocol { func GetProtocol() protocol.Protocol {
if regProtocol != nil { once.Do(func() {
return regProtocol regProtocol = newRegistryProtocol()
} })
return newRegistryProtocol() return regProtocol
} }
type wrappedInvoker struct { type wrappedInvoker struct {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment