diff --git a/common/extension/rest_client.go b/common/extension/rest_client.go index 35e6fb53818a8d877e2381e25f44690504b201f3..8b8e818fe1f658bafc1ce3a65d8a011d41798572 100644 --- a/common/extension/rest_client.go +++ b/common/extension/rest_client.go @@ -8,11 +8,11 @@ var ( restClients = make(map[string]func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) ) -func SetRestClient(name string, fun func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) { +func SetRestClientFunc(name string, fun func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) { restClients[name] = fun } -func GetRestClient(name string, restOptions *rest_interface.RestOptions) rest_interface.RestClient { +func GetNewRestClient(name string, restOptions *rest_interface.RestOptions) rest_interface.RestClient { if restClients[name] == nil { panic("restClient for " + name + " is not existing, make sure you have import the package.") } diff --git a/common/extension/rest_config_reader.go b/common/extension/rest_config_reader.go index be76689b9b88b579b3c21321a9234c4de78dfd2f..5660ef29e17d2369a0da60caee67e6b56880d976 100644 --- a/common/extension/rest_config_reader.go +++ b/common/extension/rest_config_reader.go @@ -12,7 +12,7 @@ func SetRestConfigReader(name string, fun func() rest_interface.RestConfigReader restConfigReaders[name] = fun } -func GetRestConfigReader(name string) rest_interface.RestConfigReader { +func GetSingletonRestConfigReader(name string) rest_interface.RestConfigReader { if name == "" { name = "default" } diff --git a/common/extension/rest_server.go b/common/extension/rest_server.go new file mode 100644 index 0000000000000000000000000000000000000000..f983deac2c96deffcec4111fc5f9e49ee4c8f174 --- /dev/null +++ b/common/extension/rest_server.go @@ -0,0 +1,20 @@ +package extension + +import ( + "github.com/apache/dubbo-go/protocol/rest/rest_interface" +) + +var ( + restServers = make(map[string]func() rest_interface.RestServer) +) + +func SetRestServerFunc(name string, fun func() rest_interface.RestServer) { + restServers[name] = fun +} + +func GetNewRestServer(name string) rest_interface.RestServer { + if restServers[name] == nil { + panic("restServer for " + name + " is not existing, make sure you have import the package.") + } + return restServers[name]() +} diff --git a/protocol/rest/rest_client/resty_client.go b/protocol/rest/rest_client/resty_client.go index 296737a567ce72e95892830ff73d72ea9f31305e..e6c2741c2e20ccc4d8cf8fb3979c909e481959a4 100644 --- a/protocol/rest/rest_client/resty_client.go +++ b/protocol/rest/rest_client/resty_client.go @@ -16,7 +16,7 @@ import ( ) func init() { - extension.SetRestClient(constant.DEFAULT_REST_CLIENT, GetRestyClient) + extension.SetRestClientFunc(constant.DEFAULT_REST_CLIENT, GetRestyClient) } type RestyClient struct { diff --git a/protocol/rest/rest_config_initializer.go b/protocol/rest/rest_config_initializer.go index a9434773a7b28b07eb9485f6292299b7ecc65519..3480df98a8dbadd3048d80fd366c4b7bc95ae5ef 100644 --- a/protocol/rest/rest_config_initializer.go +++ b/protocol/rest/rest_config_initializer.go @@ -28,7 +28,7 @@ func init() { func initConsumerRestConfig() { consumerConfigType := config.GetConsumerConfig().RestConfigType - consumerConfigReader := extension.GetRestConfigReader(consumerConfigType) + consumerConfigReader := extension.GetSingletonRestConfigReader(consumerConfigType) restConsumerConfig = consumerConfigReader.ReadConsumerConfig() if restConsumerConfig == nil { return @@ -42,7 +42,7 @@ func initConsumerRestConfig() { func initProviderRestConfig() { providerConfigType := config.GetProviderConfig().RestConfigType - providerConfigReader := extension.GetRestConfigReader(providerConfigType) + providerConfigReader := extension.GetSingletonRestConfigReader(providerConfigType) restProviderConfig = providerConfigReader.ReadProviderConfig() if restProviderConfig == nil { return diff --git a/protocol/rest/rest_exporter.go b/protocol/rest/rest_exporter.go index 84a645d9b1246d2b6636ee373acf05f70a2734fa..39503d157f6ee0c2d831781a2c7d8887e1d758eb 100644 --- a/protocol/rest/rest_exporter.go +++ b/protocol/rest/rest_exporter.go @@ -1,16 +1,32 @@ package rest -import "github.com/apache/dubbo-go/protocol" +import ( + "sync" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/protocol" +) type RestExporter struct { protocol.BaseExporter } -func NewRestExporter() *RestExporter { - return &RestExporter{} +func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map) *RestExporter { + return &RestExporter{ + BaseExporter: *protocol.NewBaseExporter(key, invoker, exporterMap), + } } func (re *RestExporter) Unexport() { - // undeploy service脽脽 + serviceId := re.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "") + re.BaseExporter.Unexport() + err := common.ServiceMap.UnRegister(REST, serviceId) + if err != nil { + logger.Errorf("[RestExporter.Unexport] error: %v", err) + } return } diff --git a/protocol/rest/rest_interface/rest_server.go b/protocol/rest/rest_interface/rest_server.go index c6aa79494f6773f27e2d1ee0c7164b54dade6d5c..cff0c94cfdc26fb7bfdab5e35012c502c41d736f 100644 --- a/protocol/rest/rest_interface/rest_server.go +++ b/protocol/rest/rest_interface/rest_server.go @@ -1,8 +1,13 @@ package rest_interface +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" +) + type RestServer interface { - Start() - Deploy() - Undeploy() + Start(url common.URL) + Deploy(invoker protocol.Invoker, restMethodConfig map[string]*RestMethodConfig) + Undeploy(restMethodConfig map[string]*RestMethodConfig) Destory() } diff --git a/protocol/rest/rest_invoker.go b/protocol/rest/rest_invoker.go index f87f655c63fec375887b089130dff79f45a6d02a..7bafcbb4d90414f99c74281352b5c25df93ed406 100644 --- a/protocol/rest/rest_invoker.go +++ b/protocol/rest/rest_invoker.go @@ -1,6 +1,7 @@ package rest import ( + "context" "fmt" ) @@ -18,15 +19,15 @@ type RestInvoker struct { restMethodConfigMap map[string]*rest_interface.RestMethodConfig } -func NewRestInvoker(url common.URL, client rest_interface.RestClient, restMethodConfig map[string]*rest_interface.RestMethodConfig) *RestInvoker { +func NewRestInvoker(url common.URL, client *rest_interface.RestClient, restMethodConfig map[string]*rest_interface.RestMethodConfig) *RestInvoker { return &RestInvoker{ BaseInvoker: *protocol.NewBaseInvoker(url), - client: client, + client: *client, restMethodConfigMap: restMethodConfig, } } -func (ri *RestInvoker) Invoke(invocation protocol.Invocation) protocol.Result { +func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result { inv := invocation.(*invocation_impl.RPCInvocation) methodConfig := ri.restMethodConfigMap[inv.MethodName()] var result protocol.RPCResult diff --git a/protocol/rest/rest_invoker_test.go b/protocol/rest/rest_invoker_test.go index ee8ba7b6f3ef4b398623a1efcd59a618b26cf410..4c9fd151f6c1fda7048ff917a6691a9381434ef6 100644 --- a/protocol/rest/rest_invoker_test.go +++ b/protocol/rest/rest_invoker_test.go @@ -51,11 +51,11 @@ func TestRestInvoker_Invoke(t *testing.T) { RestMethodConfigsMap: methodConfigMap, } restClient := rest_client.GetRestyClient(&rest_interface.RestOptions{ConnectTimeout: 5 * time.Second, RequestTimeout: 5 * time.Second}) - invoker := NewRestInvoker(url, restClient, methodConfigMap) + invoker := NewRestInvoker(url, &restClient, methodConfigMap) user := &User{} inv := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"), invocation.WithArguments([]interface{}{"1", "username"}), invocation.WithReply(user)) - invoker.Invoke(inv) + invoker.Invoke(context.Background(), inv) // make sure url eq := invoker.GetUrl().URLEqual(url) diff --git a/protocol/rest/rest_protocol.go b/protocol/rest/rest_protocol.go index 8801f1efee2d12a3a4a22dba454c8d36ab5d5ce7..c0c1bc92bcad216c0224b474c9c682610e0a9ea4 100644 --- a/protocol/rest/rest_protocol.go +++ b/protocol/rest/rest_protocol.go @@ -7,6 +7,8 @@ import ( "github.com/apache/dubbo-go/config" "github.com/apache/dubbo-go/protocol" "github.com/apache/dubbo-go/protocol/rest/rest_interface" + "strings" + "sync" "time" ) @@ -22,6 +24,10 @@ func init() { type RestProtocol struct { protocol.BaseProtocol + serverMap map[string]rest_interface.RestServer + clientMap map[rest_interface.RestOptions]rest_interface.RestClient + serverLock sync.Mutex + clientLock sync.Mutex } func NewRestProtocol() *RestProtocol { @@ -29,32 +35,76 @@ func NewRestProtocol() *RestProtocol { } func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter { - // TODO 褰撶敤鎴锋敞鍐屼竴涓湇鍔$殑鏃跺€欙紝鏍规嵁ExporterConfig鍜屾湇鍔″疄鐜帮紝瀹屾垚Service -> Rest鐨勭粦瀹氥€傛敞鎰忔澶勬槸Service -> Rest锛屽洜涓烘鏃舵垜浠槸鏆撮湶鏈嶅姟銆傚綋鏀跺埌璇锋眰鐨勬椂鍊欙紝鎭板ソ鏄毚闇叉湇鍔$殑鍙嶅悜锛屽嵆Rest -> Service锛� - // Server鍦‥xport鐨勬椂鍊欏苟涓嶅仛浠€涔堜簨鎯呫€備絾鏄湪鎺ュ彈鍒拌姹傜殑鏃跺€欙紝瀹冮渶瑕佽礋璐f墽琛屽弽搴忓垪鍖栫殑杩囩▼; - // http server鏄竴涓娊璞¢殧绂诲眰銆傚畠鍐呴儴鍏佽浣跨敤beego鎴栬€単in鏉ヤ綔涓簑eb鏈嶅姟鍣紝鎺ユ敹璇锋眰锛岀敤鎴峰彲浠ユ墿灞曡嚜宸辩殑瀹炵幇锛� - - return nil + url := invoker.GetUrl() + serviceKey := strings.TrimPrefix(url.Path, "/") + exporter := NewRestExporter(serviceKey, invoker, rp.ExporterMap()) + restConfig := GetRestProviderServiceConfig(url.Service()) + rp.SetExporterMap(serviceKey, exporter) + restServer := rp.getServer(url, restConfig) + restServer.Deploy(invoker, restConfig.RestMethodConfigsMap) + return exporter } func (rp *RestProtocol) Refer(url common.URL) protocol.Invoker { // create rest_invoker var requestTimeout = config.GetConsumerConfig().RequestTimeout - requestTimeoutStr := url.GetParam(constant.TIMEOUT_KEY, config.GetConsumerConfig().Request_Timeout) connectTimeout := config.GetConsumerConfig().ConnectTimeout if t, err := time.ParseDuration(requestTimeoutStr); err == nil { requestTimeout = t } restConfig := GetRestConsumerServiceConfig(url.Service()) - restClient := extension.GetRestClient(restConfig.Client, &rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout}) - invoker := NewRestInvoker(url, restClient, restConfig.RestMethodConfigsMap) + restOptions := rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout} + restClient := rp.getClient(restOptions, restConfig) + invoker := NewRestInvoker(url, &restClient, restConfig.RestMethodConfigsMap) rp.SetInvokers(invoker) return invoker } +func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.RestConfig) rest_interface.RestServer { + restServer, ok := rp.serverMap[url.Location] + if !ok { + _, ok := rp.ExporterMap().Load(strings.TrimPrefix(url.Path, "/")) + if !ok { + panic("[RestProtocol]" + url.Key() + "is not existing") + } + rp.serverLock.Lock() + restServer, ok = rp.serverMap[url.Location] + if !ok { + restServer = extension.GetNewRestServer(restConfig.Server) + restServer.Start(url) + rp.serverMap[url.Location] = restServer + } + rp.serverLock.Unlock() + + } + return restServer +} + +func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, restConfig *rest_interface.RestConfig) rest_interface.RestClient { + restClient, ok := rp.clientMap[restOptions] + rp.clientLock.Lock() + if !ok { + restClient, ok = rp.clientMap[restOptions] + if !ok { + restClient = extension.GetNewRestClient(restConfig.Client, &restOptions) + rp.clientMap[restOptions] = restClient + } + } + rp.clientLock.Unlock() + return restClient +} + func (rp *RestProtocol) Destroy() { // destroy rest_server rp.BaseProtocol.Destroy() + for key, server := range rp.serverMap { + server.Destory() + delete(rp.serverMap, key) + } + for key := range rp.clientMap { + delete(rp.clientMap, key) + } } func GetRestProtocol() protocol.Protocol {