From 08b86ada50981009d89e513ab0605dca27fabfc3 Mon Sep 17 00:00:00 2001
From: "vito.he" <hxmhlt@163.com>
Date: Mon, 9 Sep 2019 13:32:19 +0800
Subject: [PATCH] Mod:bug in directory cache invoker

---
 registry/directory/directory.go      | 13 ++++++++++--
 registry/directory/directory_test.go | 31 ++++++++++++++++++++--------
 2 files changed, 33 insertions(+), 11 deletions(-)

diff --git a/registry/directory/directory.go b/registry/directory/directory.go
index 86bb74f76..54f0acd84 100644
--- a/registry/directory/directory.go
+++ b/registry/directory/directory.go
@@ -117,8 +117,10 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) {
 		if url.Protocol == constant.OVERRIDE_PROTOCOL ||
 			url.GetParam(constant.CATEGORY_KEY, constant.DEFAULT_CATEGORY) == constant.CONFIGURATORS_CATEGORY {
 			dir.configurators = append(dir.configurators, extension.GetDefaultConfigurator(url))
+			url = nil
 		} else if url.Protocol == constant.ROUTER_PROTOCOL || //2.for router
 			url.GetParam(constant.CATEGORY_KEY, constant.DEFAULT_CATEGORY) == constant.ROUTER_CATEGORY {
+			url = nil
 			//TODO: router
 		}
 	}
@@ -198,12 +200,19 @@ func (dir *registryDirectory) cacheInvoker(url *common.URL) {
 	if url.Protocol == referenceUrl.Protocol || referenceUrl.Protocol == "" {
 		newUrl := common.MergeUrl(url, referenceUrl)
 		dir.overrideUrl(newUrl)
-		if _, ok := dir.cacheInvokersMap.Load(newUrl.Key()); !ok {
-			logger.Debugf("service will be added in cache invokers: invokers key is  %s!", url.Key())
+		if cacheInvoker, ok := dir.cacheInvokersMap.Load(newUrl.Key()); !ok {
+			logger.Infof("service will be added in cache invokers: invokers url is  %s!", newUrl)
 			newInvoker := extension.GetProtocol(protocolwrapper.FILTER).Refer(*newUrl)
 			if newInvoker != nil {
 				dir.cacheInvokersMap.Store(newUrl.Key(), newInvoker)
 			}
+		} else {
+			logger.Infof("service will be updated in cache invokers: new invoker url is %s, old invoker url is %s", newUrl, cacheInvoker.(protocol.Invoker).GetUrl())
+			newInvoker := extension.GetProtocol(protocolwrapper.FILTER).Refer(*newUrl)
+			if newInvoker != nil {
+				dir.cacheInvokersMap.Store(newUrl.Key(), newInvoker)
+				cacheInvoker.(protocol.Invoker).Destroy()
+			}
 		}
 	}
 }
diff --git a/registry/directory/directory_test.go b/registry/directory/directory_test.go
index d36257888..e2e18020e 100644
--- a/registry/directory/directory_test.go
+++ b/registry/directory/directory_test.go
@@ -144,15 +144,28 @@ func Test_MergeOverrideUrl(t *testing.T) {
 		common.WithParamsValue(constant.GROUP_KEY, "group"),
 		common.WithParamsValue(constant.VERSION_KEY, "1.0.0"))
 	mockRegistry.MockEvent(&registry.ServiceEvent{Action: remoting.EventTypeAdd, Service: providerUrl})
-	overrideUrl, _ := common.NewURL(context.TODO(), "override://0.0.0.0:20000/org.apache.dubbo-go.mockService",
-		common.WithParamsValue(constant.CLUSTER_KEY, "mock1"),
-		common.WithParamsValue(constant.GROUP_KEY, "group"),
-		common.WithParamsValue(constant.VERSION_KEY, "1.0.0"))
-	mockRegistry.MockEvent(&registry.ServiceEvent{Action: remoting.EventTypeAdd, Service: overrideUrl})
-	time.Sleep(1e9)
-	assert.Len(t, registryDirectory.cacheInvokers, 1)
-	if len(registryDirectory.cacheInvokers) > 0 {
-		assert.Equal(t, "mock1", registryDirectory.cacheInvokers[0].GetUrl().GetParam(constant.CLUSTER_KEY, ""))
+Loop1:
+	for {
+		if len(registryDirectory.cacheInvokers) > 0 {
+			overrideUrl, _ := common.NewURL(context.TODO(), "override://0.0.0.0:20000/org.apache.dubbo-go.mockService",
+				common.WithParamsValue(constant.CLUSTER_KEY, "mock1"),
+				common.WithParamsValue(constant.GROUP_KEY, "group"),
+				common.WithParamsValue(constant.VERSION_KEY, "1.0.0"))
+			mockRegistry.MockEvent(&registry.ServiceEvent{Action: remoting.EventTypeAdd, Service: overrideUrl})
+		Loop2:
+			for {
+				if len(registryDirectory.cacheInvokers) > 0 {
+					if "mock1" == registryDirectory.cacheInvokers[0].GetUrl().GetParam(constant.CLUSTER_KEY, "") {
+						assert.Len(t, registryDirectory.cacheInvokers, 1)
+						assert.True(t, true)
+						break Loop2
+					} else {
+						time.Sleep(500 * time.Millisecond)
+					}
+				}
+			}
+			break Loop1
+		}
 	}
 
 }
-- 
GitLab