From bbe18f76c877f31df5dfec7386788e4f105c1bb6 Mon Sep 17 00:00:00 2001
From: fangyincheng <fangyincheng@sina.com>
Date: Tue, 24 Mar 2020 02:25:31 +0800
Subject: [PATCH] Add: GetInterface for rpc_service.go

---
 common/rpc_service.go                         | 45 ++++++++++++++++---
 common/rpc_service_test.go                    | 22 ++++-----
 config/config_loader.go                       |  2 +-
 config/config_loader_test.go                  |  7 +--
 config/service_config.go                      |  2 +-
 .../generic_service_filter_test.go            |  2 +-
 protocol/dubbo/client_test.go                 |  2 +-
 protocol/dubbo/dubbo_exporter.go              |  3 +-
 protocol/grpc/grpc_exporter.go                |  3 +-
 protocol/jsonrpc/http_test.go                 |  2 +-
 protocol/jsonrpc/jsonrpc_exporter.go          |  3 +-
 protocol/jsonrpc/jsonrpc_invoker_test.go      |  2 +-
 protocol/rest/rest_exporter.go                |  3 +-
 protocol/rest/rest_invoker_test.go            |  4 +-
 protocol/rest/rest_protocol_test.go           |  4 +-
 15 files changed, 73 insertions(+), 33 deletions(-)

diff --git a/common/rpc_service.go b/common/rpc_service.go
index b235c32ab..d05f527a6 100644
--- a/common/rpc_service.go
+++ b/common/rpc_service.go
@@ -71,7 +71,8 @@ var (
 	// ServiceMap ...
 	// todo: lowerecas?
 	ServiceMap = &serviceMap{
-		serviceMap: make(map[string]map[string]*Service),
+		serviceMap:   make(map[string]map[string]*Service),
+		interfaceMap: make(map[string][]*Service),
 	}
 )
 
@@ -147,8 +148,9 @@ func (s *Service) Rcvr() reflect.Value {
 //////////////////////////
 
 type serviceMap struct {
-	mutex      sync.RWMutex                   // protects the serviceMap
-	serviceMap map[string]map[string]*Service // protocol -> service name -> service
+	mutex        sync.RWMutex                   // protects the serviceMap
+	serviceMap   map[string]map[string]*Service // protocol -> service name -> service
+	interfaceMap map[string][]*Service          // interface -> service
 }
 
 func (sm *serviceMap) GetService(protocol, name string) *Service {
@@ -163,10 +165,23 @@ func (sm *serviceMap) GetService(protocol, name string) *Service {
 	return nil
 }
 
-func (sm *serviceMap) Register(protocol string, rcvr RPCService) (string, error) {
+// GetInterface get an interface defination by interface name
+func (sm *serviceMap) GetInterface(interfaceName string) []*Service {
+	sm.mutex.RLock()
+	defer sm.mutex.RUnlock()
+	if s, ok := sm.interfaceMap[interfaceName]; ok {
+		return s
+	}
+	return nil
+}
+
+func (sm *serviceMap) Register(interfaceName, protocol string, rcvr RPCService) (string, error) {
 	if sm.serviceMap[protocol] == nil {
 		sm.serviceMap[protocol] = make(map[string]*Service)
 	}
+	if sm.interfaceMap[interfaceName] == nil {
+		sm.interfaceMap[interfaceName] = make([]*Service, 0, 16)
+	}
 
 	s := new(Service)
 	s.rcvrType = reflect.TypeOf(rcvr)
@@ -201,12 +216,13 @@ func (sm *serviceMap) Register(protocol string, rcvr RPCService) (string, error)
 	}
 	sm.mutex.Lock()
 	sm.serviceMap[protocol][s.name] = s
+	sm.interfaceMap[interfaceName] = append(sm.interfaceMap[interfaceName], s)
 	sm.mutex.Unlock()
 
 	return strings.TrimSuffix(methods, ","), nil
 }
 
-func (sm *serviceMap) UnRegister(protocol, serviceId string) error {
+func (sm *serviceMap) UnRegister(interfaceName, protocol, serviceId string) error {
 	if protocol == "" || serviceId == "" {
 		return perrors.New("protocol or serviceName is nil")
 	}
@@ -216,15 +232,32 @@ func (sm *serviceMap) UnRegister(protocol, serviceId string) error {
 		sm.mutex.RUnlock()
 		return perrors.New("no services for " + protocol)
 	}
-	_, ok = svcs[serviceId]
+	s, ok := svcs[serviceId]
 	if !ok {
 		sm.mutex.RUnlock()
 		return perrors.New("no service for " + serviceId)
 	}
+	svrs, ok := sm.interfaceMap[interfaceName]
+	if !ok {
+		sm.mutex.RUnlock()
+		return perrors.New("no service for " + interfaceName)
+	}
+	index := -1
+	for i, svr := range svrs {
+		if svr == s {
+			index = i
+		}
+	}
 	sm.mutex.RUnlock()
 
 	sm.mutex.Lock()
 	defer sm.mutex.Unlock()
+	sm.interfaceMap[interfaceName] = make([]*Service, 0, len(svrs))
+	for i, _ := range svrs {
+		if i != index {
+			sm.interfaceMap[interfaceName] = append(sm.interfaceMap[interfaceName], svrs[i])
+		}
+	}
 	delete(svcs, serviceId)
 	delete(sm.serviceMap, protocol)
 
diff --git a/common/rpc_service_test.go b/common/rpc_service_test.go
index 8c9b9d15c..2311205d0 100644
--- a/common/rpc_service_test.go
+++ b/common/rpc_service_test.go
@@ -77,46 +77,48 @@ func TestServiceMap_Register(t *testing.T) {
 	// lowercase
 	s0 := &testService{}
 	// methods, err := ServiceMap.Register("testporotocol", s0)
-	_, err := ServiceMap.Register("testporotocol", s0)
+	_, err := ServiceMap.Register("testService", "testporotocol", s0)
 	assert.EqualError(t, err, "type testService is not exported")
 
 	// succ
 	s := &TestService{}
-	methods, err := ServiceMap.Register("testporotocol", s)
+	methods, err := ServiceMap.Register("testService", "testporotocol", s)
 	assert.NoError(t, err)
 	assert.Equal(t, "MethodOne,MethodThree,methodTwo", methods)
 
 	// repeat
-	_, err = ServiceMap.Register("testporotocol", s)
+	_, err = ServiceMap.Register("testService", "testporotocol", s)
 	assert.EqualError(t, err, "service already defined: com.test.Path")
 
 	// no method
 	s1 := &TestService1{}
-	_, err = ServiceMap.Register("testporotocol", s1)
+	_, err = ServiceMap.Register("testService", "testporotocol", s1)
 	assert.EqualError(t, err, "type com.test.Path1 has no exported methods of suitable type")
 
 	ServiceMap = &serviceMap{
-		serviceMap: make(map[string]map[string]*Service),
+		serviceMap:   make(map[string]map[string]*Service),
+		interfaceMap: make(map[string][]*Service),
 	}
 }
 
 func TestServiceMap_UnRegister(t *testing.T) {
 	s := &TestService{}
-	_, err := ServiceMap.Register("testprotocol", s)
+	_, err := ServiceMap.Register("TestService", "testprotocol", s)
 	assert.NoError(t, err)
 	assert.NotNil(t, ServiceMap.GetService("testprotocol", "com.test.Path"))
+	assert.Equal(t, 1, len(ServiceMap.GetInterface("TestService")))
 
-	err = ServiceMap.UnRegister("", "com.test.Path")
+	err = ServiceMap.UnRegister("", "", "com.test.Path")
 	assert.EqualError(t, err, "protocol or serviceName is nil")
 
-	err = ServiceMap.UnRegister("protocol", "com.test.Path")
+	err = ServiceMap.UnRegister("", "protocol", "com.test.Path")
 	assert.EqualError(t, err, "no services for protocol")
 
-	err = ServiceMap.UnRegister("testprotocol", "com.test.Path1")
+	err = ServiceMap.UnRegister("", "testprotocol", "com.test.Path1")
 	assert.EqualError(t, err, "no service for com.test.Path1")
 
 	// succ
-	err = ServiceMap.UnRegister("testprotocol", "com.test.Path")
+	err = ServiceMap.UnRegister("TestService", "testprotocol", "com.test.Path")
 	assert.NoError(t, err)
 }
 
diff --git a/config/config_loader.go b/config/config_loader.go
index c0687d8fc..61cb49457 100644
--- a/config/config_loader.go
+++ b/config/config_loader.go
@@ -199,7 +199,7 @@ func Load() {
 			svs.id = key
 			svs.Implement(rpcService)
 			if err := svs.Export(); err != nil {
-				panic(fmt.Sprintf("service %s export failed! ", key))
+				panic(fmt.Sprintf("service %s export failed! err: %#v", key, err))
 			}
 		}
 	}
diff --git a/config/config_loader_test.go b/config/config_loader_test.go
index 498f82678..6368fcbd2 100644
--- a/config/config_loader_test.go
+++ b/config/config_loader_test.go
@@ -82,7 +82,8 @@ func TestLoad(t *testing.T) {
 
 	conServices = map[string]common.RPCService{}
 	proServices = map[string]common.RPCService{}
-	common.ServiceMap.UnRegister("mock", "MockService")
+	err := common.ServiceMap.UnRegister("com.MockService", "mock", "MockService")
+	assert.Nil(t, err)
 	consumerConfig = nil
 	providerConfig = nil
 }
@@ -110,7 +111,7 @@ func TestLoadWithSingleReg(t *testing.T) {
 
 	conServices = map[string]common.RPCService{}
 	proServices = map[string]common.RPCService{}
-	common.ServiceMap.UnRegister("mock", "MockService")
+	common.ServiceMap.UnRegister("com.MockService", "mock", "MockService")
 	consumerConfig = nil
 	providerConfig = nil
 }
@@ -139,7 +140,7 @@ func TestWithNoRegLoad(t *testing.T) {
 
 	conServices = map[string]common.RPCService{}
 	proServices = map[string]common.RPCService{}
-	common.ServiceMap.UnRegister("mock", "MockService")
+	common.ServiceMap.UnRegister("com.MockService", "mock", "MockService")
 	consumerConfig = nil
 	providerConfig = nil
 }
diff --git a/config/service_config.go b/config/service_config.go
index faa8dc9f8..61d24ced7 100644
--- a/config/service_config.go
+++ b/config/service_config.go
@@ -128,7 +128,7 @@ func (c *ServiceConfig) Export() error {
 	}
 	for _, proto := range protocolConfigs {
 		// registry the service reflect
-		methods, err := common.ServiceMap.Register(proto.Name, c.rpcService)
+		methods, err := common.ServiceMap.Register(c.InterfaceName, proto.Name, c.rpcService)
 		if err != nil {
 			err := perrors.Errorf("The service %v  export the protocol %v error! Error message is %v .", c.InterfaceName, proto.Name, err.Error())
 			logger.Errorf(err.Error())
diff --git a/filter/filter_impl/generic_service_filter_test.go b/filter/filter_impl/generic_service_filter_test.go
index 37c6af745..2a911659f 100644
--- a/filter/filter_impl/generic_service_filter_test.go
+++ b/filter/filter_impl/generic_service_filter_test.go
@@ -96,7 +96,7 @@ func TestGenericServiceFilter_Invoke(t *testing.T) {
 			hessian.Object("222")},
 	}
 	s := &TestService{}
-	_, _ = common.ServiceMap.Register("testprotocol", s)
+	_, _ = common.ServiceMap.Register("TestService", "testprotocol", s)
 	rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil)
 	filter := GetGenericServiceFilter()
 	url, _ := common.NewURL("testprotocol://127.0.0.1:20000/com.test.Path")
diff --git a/protocol/dubbo/client_test.go b/protocol/dubbo/client_test.go
index 1e0a73fac..744ffa80d 100644
--- a/protocol/dubbo/client_test.go
+++ b/protocol/dubbo/client_test.go
@@ -162,7 +162,7 @@ func InitTest(t *testing.T) (protocol.Protocol, common.URL) {
 
 	hessian.RegisterPOJO(&User{})
 
-	methods, err := common.ServiceMap.Register("dubbo", &UserProvider{})
+	methods, err := common.ServiceMap.Register("com.ikurento.user.UserProvider", "dubbo", &UserProvider{})
 	assert.NoError(t, err)
 	assert.Equal(t, "GetBigPkg,GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4,GetUser5,GetUser6", methods)
 
diff --git a/protocol/dubbo/dubbo_exporter.go b/protocol/dubbo/dubbo_exporter.go
index f4cd0cc12..1c45c4005 100644
--- a/protocol/dubbo/dubbo_exporter.go
+++ b/protocol/dubbo/dubbo_exporter.go
@@ -43,8 +43,9 @@ func NewDubboExporter(key string, invoker protocol.Invoker, exporterMap *sync.Ma
 // Unexport ...
 func (de *DubboExporter) Unexport() {
 	serviceId := de.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
+	interfaceName := de.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
 	de.BaseExporter.Unexport()
-	err := common.ServiceMap.UnRegister(DUBBO, serviceId)
+	err := common.ServiceMap.UnRegister(interfaceName, DUBBO, serviceId)
 	if err != nil {
 		logger.Errorf("[DubboExporter.Unexport] error: %v", err)
 	}
diff --git a/protocol/grpc/grpc_exporter.go b/protocol/grpc/grpc_exporter.go
index 3c38ef974..5b7ff36c1 100644
--- a/protocol/grpc/grpc_exporter.go
+++ b/protocol/grpc/grpc_exporter.go
@@ -43,8 +43,9 @@ func NewGrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
 // Unexport ...
 func (gg *GrpcExporter) Unexport() {
 	serviceId := gg.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
+	interfaceName := gg.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
 	gg.BaseExporter.Unexport()
-	err := common.ServiceMap.UnRegister(GRPC, serviceId)
+	err := common.ServiceMap.UnRegister(interfaceName, GRPC, serviceId)
 	if err != nil {
 		logger.Errorf("[GrpcExporter.Unexport] error: %v", err)
 	}
diff --git a/protocol/jsonrpc/http_test.go b/protocol/jsonrpc/http_test.go
index 0cb88b36a..f8480bf32 100644
--- a/protocol/jsonrpc/http_test.go
+++ b/protocol/jsonrpc/http_test.go
@@ -50,7 +50,7 @@ type (
 
 func TestHTTPClient_Call(t *testing.T) {
 
-	methods, err := common.ServiceMap.Register("jsonrpc", &UserProvider{})
+	methods, err := common.ServiceMap.Register("com.ikurento.user.UserProvider", "jsonrpc", &UserProvider{})
 	assert.NoError(t, err)
 	assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods)
 
diff --git a/protocol/jsonrpc/jsonrpc_exporter.go b/protocol/jsonrpc/jsonrpc_exporter.go
index 7f8fd4918..c61cf9ada 100644
--- a/protocol/jsonrpc/jsonrpc_exporter.go
+++ b/protocol/jsonrpc/jsonrpc_exporter.go
@@ -43,8 +43,9 @@ func NewJsonrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.
 // Unexport ...
 func (je *JsonrpcExporter) Unexport() {
 	serviceId := je.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
+	interfaceName := je.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
 	je.BaseExporter.Unexport()
-	err := common.ServiceMap.UnRegister(JSONRPC, serviceId)
+	err := common.ServiceMap.UnRegister(interfaceName, JSONRPC, serviceId)
 	if err != nil {
 		logger.Errorf("[JsonrpcExporter.Unexport] error: %v", err)
 	}
diff --git a/protocol/jsonrpc/jsonrpc_invoker_test.go b/protocol/jsonrpc/jsonrpc_invoker_test.go
index 9e08eed2b..0f14ba11e 100644
--- a/protocol/jsonrpc/jsonrpc_invoker_test.go
+++ b/protocol/jsonrpc/jsonrpc_invoker_test.go
@@ -36,7 +36,7 @@ import (
 
 func TestJsonrpcInvoker_Invoke(t *testing.T) {
 
-	methods, err := common.ServiceMap.Register("jsonrpc", &UserProvider{})
+	methods, err := common.ServiceMap.Register("UserProvider", "jsonrpc", &UserProvider{})
 	assert.NoError(t, err)
 	assert.Equal(t, "GetUser,GetUser0,GetUser1,GetUser2,GetUser3,GetUser4", methods)
 
diff --git a/protocol/rest/rest_exporter.go b/protocol/rest/rest_exporter.go
index 470d525ad..1ee208615 100644
--- a/protocol/rest/rest_exporter.go
+++ b/protocol/rest/rest_exporter.go
@@ -40,8 +40,9 @@ func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
 
 func (re *RestExporter) Unexport() {
 	serviceId := re.GetInvoker().GetUrl().GetParam(constant.BEAN_NAME_KEY, "")
+	interfaceName := re.GetInvoker().GetUrl().GetParam(constant.INTERFACE_KEY, "")
 	re.BaseExporter.Unexport()
-	err := common.ServiceMap.UnRegister(REST, serviceId)
+	err := common.ServiceMap.UnRegister(interfaceName, REST, serviceId)
 	if err != nil {
 		logger.Errorf("[RestExporter.Unexport] error: %v", err)
 	}
diff --git a/protocol/rest/rest_invoker_test.go b/protocol/rest/rest_invoker_test.go
index e44c5d9a2..2ea260c58 100644
--- a/protocol/rest/rest_invoker_test.go
+++ b/protocol/rest/rest_invoker_test.go
@@ -61,7 +61,7 @@ func TestRestInvoker_Invoke(t *testing.T) {
 		"module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" +
 		"side=provider&timeout=3000&timestamp=1556509797245")
 	assert.NoError(t, err)
-	_, err = common.ServiceMap.Register(url.Protocol, &UserProvider{})
+	_, err = common.ServiceMap.Register("UserProvider", url.Protocol, &UserProvider{})
 	assert.NoError(t, err)
 	con := config.ProviderConfig{}
 	config.SetProviderConfig(con)
@@ -206,6 +206,6 @@ func TestRestInvoker_Invoke(t *testing.T) {
 	assert.Error(t, res.Error(), "test error")
 
 	assert.Equal(t, filterNum, 12)
-	err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider")
+	err = common.ServiceMap.UnRegister("UserProvider", url.Protocol, "com.ikurento.user.UserProvider")
 	assert.NoError(t, err)
 }
diff --git a/protocol/rest/rest_protocol_test.go b/protocol/rest/rest_protocol_test.go
index 8af73a183..911714877 100644
--- a/protocol/rest/rest_protocol_test.go
+++ b/protocol/rest/rest_protocol_test.go
@@ -80,7 +80,7 @@ func TestRestProtocol_Export(t *testing.T) {
 		"module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&" +
 		"side=provider&timeout=3000&timestamp=1556509797245")
 	assert.NoError(t, err)
-	_, err = common.ServiceMap.Register(url.Protocol, &UserProvider{})
+	_, err = common.ServiceMap.Register("UserProvider", url.Protocol, &UserProvider{})
 	assert.NoError(t, err)
 	con := config.ProviderConfig{}
 	config.SetProviderConfig(con)
@@ -128,7 +128,7 @@ func TestRestProtocol_Export(t *testing.T) {
 	proto.Destroy()
 	_, ok = proto.(*RestProtocol).serverMap[url.Location]
 	assert.False(t, ok)
-	err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider")
+	err = common.ServiceMap.UnRegister("UserProvider", url.Protocol, "com.ikurento.user.UserProvider")
 	assert.NoError(t, err)
 }
 
-- 
GitLab