Skip to content
Snippets Groups Projects
Commit ef0eb7dc authored by Ming Deng's avatar Ming Deng Committed by GitHub
Browse files

Merge pull request #454 from fangyincheng/improve-rpcservice

Add: GetInterface for rpc_service.go
parents 575d360f 5d0327ab
No related branches found
No related tags found
No related merge requests found
Showing with 97 additions and 42 deletions
......@@ -71,7 +71,8 @@ var (
// ServiceMap ...
// todo: lowerecas?
ServiceMap = &serviceMap{
serviceMap: make(map[string]map[string]*Service),
serviceMap: make(map[string]map[string]*Service),
interfaceMap: make(map[string][]*Service),
}
)
......@@ -147,10 +148,12 @@ func (s *Service) Rcvr() reflect.Value {
//////////////////////////
type serviceMap struct {
mutex sync.RWMutex // protects the serviceMap
serviceMap map[string]map[string]*Service // protocol -> service name -> service
mutex sync.RWMutex // protects the serviceMap
serviceMap map[string]map[string]*Service // protocol -> service name -> service
interfaceMap map[string][]*Service // interface -> service
}
// GetService get a service defination by protocol and name
func (sm *serviceMap) GetService(protocol, name string) *Service {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
......@@ -163,10 +166,24 @@ func (sm *serviceMap) GetService(protocol, name string) *Service {
return nil
}
func (sm *serviceMap) Register(protocol string, rcvr RPCService) (string, error) {
// GetInterface get an interface defination by interface name
func (sm *serviceMap) GetInterface(interfaceName string) []*Service {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
if s, ok := sm.interfaceMap[interfaceName]; ok {
return s
}
return nil
}
// Register register a service by @interfaceName and @protocol
func (sm *serviceMap) Register(interfaceName, protocol string, rcvr RPCService) (string, error) {
if sm.serviceMap[protocol] == nil {
sm.serviceMap[protocol] = make(map[string]*Service)
}
if sm.interfaceMap[interfaceName] == nil {
sm.interfaceMap[interfaceName] = make([]*Service, 0, 16)
}
s := new(Service)
s.rcvrType = reflect.TypeOf(rcvr)
......@@ -201,30 +218,61 @@ func (sm *serviceMap) Register(protocol string, rcvr RPCService) (string, error)
}
sm.mutex.Lock()
sm.serviceMap[protocol][s.name] = s
sm.interfaceMap[interfaceName] = append(sm.interfaceMap[interfaceName], s)
sm.mutex.Unlock()
return strings.TrimSuffix(methods, ","), nil
}
func (sm *serviceMap) UnRegister(protocol, serviceId string) error {
// UnRegister cancel a service by @interfaceName, @protocol and @serviceId
func (sm *serviceMap) UnRegister(interfaceName, protocol, serviceId string) error {
if protocol == "" || serviceId == "" {
return perrors.New("protocol or serviceName is nil")
}
sm.mutex.RLock()
svcs, ok := sm.serviceMap[protocol]
if !ok {
sm.mutex.RUnlock()
return perrors.New("no services for " + protocol)
var (
err error
index = -1
svcs map[string]*Service
svrs []*Service
ok bool
)
f := func() error {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
svcs, ok = sm.serviceMap[protocol]
if !ok {
return perrors.New("no services for " + protocol)
}
s, ok := svcs[serviceId]
if !ok {
return perrors.New("no service for " + serviceId)
}
svrs, ok = sm.interfaceMap[interfaceName]
if !ok {
return perrors.New("no service for " + interfaceName)
}
for i, svr := range svrs {
if svr == s {
index = i
}
}
return nil
}
_, ok = svcs[serviceId]
if !ok {
sm.mutex.RUnlock()
return perrors.New("no service for " + serviceId)
if err = f(); err != nil {
return err
}
sm.mutex.RUnlock()
sm.mutex.Lock()
defer sm.mutex.Unlock()
sm.interfaceMap[interfaceName] = make([]*Service, 0, len(svrs))
for i, _ := range svrs {
if i != index {
sm.interfaceMap[interfaceName] = append(sm.interfaceMap[interfaceName], svrs[i])
}
}
delete(svcs, serviceId)
delete(sm.serviceMap, protocol)
......
......@@ -77,46 +77,48 @@ func TestServiceMap_Register(t *testing.T) {
// lowercase
s0 := &testService{}
// methods, err := ServiceMap.Register("testporotocol", s0)
_, err := ServiceMap.Register("testporotocol", s0)
_, err := ServiceMap.Register("testService", "testporotocol", s0)
assert.EqualError(t, err, "type testService is not exported")
// succ
s := &TestService{}
methods, err := ServiceMap.Register("testporotocol", s)
methods, err := ServiceMap.Register("testService", "testporotocol", s)
assert.NoError(t, err)
assert.Equal(t, "MethodOne,MethodThree,methodTwo", methods)
// repeat
_, err = ServiceMap.Register("testporotocol", s)
_, err = ServiceMap.Register("testService", "testporotocol", s)
assert.EqualError(t, err, "service already defined: com.test.Path")
// no method
s1 := &TestService1{}
_, err = ServiceMap.Register("testporotocol", s1)
_, err = ServiceMap.Register("testService", "testporotocol", s1)
assert.EqualError(t, err, "type com.test.Path1 has no exported methods of suitable type")
ServiceMap = &serviceMap{
serviceMap: make(map[string]map[string]*Service),
serviceMap: make(map[string]map[string]*Service),
interfaceMap: make(map[string][]*Service),
}
}
func TestServiceMap_UnRegister(t *testing.T) {
s := &TestService{}
_, err := ServiceMap.Register("testprotocol", s)
_, err := ServiceMap.Register("TestService", "testprotocol", s)
assert.NoError(t, err)
assert.NotNil(t, ServiceMap.GetService("testprotocol", "com.test.Path"))
assert.Equal(t, 1, len(ServiceMap.GetInterface("TestService")))
err = ServiceMap.UnRegister("", "com.test.Path")
err = ServiceMap.UnRegister("", "", "com.test.Path")
assert.EqualError(t, err, "protocol or serviceName is nil")
err = ServiceMap.UnRegister("protocol", "com.test.Path")
err = ServiceMap.UnRegister("", "protocol", "com.test.Path")
assert.EqualError(t, err, "no services for protocol")
err = ServiceMap.UnRegister("testprotocol", "com.test.Path1")
err = ServiceMap.UnRegister("", "testprotocol", "com.test.Path1")
assert.EqualError(t, err, "no service for com.test.Path1")
// succ
err = ServiceMap.UnRegister("testprotocol", "com.test.Path")
err = ServiceMap.UnRegister("TestService", "testprotocol", "com.test.Path")
assert.NoError(t, err)
}
......
......@@ -199,7 +199,7 @@ func Load() {
svs.id = key
svs.Implement(rpcService)
if err := svs.Export(); err != nil {
panic(fmt.Sprintf("service %s export failed! ", key))
panic(fmt.Sprintf("service %s export failed! err: %#v", key, err))
}
}
}
......
......@@ -82,7 +82,8 @@ func TestLoad(t *testing.T) {
conServices = map[string]common.RPCService{}
proServices = map[string]common.RPCService{}
common.ServiceMap.UnRegister("mock", "MockService")
err := common.ServiceMap.UnRegister("com.MockService", "mock", "MockService")
assert.Nil(t, err)
consumerConfig = nil
providerConfig = nil
}
......@@ -110,7 +111,7 @@ func TestLoadWithSingleReg(t *testing.T) {
conServices = map[string]common.RPCService{}
proServices = map[string]common.RPCService{}
common.ServiceMap.UnRegister("mock", "MockService")
common.ServiceMap.UnRegister("com.MockService", "mock", "MockService")
consumerConfig = nil
providerConfig = nil
}
......@@ -139,7 +140,7 @@ func TestWithNoRegLoad(t *testing.T) {
conServices = map[string]common.RPCService{}
proServices = map[string]common.RPCService{}
common.ServiceMap.UnRegister("mock", "MockService")
common.ServiceMap.UnRegister("com.MockService", "mock", "MockService")
consumerConfig = nil
providerConfig = nil
}
......
......@@ -129,7 +129,7 @@ func (c *ServiceConfig) Export() error {
}
for _, proto := range protocolConfigs {
// registry the service reflect
methods, err := common.ServiceMap.Register(proto.Name, c.rpcService)
methods, err := common.ServiceMap.Register(c.InterfaceName, proto.Name, c.rpcService)
if err != nil {
err := perrors.Errorf("The service %v export the protocol %v error! Error message is %v .", c.InterfaceName, proto.Name, err.Error())
logger.Errorf(err.Error())
......
......@@ -96,7 +96,7 @@ func TestGenericServiceFilter_Invoke(t *testing.T) {
hessian.Object("222")},
}
s := &TestService{}
_, _ = common.ServiceMap.Register("testprotocol", s)
_, _ = common.ServiceMap.Register("TestService", "testprotocol", s)
rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
filter := GetGenericServiceFilter()
url, _ := common.NewURL("testprotocol://127.0.0.1:20000/com.test.Path")
......
......@@ -162,7 +162,7 @@ func InitTest(t *testing.T) (protocol.Protocol, common.URL) {
hessian.RegisterPOJO(&User{})
methods, err := common.ServiceMap.Register("dubbo", &UserProvider{})
methods, err := common.ServiceMap.Register("com.ikurento.user.UserProvider", "dubbo", &UserProvider{})
assert.NoError(t, err)
assert.Equal(t, "GetBigPkg,GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4,GetUser5,GetUser6", methods)
......
......@@ -43,8 +43,9 @@ func NewDubboExporter(key string, invoker protocol.Invoker, exporterMap *sync.Ma
// Unexport ...
func (de *DubboExporter) Unexport() {
serviceId := de.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
interfaceName := de.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
de.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(DUBBO, serviceId)
err := common.ServiceMap.UnRegister(interfaceName, DUBBO, serviceId)
if err != nil {
logger.Errorf("[DubboExporter.Unexport] error: %v", err)
}
......
......@@ -43,8 +43,9 @@ func NewGrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
// Unexport ...
func (gg *GrpcExporter) Unexport() {
serviceId := gg.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
interfaceName := gg.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
gg.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(GRPC, serviceId)
err := common.ServiceMap.UnRegister(interfaceName, GRPC, serviceId)
if err != nil {
logger.Errorf("[GrpcExporter.Unexport] error: %v", err)
}
......
......@@ -50,7 +50,7 @@ type (
func TestHTTPClient_Call(t *testing.T) {
methods, err := common.ServiceMap.Register("jsonrpc", &UserProvider{})
methods, err := common.ServiceMap.Register("com.ikurento.user.UserProvider", "jsonrpc", &UserProvider{})
assert.NoError(t, err)
assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods)
......
......@@ -43,8 +43,9 @@ func NewJsonrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.
// Unexport ...
func (je *JsonrpcExporter) Unexport() {
serviceId := je.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
interfaceName := je.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
je.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(JSONRPC, serviceId)
err := common.ServiceMap.UnRegister(interfaceName, JSONRPC, serviceId)
if err != nil {
logger.Errorf("[JsonrpcExporter.Unexport] error: %v", err)
}
......
......@@ -36,7 +36,7 @@ import (
func TestJsonrpcInvoker_Invoke(t *testing.T) {
methods, err := common.ServiceMap.Register("jsonrpc", &UserProvider{})
methods, err := common.ServiceMap.Register("UserProvider", "jsonrpc", &UserProvider{})
assert.NoError(t, err)
assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods)
......
......@@ -40,8 +40,9 @@ func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
func (re *RestExporter) Unexport() {
serviceId := re.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
interfaceName := re.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
re.BaseExporter.Unexport()
err := common.ServiceMap.UnRegister(REST, serviceId)
err := common.ServiceMap.UnRegister(interfaceName, REST, serviceId)
if err != nil {
logger.Errorf("[RestExporter.Unexport] error: %v", err)
}
......
......@@ -61,7 +61,7 @@ func TestRestInvoker_Invoke(t *testing.T) {
"module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" +
"side=provider&timeout=3000&timestamp=1556509797245")
assert.NoError(t, err)
_, err = common.ServiceMap.Register(url.Protocol, &UserProvider{})
_, err = common.ServiceMap.Register("UserProvider", url.Protocol, &UserProvider{})
assert.NoError(t, err)
con := config.ProviderConfig{}
config.SetProviderConfig(con)
......@@ -206,6 +206,6 @@ func TestRestInvoker_Invoke(t *testing.T) {
assert.Error(t, res.Error(), "test error")
assert.Equal(t, filterNum, 12)
err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider")
err = common.ServiceMap.UnRegister("UserProvider", url.Protocol, "com.ikurento.user.UserProvider")
assert.NoError(t, err)
}
......@@ -80,7 +80,7 @@ func TestRestProtocol_Export(t *testing.T) {
"module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" +
"side=provider&timeout=3000&timestamp=1556509797245")
assert.NoError(t, err)
_, err = common.ServiceMap.Register(url.Protocol, &UserProvider{})
_, err = common.ServiceMap.Register("UserProvider", url.Protocol, &UserProvider{})
assert.NoError(t, err)
con := config.ProviderConfig{}
config.SetProviderConfig(con)
......@@ -128,7 +128,7 @@ func TestRestProtocol_Export(t *testing.T) {
proto.Destroy()
_, ok = proto.(*RestProtocol).serverMap[url.Location]
assert.False(t, ok)
err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider")
err = common.ServiceMap.UnRegister("UserProvider", url.Protocol, "com.ikurento.user.UserProvider")
assert.NoError(t, err)
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment