diff --git a/protocol/rest/rest_interface/rest_client.go b/protocol/rest/rest_interface/rest_client.go index 6d0515daab76a9b4d5eed0deae154b875f0a2e70..da48e7ea762f769cb367e9d8ad85562523e011fc 100644 --- a/protocol/rest/rest_interface/rest_client.go +++ b/protocol/rest/rest_interface/rest_client.go @@ -18,6 +18,7 @@ type RestRequest struct { PathParams map[string]string QueryParams map[string]string Body map[string]interface{} + Headers map[string]string } type RestClient interface { diff --git a/protocol/rest/rest_interface/rest_config.go b/protocol/rest/rest_interface/rest_config.go index e0bb53bb8ce8d5dd35aafff0c42e40c74e9ca6a4..442177d6b02f985cc4cb667fff289bd7e7e1bbf5 100644 --- a/protocol/rest/rest_interface/rest_config.go +++ b/protocol/rest/rest_interface/rest_config.go @@ -41,4 +41,6 @@ type RestMethodConfig struct { QueryParamsMap map[int]string Body string `yaml:"rest_body" json:"rest_body,omitempty" property:"rest_body"` BodyMap map[int]string + Headers string `yaml:"rest_headers" json:"rest_headers,omitempty" property:"rest_headers"` + HeadersMap map[int]string } diff --git a/protocol/rest/rest_invoker.go b/protocol/rest/rest_invoker.go index 7bafcbb4d90414f99c74281352b5c25df93ed406..67e2c9f735c00d32623ff7c5fa5bd707f558bf0b 100644 --- a/protocol/rest/rest_invoker.go +++ b/protocol/rest/rest_invoker.go @@ -35,18 +35,10 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio logger.Errorf("[RestInvoker]Rest methodConfig:%s is nil", inv.MethodName()) return nil } - pathParams := make(map[string]string) - queryParams := make(map[string]string) - bodyParams := make(map[string]interface{}) - for key, value := range methodConfig.PathParamsMap { - pathParams[value] = fmt.Sprintf("%v", inv.Arguments()[key]) - } - for key, value := range methodConfig.QueryParamsMap { - queryParams[value] = fmt.Sprintf("%v", inv.Arguments()[key]) - } - for key, value := range methodConfig.BodyMap { - bodyParams[value] = inv.Arguments()[key] - } + pathParams := restStringMapTransform(methodConfig.PathParamsMap, inv.Arguments()) + queryParams := restStringMapTransform(methodConfig.QueryParamsMap, inv.Arguments()) + headers := restStringMapTransform(methodConfig.HeadersMap, inv.Arguments()) + bodyParams := restInterfaceMapTransform(methodConfig.BodyMap, inv.Arguments()) req := &rest_interface.RestRequest{ Location: ri.GetUrl().Location, Produces: methodConfig.Produces, @@ -56,11 +48,27 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio PathParams: pathParams, QueryParams: queryParams, Body: bodyParams, + Headers: headers, } result.Err = ri.client.Do(req, inv.Reply()) if result.Err == nil { result.Rest = inv.Reply() } return &result +} +func restStringMapTransform(paramsMap map[int]string, args []interface{}) map[string]string { + resMap := make(map[string]string, len(paramsMap)) + for key, value := range paramsMap { + resMap[value] = fmt.Sprintf("%v", args[key]) + } + return resMap +} + +func restInterfaceMapTransform(paramsMap map[int]string, args []interface{}) map[string]interface{} { + resMap := make(map[string]interface{}, len(paramsMap)) + for key, value := range paramsMap { + resMap[value] = args[key] + } + return resMap } diff --git a/protocol/rest/rest_invoker_test.go b/protocol/rest/rest_invoker_test.go index 4c9fd151f6c1fda7048ff917a6691a9381434ef6..18828dd189afdf076942251c07a82dde1d922854 100644 --- a/protocol/rest/rest_invoker_test.go +++ b/protocol/rest/rest_invoker_test.go @@ -18,7 +18,6 @@ type User struct { func TestRestInvoker_Invoke(t *testing.T) { // Refer - proto := GetRestProtocol() url, err := common.NewURL(context.Background(), "rest://127.0.0.1:8888/com.ikurento.user.UserProvider?anyhost=true&"+ "application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&"+ "environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%2C&"+ @@ -57,14 +56,4 @@ func TestRestInvoker_Invoke(t *testing.T) { invocation.WithArguments([]interface{}{"1", "username"}), invocation.WithReply(user)) invoker.Invoke(context.Background(), inv) - // make sure url - eq := invoker.GetUrl().URLEqual(url) - assert.True(t, eq) - - // make sure invokers after 'Destroy' - invokersLen := len(proto.(*RestProtocol).Invokers()) - assert.Equal(t, 1, invokersLen) - proto.Destroy() - invokersLen = len(proto.(*RestProtocol).Invokers()) - assert.Equal(t, 0, invokersLen) }