From bd45b58aae7e0df9532927e28e600ccc38905cfe Mon Sep 17 00:00:00 2001
From: Joe Zou <yixian.zou@gmail.com>
Date: Wed, 22 Jan 2020 17:08:27 +0800
Subject: [PATCH] Update router

---
 cluster/directory/base_directory.go   | 39 ++++++++++-----------------
 cluster/directory/static_directory.go | 17 ++++--------
 cluster/router/chain/chain.go         | 15 ++++++-----
 cluster/router/chain/chain_test.go    |  3 ++-
 cluster/router/chain/factory.go       |  6 ++---
 cluster/router/router.go              |  8 +++---
 common/extension/router_factory.go    |  6 ++---
 registry/directory/directory.go       | 14 +++-------
 8 files changed, 43 insertions(+), 65 deletions(-)

diff --git a/cluster/directory/base_directory.go b/cluster/directory/base_directory.go
index 90b231d7c..4dc5ef225 100644
--- a/cluster/directory/base_directory.go
+++ b/cluster/directory/base_directory.go
@@ -19,6 +19,7 @@ package directory
 
 import (
 	"github.com/apache/dubbo-go/cluster/router"
+	"github.com/apache/dubbo-go/cluster/router/chain"
 	"sync"
 )
 import (
@@ -28,18 +29,21 @@ import (
 	"github.com/apache/dubbo-go/common"
 	"github.com/apache/dubbo-go/common/constant"
 	"github.com/apache/dubbo-go/common/extension"
-	"github.com/apache/dubbo-go/common/logger"
 	gxset "github.com/dubbogo/gost/container/set"
 )
 
 var routerURLSet = gxset.NewSet()
 
 type BaseDirectory struct {
-	url       *common.URL
-	destroyed *atomic.Bool
-	routers   []router.Router
-	mutex     sync.Mutex
-	once      sync.Once
+	url         *common.URL
+	destroyed   *atomic.Bool
+	mutex       sync.Mutex
+	once        sync.Once
+	routerChain router.Chain
+}
+
+func (dir *BaseDirectory) RouterChain() router.Chain {
+	return dir.routerChain
 }
 
 func GetRouterURLSet() *gxset.HashSet {
@@ -48,8 +52,9 @@ func GetRouterURLSet() *gxset.HashSet {
 
 func NewBaseDirectory(url *common.URL) BaseDirectory {
 	return BaseDirectory{
-		url:       url,
-		destroyed: atomic.NewBool(false),
+		url:         url,
+		destroyed:   atomic.NewBool(false),
+		routerChain: &chain.RouterChain{},
 	}
 }
 func (dir *BaseDirectory) Destroyed() bool {
@@ -75,23 +80,7 @@ func (dir *BaseDirectory) SetRouters(routers []router.Router) {
 
 	dir.mutex.Lock()
 	defer dir.mutex.Unlock()
-	dir.routers = routers
-}
-
-func (dir *BaseDirectory) Routers() []router.Router {
-	dir.once.Do(func() {
-		rs := routerURLSet.Values()
-		for _, r := range rs {
-			factory := extension.GetRouterFactory(r.(*common.URL).GetParam("router", "condition"))
-			router, err := factory.Router(r.(*common.URL))
-			if err != nil {
-				logger.Errorf("router fail! error:%v", err)
-				continue
-			}
-			dir.routers = append(dir.routers, router)
-		}
-	})
-	return dir.routers
+	dir.routerChain.AddRouters(routers)
 }
 
 func (dir *BaseDirectory) Destroy(doDestroy func()) {
diff --git a/cluster/directory/static_directory.go b/cluster/directory/static_directory.go
index 381731aef..6160e523d 100644
--- a/cluster/directory/static_directory.go
+++ b/cluster/directory/static_directory.go
@@ -17,13 +17,8 @@
 
 package directory
 
-import (
-	"reflect"
-)
-
 import (
 	"github.com/apache/dubbo-go/common"
-	"github.com/apache/dubbo-go/common/constant"
 	"github.com/apache/dubbo-go/protocol"
 )
 
@@ -59,15 +54,13 @@ func (dir *staticDirectory) IsAvailable() bool {
 
 func (dir *staticDirectory) List(invocation protocol.Invocation) []protocol.Invoker {
 	invokers := dir.invokers
-	localRouters := dir.routers
+	routerChain := dir.RouterChain()
 
-	for _, router := range localRouters {
-		if reflect.ValueOf(router.Url()).IsNil() || router.Url().GetParamBool(constant.RUNTIME_KEY, false) {
-			dirUrl := dir.GetUrl()
-			return router.Route(invokers, &dirUrl, invocation)
-		}
+	if routerChain == nil {
+		return invokers
 	}
-	return invokers
+	dirUrl := dir.GetUrl()
+	return routerChain.Route(invokers, &dirUrl, invocation)
 }
 
 func (dir *staticDirectory) Destroy() {
diff --git a/cluster/router/chain/chain.go b/cluster/router/chain/chain.go
index 1107e854a..60aeb017b 100644
--- a/cluster/router/chain/chain.go
+++ b/cluster/router/chain/chain.go
@@ -23,11 +23,12 @@ import (
 	"github.com/apache/dubbo-go/common/extension"
 	"github.com/apache/dubbo-go/common/logger"
 	"github.com/apache/dubbo-go/protocol"
+	perrors "github.com/pkg/errors"
 	"sort"
 )
 
 // RouterChain Router chain
-type Chain struct {
+type RouterChain struct {
 	//full list of addresses from registry, classified by method name.
 	invokers []protocol.Invoker
 	//containing all routers, reconstruct every time 'route://' urls change.
@@ -37,14 +38,14 @@ type Chain struct {
 	builtinRouters []router.Router
 }
 
-func (c Chain) Route(invoker []protocol.Invoker, url *common.URL, invocation protocol.Invocation) []protocol.Invoker {
+func (c RouterChain) Route(invoker []protocol.Invoker, url *common.URL, invocation protocol.Invocation) []protocol.Invoker {
 	finalInvokers := invoker
 	for _, r := range c.routers {
 		finalInvokers = r.Route(invoker, url, invocation)
 	}
 	return finalInvokers
 }
-func (c Chain) AddRouters(routers []router.Router) {
+func (c RouterChain) AddRouters(routers []router.Router) {
 	newRouters := make([]router.Router, 0)
 	newRouters = append(newRouters, c.builtinRouters...)
 	newRouters = append(newRouters, routers...)
@@ -52,10 +53,10 @@ func (c Chain) AddRouters(routers []router.Router) {
 	c.routers = newRouters
 }
 
-func NewRouterChain(url *common.URL) *Chain {
+func NewRouterChain(url *common.URL) (*RouterChain, error) {
 	routerFactories := extension.GetRouters()
 	if len(routerFactories) == 0 {
-		return nil
+		return nil, perrors.Errorf("Illegal route rule!")
 	}
 	routers := make([]router.Router, 0)
 	for _, routerFactory := range routerFactories {
@@ -72,12 +73,12 @@ func NewRouterChain(url *common.URL) *Chain {
 
 	sortRouter(newRouters)
 
-	chain := &Chain{
+	chain := &RouterChain{
 		builtinRouters: routers,
 		routers:        newRouters,
 	}
 
-	return chain
+	return chain, nil
 }
 
 func sortRouter(routers []router.Router) {
diff --git a/cluster/router/chain/chain_test.go b/cluster/router/chain/chain_test.go
index 0389977cf..9942549f1 100644
--- a/cluster/router/chain/chain_test.go
+++ b/cluster/router/chain/chain_test.go
@@ -62,7 +62,8 @@ conditions:
 	assert.Nil(t, err)
 	assert.NotNil(t, configuration)
 
-	chain := NewRouterChain(getRouteUrl("test-condition"))
+	chain, err := NewRouterChain(getRouteUrl("test-condition"))
+	assert.Nil(t, err)
 	assert.Equal(t, 1, len(chain.routers))
 	appRouter := chain.routers[0].(*condition.AppRouter)
 
diff --git a/cluster/router/chain/factory.go b/cluster/router/chain/factory.go
index e2166726b..1f24e6df7 100644
--- a/cluster/router/chain/factory.go
+++ b/cluster/router/chain/factory.go
@@ -29,10 +29,10 @@ func init() {
 
 type RouterChainFactory struct{}
 
-func (c RouterChainFactory) Router(*common.URL) (router.RouterChain, error) {
-	panic("implement me")
+func (c RouterChainFactory) Router(url *common.URL) (router.Chain, error) {
+	return NewRouterChain(url)
 }
 
-func NewRouterChainFactory() router.RouterChainFactory {
+func NewRouterChainFactory() router.ChainFactory {
 	return RouterChainFactory{}
 }
diff --git a/cluster/router/router.go b/cluster/router/router.go
index cfebdb52d..e095ca508 100644
--- a/cluster/router/router.go
+++ b/cluster/router/router.go
@@ -27,9 +27,9 @@ type RouterFactory interface {
 	Router(*common.URL) (Router, error)
 }
 
-// Extension - Router Chain
-type RouterChainFactory interface {
-	Router(*common.URL) (RouterChain, error)
+// ChainFactory Extension - Router Chain
+type ChainFactory interface {
+	Router(*common.URL) (Chain, error)
 }
 
 type Router interface {
@@ -38,7 +38,7 @@ type Router interface {
 	Url() common.URL
 }
 
-type RouterChain interface {
+type Chain interface {
 	Route([]protocol.Invoker, *common.URL, protocol.Invocation) []protocol.Invoker
 	AddRouters([]Router)
 }
diff --git a/common/extension/router_factory.go b/common/extension/router_factory.go
index 574cbba82..e6a1cec1a 100644
--- a/common/extension/router_factory.go
+++ b/common/extension/router_factory.go
@@ -23,7 +23,7 @@ import (
 
 var (
 	routers      = make(map[string]func() router.RouterFactory)
-	routerChains = make(map[string]func() router.RouterChainFactory)
+	routerChains = make(map[string]func() router.ChainFactory)
 )
 
 func SetRouterFactory(name string, fun func() router.RouterFactory) {
@@ -37,11 +37,11 @@ func GetRouterFactory(name string) router.RouterFactory {
 	return routers[name]()
 }
 
-func SetRouterChainsFactory(name string, fun func() router.RouterChainFactory) {
+func SetRouterChainsFactory(name string, fun func() router.ChainFactory) {
 	routerChains[name] = fun
 }
 
-func GetRouterChainsFactory(name string) router.RouterChainFactory {
+func GetRouterChainsFactory(name string) router.ChainFactory {
 	if routers[name] == nil {
 		panic("router_chain_factory for " + name + " is not existing, make sure you have import the package.")
 	}
diff --git a/registry/directory/directory.go b/registry/directory/directory.go
index 807bb471b..80c6d8b78 100644
--- a/registry/directory/directory.go
+++ b/registry/directory/directory.go
@@ -19,7 +19,6 @@ package directory
 
 import (
 	"github.com/apache/dubbo-go/cluster/router"
-	"reflect"
 	"strings"
 	"sync"
 	"time"
@@ -271,17 +270,12 @@ func (dir *registryDirectory) cacheInvoker(url *common.URL) {
 //select the protocol invokers from the directory
 func (dir *registryDirectory) List(invocation protocol.Invocation) []protocol.Invoker {
 	invokers := dir.cacheInvokers
-	localRouters := dir.Routers()
+	routerChain := dir.RouterChain()
 
-	if len(localRouters) > 0 {
-		for _, router := range localRouters {
-			if reflect.ValueOf(router.Url()).IsValid() || router.Url().GetParamBool(constant.RUNTIME_KEY, false) {
-				invokers = router.Route(invokers, dir.cacheOriginUrl, invocation)
-			}
-		}
+	if routerChain == nil {
+		return invokers
 	}
-	return invokers
-
+	return routerChain.Route(invokers, dir.cacheOriginUrl, invocation)
 }
 
 func (dir *registryDirectory) IsAvailable() bool {
-- 
GitLab