From fa27aa2386d688f29dbf2153dfd831e3ff6e237f Mon Sep 17 00:00:00 2001
From: Patrick <dreamlike.sky@foxmail.com>
Date: Mon, 9 Mar 2020 10:49:04 +0800
Subject: [PATCH] modify rest protocol

---
 protocol/rest/rest_protocol.go | 29 +++++++++++++++++++----------
 1 file changed, 19 insertions(+), 10 deletions(-)

diff --git a/protocol/rest/rest_protocol.go b/protocol/rest/rest_protocol.go
index 20aaee5dc..fdf941a0d 100644
--- a/protocol/rest/rest_protocol.go
+++ b/protocol/rest/rest_protocol.go
@@ -27,6 +27,7 @@ import (
 	"github.com/apache/dubbo-go/common"
 	"github.com/apache/dubbo-go/common/constant"
 	"github.com/apache/dubbo-go/common/extension"
+	"github.com/apache/dubbo-go/common/logger"
 	"github.com/apache/dubbo-go/config"
 	"github.com/apache/dubbo-go/protocol"
 	"github.com/apache/dubbo-go/protocol/rest/rest_interface"
@@ -63,10 +64,14 @@ func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter {
 	url := invoker.GetUrl()
 	serviceKey := url.ServiceKey()
 	exporter := NewRestExporter(serviceKey, invoker, rp.ExporterMap())
-	restConfig := GetRestProviderServiceConfig(strings.TrimPrefix(url.Path, "/"))
+	restServiceConfig := GetRestProviderServiceConfig(strings.TrimPrefix(url.Path, "/"))
+	if restServiceConfig == nil {
+		logger.Errorf("%s service doesn't has provider config", url.Path)
+		return nil
+	}
 	rp.SetExporterMap(serviceKey, exporter)
-	restServer := rp.getServer(url, restConfig)
-	restServer.Deploy(invoker, restConfig.RestMethodConfigsMap)
+	restServer := rp.getServer(url, restServiceConfig.Server)
+	restServer.Deploy(invoker, restServiceConfig.RestMethodConfigsMap)
 	return exporter
 }
 
@@ -78,15 +83,19 @@ func (rp *RestProtocol) Refer(url common.URL) protocol.Invoker {
 	if t, err := time.ParseDuration(requestTimeoutStr); err == nil {
 		requestTimeout = t
 	}
-	restConfig := GetRestConsumerServiceConfig(strings.TrimPrefix(url.Path, "/"))
+	restServiceConfig := GetRestConsumerServiceConfig(strings.TrimPrefix(url.Path, "/"))
+	if restServiceConfig == nil {
+		logger.Errorf("%s service doesn't has consumer config", url.Path)
+		return nil
+	}
 	restOptions := rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout}
-	restClient := rp.getClient(restOptions, restConfig)
-	invoker := NewRestInvoker(url, &restClient, restConfig.RestMethodConfigsMap)
+	restClient := rp.getClient(restOptions, restServiceConfig.Client)
+	invoker := NewRestInvoker(url, &restClient, restServiceConfig.RestMethodConfigsMap)
 	rp.SetInvokers(invoker)
 	return invoker
 }
 
-func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.RestServiceConfig) rest_interface.RestServer {
+func (rp *RestProtocol) getServer(url common.URL, serverType string) rest_interface.RestServer {
 	restServer, ok := rp.serverMap[url.Location]
 	if !ok {
 		_, ok := rp.ExporterMap().Load(url.ServiceKey())
@@ -96,7 +105,7 @@ func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.Res
 		rp.serverLock.Lock()
 		restServer, ok = rp.serverMap[url.Location]
 		if !ok {
-			restServer = extension.GetNewRestServer(restConfig.Server)
+			restServer = extension.GetNewRestServer(serverType)
 			restServer.Start(url)
 			rp.serverMap[url.Location] = restServer
 		}
@@ -106,13 +115,13 @@ func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.Res
 	return restServer
 }
 
-func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, restConfig *rest_interface.RestServiceConfig) rest_interface.RestClient {
+func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, clientType string) rest_interface.RestClient {
 	restClient, ok := rp.clientMap[restOptions]
 	rp.clientLock.Lock()
 	if !ok {
 		restClient, ok = rp.clientMap[restOptions]
 		if !ok {
-			restClient = extension.GetNewRestClient(restConfig.Client, &restOptions)
+			restClient = extension.GetNewRestClient(clientType, &restOptions)
 			rp.clientMap[restOptions] = restClient
 		}
 	}
-- 
GitLab