From f29c788c9d0824f7978326cb6a5e4e6d6cba756a Mon Sep 17 00:00:00 2001
From: scott <scottwangsxll@gmail.com>
Date: Sat, 14 Mar 2020 12:22:20 +0800
Subject: [PATCH] Fix ci client close race condition

---
 registry/kubernetes/listener_test.go | 10 +---
 registry/kubernetes/registry_test.go | 31 ++++++----
 remoting/kubernetes/client_test.go   | 86 +++++++++++++++-------------
 3 files changed, 69 insertions(+), 58 deletions(-)

diff --git a/registry/kubernetes/listener_test.go b/registry/kubernetes/listener_test.go
index 16bbbf8c7..c9ff62608 100644
--- a/registry/kubernetes/listener_test.go
+++ b/registry/kubernetes/listener_test.go
@@ -187,14 +187,11 @@ type KubernetesRegistryTestSuite struct {
 	suite.Suite
 
 	currentPod v1.Pod
-
-	registry *kubernetesRegistry
 }
 
-func (s *KubernetesRegistryTestSuite) SetupTest() {
+func (s *KubernetesRegistryTestSuite) initRegistry() *kubernetesRegistry {
 
 	t := s.T()
-	var err error
 
 	regurl, err := common.NewURL("registry://127.0.0.1:443", common.WithParamsValue(constant.ROLE_KEY, strconv.Itoa(common.PROVIDER)))
 	if err != nil {
@@ -215,7 +212,7 @@ func (s *KubernetesRegistryTestSuite) SetupTest() {
 		t.Fatal(err)
 	}
 
-	s.registry = mock.(*kubernetesRegistry)
+	return mock.(*kubernetesRegistry)
 }
 
 func (s *KubernetesRegistryTestSuite) SetupSuite() {
@@ -242,9 +239,6 @@ func (s *KubernetesRegistryTestSuite) SetupSuite() {
 	}
 }
 
-// stop etcd server
-func (s *KubernetesRegistryTestSuite) TearDownSuite() {}
-
 func (s *KubernetesRegistryTestSuite) TestDataChange() {
 
 	t := s.T()
diff --git a/registry/kubernetes/registry_test.go b/registry/kubernetes/registry_test.go
index 2bd744d92..cc5ccbb14 100644
--- a/registry/kubernetes/registry_test.go
+++ b/registry/kubernetes/registry_test.go
@@ -34,11 +34,14 @@ func (s *KubernetesRegistryTestSuite) TestRegister() {
 
 	t := s.T()
 
+	r := s.initRegistry()
+	defer r.Destroy()
+
 	url, _ := common.NewURL("dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider", common.WithParamsValue(constant.CLUSTER_KEY, "mock"), common.WithMethods([]string{"GetUser", "AddUser"}))
 
-	err := s.registry.Register(url)
+	err := r.Register(url)
 	assert.NoError(t, err)
-	_, _, err = s.registry.client.GetChildren("/dubbo/com.ikurento.user.UserProvider/providers")
+	_, _, err = r.client.GetChildren("/dubbo/com.ikurento.user.UserProvider/providers")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -48,15 +51,18 @@ func (s *KubernetesRegistryTestSuite) TestSubscribe() {
 
 	t := s.T()
 
+	r := s.initRegistry()
+	defer r.Destroy()
+
 	url, _ := common.NewURL("dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider", common.WithParamsValue(constant.CLUSTER_KEY, "mock"), common.WithMethods([]string{"GetUser", "AddUser"}))
 
-	listener, err := s.registry.DoSubscribe(&url)
+	listener, err := r.DoSubscribe(&url)
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	go func() {
-		err := s.registry.Register(url)
+		err := r.Register(url)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -73,18 +79,21 @@ func (s *KubernetesRegistryTestSuite) TestSubscribe() {
 func (s *KubernetesRegistryTestSuite) TestConsumerDestroy() {
 
 	t := s.T()
+
+	r := s.initRegistry()
+
 	url, _ := common.NewURL("dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider", common.WithParamsValue(constant.CLUSTER_KEY, "mock"), common.WithMethods([]string{"GetUser", "AddUser"}))
 
-	_, err := s.registry.DoSubscribe(&url)
+	_, err := r.DoSubscribe(&url)
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	//listener.Close()
 	time.Sleep(1e9)
-	s.registry.Destroy()
+	r.Destroy()
 
-	assert.Equal(t, false, s.registry.IsAvailable())
+	assert.Equal(t, false, r.IsAvailable())
 
 }
 
@@ -92,12 +101,14 @@ func (s *KubernetesRegistryTestSuite) TestProviderDestroy() {
 
 	t := s.T()
 
+	r := s.initRegistry()
+
 	url, _ := common.NewURL("dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider", common.WithParamsValue(constant.CLUSTER_KEY, "mock"), common.WithMethods([]string{"GetUser", "AddUser"}))
-	err := s.registry.Register(url)
+	err := r.Register(url)
 	assert.NoError(t, err)
 
 	//listener.Close()
 	time.Sleep(1e9)
-	s.registry.Destroy()
-	assert.Equal(t, false, s.registry.IsAvailable())
+	r.Destroy()
+	assert.Equal(t, false, r.IsAvailable())
 }
diff --git a/remoting/kubernetes/client_test.go b/remoting/kubernetes/client_test.go
index 846745bb6..c4a4e822e 100644
--- a/remoting/kubernetes/client_test.go
+++ b/remoting/kubernetes/client_test.go
@@ -24,10 +24,9 @@ import (
 	"sync"
 	"testing"
 	"time"
-)
 
-import (
 	"github.com/stretchr/testify/suite"
+
 	v1 "k8s.io/api/core/v1"
 	"k8s.io/client-go/kubernetes"
 	"k8s.io/client-go/kubernetes/fake"
@@ -196,10 +195,29 @@ var clientPodJsonData = `{
 type KubernetesClientTestSuite struct {
 	suite.Suite
 
-	client     *Client
 	currentPod v1.Pod
 }
 
+func (s *KubernetesClientTestSuite) initClient() *Client {
+
+	t := s.T()
+
+	client, err := newMockClient(s.currentPod.GetNamespace(), func() (kubernetes.Interface, error) {
+
+		out := fake.NewSimpleClientset()
+
+		// mock current pod
+		if _, err := out.CoreV1().Pods(s.currentPod.GetNamespace()).Create(&s.currentPod); err != nil {
+			t.Fatal(err)
+		}
+		return out, nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	return client
+}
+
 func (s *KubernetesClientTestSuite) SetupSuite() {
 
 	t := s.T()
@@ -218,39 +236,19 @@ func (s *KubernetesClientTestSuite) SetupSuite() {
 	}
 }
 
-func (s *KubernetesClientTestSuite) TearDownSuite() {
-	s.client.Close()
-}
-
-func (s *KubernetesClientTestSuite) SetupTest() {
-
-	t := s.T()
-	var err error
-	s.client, err = newMockClient(s.currentPod.GetNamespace(), func() (kubernetes.Interface, error) {
-
-		out := fake.NewSimpleClientset()
-
-		// mock current pod
-		if _, err := out.CoreV1().Pods(s.currentPod.GetNamespace()).Create(&s.currentPod); err != nil {
-			t.Fatal(err)
-		}
-		return out, nil
-	})
-	if err != nil {
-		t.Fatal(err)
-	}
-}
-
 func (s *KubernetesClientTestSuite) TestClientValid() {
 
 	t := s.T()
 
-	if s.client.Valid() != true {
+	client := s.initClient()
+	defer client.Close()
+
+	if client.Valid() != true {
 		t.Fatal("client is not valid")
 	}
-	s.client.Close()
 
-	if s.client.Valid() != false {
+	client.Close()
+	if client.Valid() != false {
 		t.Fatal("client is valid")
 	}
 }
@@ -259,14 +257,16 @@ func (s *KubernetesClientTestSuite) TestClientDone() {
 
 	t := s.T()
 
+	client := s.initClient()
+
 	go func() {
 		time.Sleep(time.Second)
-		s.client.Close()
+		client.Close()
 	}()
 
-	<-s.client.Done()
+	<-client.Done()
 
-	if s.client.Valid() == true {
+	if client.Valid() == true {
 		t.Fatal("client should be invalid then")
 	}
 }
@@ -274,14 +274,16 @@ func (s *KubernetesClientTestSuite) TestClientDone() {
 func (s *KubernetesClientTestSuite) TestClientCreateKV() {
 
 	t := s.T()
-	defer s.client.Close()
+
+	client := s.initClient()
+	defer client.Close()
 
 	for _, tc := range tests {
 
 		k := tc.input.k
 		v := tc.input.v
 
-		if err := s.client.Create(k, v); err != nil {
+		if err := client.Create(k, v); err != nil {
 			t.Fatal(err)
 		}
 
@@ -291,7 +293,9 @@ func (s *KubernetesClientTestSuite) TestClientCreateKV() {
 func (s *KubernetesClientTestSuite) TestClientGetChildrenKVList() {
 
 	t := s.T()
-	defer s.client.Close()
+
+	client := s.initClient()
+	defer client.Close()
 
 	expect := make(map[string]string)
 	got := make(map[string]string)
@@ -305,12 +309,12 @@ func (s *KubernetesClientTestSuite) TestClientGetChildrenKVList() {
 			expect[k] = v
 		}
 
-		if err := s.client.Create(k, v); err != nil {
+		if err := client.Create(k, v); err != nil {
 			t.Fatal(err)
 		}
 	}
 
-	kList, vList, err := s.client.GetChildren(prefix)
+	kList, vList, err := client.GetChildren(prefix)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -332,6 +336,8 @@ func (s *KubernetesClientTestSuite) TestClientWatch() {
 
 	t := s.T()
 
+	client := s.initClient()
+
 	wg := sync.WaitGroup{}
 	wg.Add(1)
 
@@ -339,7 +345,7 @@ func (s *KubernetesClientTestSuite) TestClientWatch() {
 
 		defer wg.Done()
 
-		wc, err := s.client.WatchWithPrefix(prefix)
+		wc, err := client.WatchWithPrefix(prefix)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -355,12 +361,12 @@ func (s *KubernetesClientTestSuite) TestClientWatch() {
 		k := tc.input.k
 		v := tc.input.v
 
-		if err := s.client.Create(k, v); err != nil {
+		if err := client.Create(k, v); err != nil {
 			t.Fatal(err)
 		}
 	}
 
-	s.client.Close()
+	client.Close()
 	wg.Wait()
 }
 
-- 
GitLab