Skip to content
Snippets Groups Projects
Commit bc696d56 authored by Patrick's avatar Patrick
Browse files

init rest exporter and modify rest invoker

parent afee53c2
No related branches found
No related tags found
No related merge requests found
......@@ -8,11 +8,11 @@ var (
restClients = make(map[string]func(restOptions *rest_interface.RestOptions) rest_interface.RestClient)
)
func SetRestClient(name string, fun func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) {
func SetRestClientFunc(name string, fun func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) {
restClients[name] = fun
}
func GetRestClient(name string, restOptions *rest_interface.RestOptions) rest_interface.RestClient {
func GetNewRestClient(name string, restOptions *rest_interface.RestOptions) rest_interface.RestClient {
if restClients[name] == nil {
panic("restClient for " + name + " is not existing, make sure you have import the package.")
}
......
......@@ -12,7 +12,7 @@ func SetRestConfigReader(name string, fun func() rest_interface.RestConfigReader
restConfigReaders[name] = fun
}
func GetRestConfigReader(name string) rest_interface.RestConfigReader {
func GetSingletonRestConfigReader(name string) rest_interface.RestConfigReader {
if name == "" {
name = "default"
}
......
package extension
import (
"github.com/apache/dubbo-go/protocol/rest/rest_interface"
)
var (
restServers = make(map[string]func() rest_interface.RestServer)
)
func SetRestServerFunc(name string, fun func() rest_interface.RestServer) {
restServers[name] = fun
}
func GetNewRestServer(name string) rest_interface.RestServer {
if restServers[name] == nil {
panic("restServer for " + name + " is not existing, make sure you have import the package.")
}
return restServers[name]()
}
......@@ -16,7 +16,7 @@ import (
)
func init() {
extension.SetRestClient(constant.DEFAULT_REST_CLIENT, GetRestyClient)
extension.SetRestClientFunc(constant.DEFAULT_REST_CLIENT, GetRestyClient)
}
type RestyClient struct {
......
......@@ -28,7 +28,7 @@ func init() {
func initConsumerRestConfig() {
consumerConfigType := config.GetConsumerConfig().RestConfigType
consumerConfigReader := extension.GetRestConfigReader(consumerConfigType)
consumerConfigReader := extension.GetSingletonRestConfigReader(consumerConfigType)
restConsumerConfig = consumerConfigReader.ReadConsumerConfig()
if restConsumerConfig == nil {
return
......@@ -42,7 +42,7 @@ func initConsumerRestConfig() {
func initProviderRestConfig() {
providerConfigType := config.GetProviderConfig().RestConfigType
providerConfigReader := extension.GetRestConfigReader(providerConfigType)
providerConfigReader := extension.GetSingletonRestConfigReader(providerConfigType)
restProviderConfig = providerConfigReader.ReadProviderConfig()
if restProviderConfig == nil {
return
......
package rest
import "github.com/apache/dubbo-go/protocol"
import (
"sync"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/protocol"
)
type RestExporter struct {
protocol.BaseExporter
}
func NewRestExporter() *RestExporter {
return &RestExporter{}
func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map) *RestExporter {
return &RestExporter{
BaseExporter: *protocol.NewBaseExporter(key, invoker, exporterMap),
}
}
func (re *RestExporter) Unexport() {
// undeploy serviceßß
serviceId := re.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
re.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(REST, serviceId)
if err != nil {
logger.Errorf("[RestExporter.Unexport] error: %v", err)
}
return
}
package rest_interface
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
)
type RestServer interface {
Start()
Deploy()
Undeploy()
Start(url common.URL)
Deploy(invoker protocol.Invoker, restMethodConfig map[string]*RestMethodConfig)
Undeploy(restMethodConfig map[string]*RestMethodConfig)
Destory()
}
package rest
import (
"context"
"fmt"
)
......@@ -18,15 +19,15 @@ type RestInvoker struct {
restMethodConfigMap map[string]*rest_interface.RestMethodConfig
}
func NewRestInvoker(url common.URL, client rest_interface.RestClient, restMethodConfig map[string]*rest_interface.RestMethodConfig) *RestInvoker {
func NewRestInvoker(url common.URL, client *rest_interface.RestClient, restMethodConfig map[string]*rest_interface.RestMethodConfig) *RestInvoker {
return &RestInvoker{
BaseInvoker: *protocol.NewBaseInvoker(url),
client: client,
client: *client,
restMethodConfigMap: restMethodConfig,
}
}
func (ri *RestInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
inv := invocation.(*invocation_impl.RPCInvocation)
methodConfig := ri.restMethodConfigMap[inv.MethodName()]
var result protocol.RPCResult
......
......@@ -51,11 +51,11 @@ func TestRestInvoker_Invoke(t *testing.T) {
RestMethodConfigsMap: methodConfigMap,
}
restClient := rest_client.GetRestyClient(&rest_interface.RestOptions{ConnectTimeout: 5 * time.Second, RequestTimeout: 5 * time.Second})
invoker := NewRestInvoker(url, restClient, methodConfigMap)
invoker := NewRestInvoker(url, &restClient, methodConfigMap)
user := &User{}
inv := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("GetUser"),
invocation.WithArguments([]interface{}{"1", "username"}), invocation.WithReply(user))
invoker.Invoke(inv)
invoker.Invoke(context.Background(), inv)
// make sure url
eq := invoker.GetUrl().URLEqual(url)
......
......@@ -7,6 +7,8 @@ import (
"github.com/apache/dubbo-go/config"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/rest/rest_interface"
"strings"
"sync"
"time"
)
......@@ -22,6 +24,10 @@ func init() {
type RestProtocol struct {
protocol.BaseProtocol
serverMap map[string]rest_interface.RestServer
clientMap map[rest_interface.RestOptions]rest_interface.RestClient
serverLock sync.Mutex
clientLock sync.Mutex
}
func NewRestProtocol() *RestProtocol {
......@@ -29,32 +35,76 @@ func NewRestProtocol() *RestProtocol {
}
func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter {
// TODO 当用户注册一个服务的时候,根据ExporterConfig和服务实现,完成Service -> Rest的绑定。注意此处是Service -> Rest,因为此时我们是暴露服务。当收到请求的时候,恰好是暴露服务的反向,即Rest -> Service;
// Server在Export的时候并不做什么事情。但是在接受到请求的时候,它需要负责执行反序列化的过程;
// http server是一个抽象隔离层。它内部允许使用beego或者gin来作为web服务器,接收请求,用户可以扩展自己的实现;
return nil
url := invoker.GetUrl()
serviceKey := strings.TrimPrefix(url.Path, "/")
exporter := NewRestExporter(serviceKey, invoker, rp.ExporterMap())
restConfig := GetRestProviderServiceConfig(url.Service())
rp.SetExporterMap(serviceKey, exporter)
restServer := rp.getServer(url, restConfig)
restServer.Deploy(invoker, restConfig.RestMethodConfigsMap)
return exporter
}
func (rp *RestProtocol) Refer(url common.URL) protocol.Invoker {
// create rest_invoker
var requestTimeout = config.GetConsumerConfig().RequestTimeout
requestTimeoutStr := url.GetParam(constant.TIMEOUT_KEY, config.GetConsumerConfig().Request_Timeout)
connectTimeout := config.GetConsumerConfig().ConnectTimeout
if t, err := time.ParseDuration(requestTimeoutStr); err == nil {
requestTimeout = t
}
restConfig := GetRestConsumerServiceConfig(url.Service())
restClient := extension.GetRestClient(restConfig.Client, &rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout})
invoker := NewRestInvoker(url, restClient, restConfig.RestMethodConfigsMap)
restOptions := rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout}
restClient := rp.getClient(restOptions, restConfig)
invoker := NewRestInvoker(url, &restClient, restConfig.RestMethodConfigsMap)
rp.SetInvokers(invoker)
return invoker
}
func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.RestConfig) rest_interface.RestServer {
restServer, ok := rp.serverMap[url.Location]
if !ok {
_, ok := rp.ExporterMap().Load(strings.TrimPrefix(url.Path, "/"))
if !ok {
panic("[RestProtocol]" + url.Key() + "is not existing")
}
rp.serverLock.Lock()
restServer, ok = rp.serverMap[url.Location]
if !ok {
restServer = extension.GetNewRestServer(restConfig.Server)
restServer.Start(url)
rp.serverMap[url.Location] = restServer
}
rp.serverLock.Unlock()
}
return restServer
}
func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, restConfig *rest_interface.RestConfig) rest_interface.RestClient {
restClient, ok := rp.clientMap[restOptions]
rp.clientLock.Lock()
if !ok {
restClient, ok = rp.clientMap[restOptions]
if !ok {
restClient = extension.GetNewRestClient(restConfig.Client, &restOptions)
rp.clientMap[restOptions] = restClient
}
}
rp.clientLock.Unlock()
return restClient
}
func (rp *RestProtocol) Destroy() {
// destroy rest_server
rp.BaseProtocol.Destroy()
for key, server := range rp.serverMap {
server.Destory()
delete(rp.serverMap, key)
}
for key := range rp.clientMap {
delete(rp.clientMap, key)
}
}
func GetRestProtocol() protocol.Protocol {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment