Skip to content
Snippets Groups Projects
Commit 4e517c32 authored by 邹毅贤's avatar 邹毅贤
Browse files

fix review comment

parent 43917bc1
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package directory package directory
import ( import (
"github.com/apache/dubbo-go/common/logger"
"sync" "sync"
) )
...@@ -75,17 +76,29 @@ func (dir *BaseDirectory) GetDirectoryUrl() *common.URL { ...@@ -75,17 +76,29 @@ func (dir *BaseDirectory) GetDirectoryUrl() *common.URL {
return dir.url return dir.url
} }
func (dir *BaseDirectory) SetRouters(routers []router.Router) { func (dir *BaseDirectory) SetRouters(urls []*common.URL) {
routerKey := dir.GetUrl().GetParam(constant.ROUTER_KEY, "") if len(urls) == 0 {
if len(routerKey) > 0 { return
factory := extension.GetRouterFactory(dir.GetUrl().Protocol) }
url := dir.GetUrl()
router, err := factory.Router(&url) routers := make([]router.Router, len(urls), len(urls))
if err == nil {
routers = append(routers, router) for _, url := range urls {
routerKey := url.GetParam(constant.ROUTER_KEY, "")
if len(routerKey) > 0 {
factory := extension.GetRouterFactory(url.Protocol)
r, err := factory.Router(url)
if err != nil {
logger.Errorf("Create router fail. router key: %s, error: %v", routerKey, url.Service(), err)
return
}
routers = append(routers, r)
} }
} }
logger.Infof("Init file condition router success, size: %v", len(routers))
dir.routerChain.AddRouters(routers) dir.routerChain.AddRouters(routers)
} }
......
...@@ -56,7 +56,7 @@ func (c RouterChain) Route(invoker []protocol.Invoker, url *common.URL, invocati ...@@ -56,7 +56,7 @@ func (c RouterChain) Route(invoker []protocol.Invoker, url *common.URL, invocati
c.mutex.RUnlock() c.mutex.RUnlock()
for _, r := range rs { for _, r := range rs {
finalInvokers = r.Route(invoker, url, invocation) finalInvokers = r.Route(finalInvokers, url, invocation)
} }
return finalInvokers return finalInvokers
} }
......
...@@ -121,31 +121,31 @@ func TestRoute_matchWhen(t *testing.T) { ...@@ -121,31 +121,31 @@ func TestRoute_matchWhen(t *testing.T) {
rule := base64.URLEncoding.EncodeToString([]byte("=> host = 1.2.3.4")) rule := base64.URLEncoding.EncodeToString([]byte("=> host = 1.2.3.4"))
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule)) router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
cUrl, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService") cUrl, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService")
matchWhen, _ := router.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen := router.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, true, matchWhen) assert.Equal(t, true, matchWhen)
rule1 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) rule1 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router1, _ := NewConditionRouterFactory().Router(getRouteUrl(rule1)) router1, _ := NewConditionRouterFactory().Router(getRouteUrl(rule1))
matchWhen1, _ := router1.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen1 := router1.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, true, matchWhen1) assert.Equal(t, true, matchWhen1)
rule2 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4")) rule2 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"))
router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2)) router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2))
matchWhen2, _ := router2.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen2 := router2.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, false, matchWhen2) assert.Equal(t, false, matchWhen2)
rule3 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.4 & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) rule3 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.4 & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3)) router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3))
matchWhen3, _ := router3.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen3 := router3.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, true, matchWhen3) assert.Equal(t, true, matchWhen3)
rule4 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) rule4 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router4, _ := NewConditionRouterFactory().Router(getRouteUrl(rule4)) router4, _ := NewConditionRouterFactory().Router(getRouteUrl(rule4))
matchWhen4, _ := router4.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen4 := router4.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, true, matchWhen4) assert.Equal(t, true, matchWhen4)
rule5 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.1 => host = 1.2.3.4")) rule5 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.1 => host = 1.2.3.4"))
router5, _ := NewConditionRouterFactory().Router(getRouteUrl(rule5)) router5, _ := NewConditionRouterFactory().Router(getRouteUrl(rule5))
matchWhen5, _ := router5.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen5 := router5.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, false, matchWhen5) assert.Equal(t, false, matchWhen5)
rule6 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.2 => host = 1.2.3.4")) rule6 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.2 => host = 1.2.3.4"))
router6, _ := NewConditionRouterFactory().Router(getRouteUrl(rule6)) router6, _ := NewConditionRouterFactory().Router(getRouteUrl(rule6))
matchWhen6, _ := router6.(*ConditionRouter).MatchWhen(&cUrl, inv) matchWhen6 := router6.(*ConditionRouter).MatchWhen(&cUrl, inv)
assert.Equal(t, true, matchWhen6) assert.Equal(t, true, matchWhen6)
} }
...@@ -189,20 +189,20 @@ func TestRoute_methodRoute(t *testing.T) { ...@@ -189,20 +189,20 @@ func TestRoute_methodRoute(t *testing.T) {
rule := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) rule := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule)) router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
url, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=setFoo,getFoo,findFoo") url, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=setFoo,getFoo,findFoo")
matchWhen, _ := router.(*ConditionRouter).MatchWhen(&url, inv) matchWhen := router.(*ConditionRouter).MatchWhen(&url, inv)
assert.Equal(t, true, matchWhen) assert.Equal(t, true, matchWhen)
url1, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo") url1, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
matchWhen, _ = router.(*ConditionRouter).MatchWhen(&url1, inv) matchWhen = router.(*ConditionRouter).MatchWhen(&url1, inv)
assert.Equal(t, true, matchWhen) assert.Equal(t, true, matchWhen)
url2, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo") url2, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
rule2 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host!=1.1.1.1 => host = 1.2.3.4")) rule2 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host!=1.1.1.1 => host = 1.2.3.4"))
router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2)) router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2))
matchWhen, _ = router2.(*ConditionRouter).MatchWhen(&url2, inv) matchWhen = router2.(*ConditionRouter).MatchWhen(&url2, inv)
assert.Equal(t, false, matchWhen) assert.Equal(t, false, matchWhen)
url3, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo") url3, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
rule3 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host=1.1.1.1 => host = 1.2.3.4")) rule3 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host=1.1.1.1 => host = 1.2.3.4"))
router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3)) router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3))
matchWhen, _ = router3.(*ConditionRouter).MatchWhen(&url3, inv) matchWhen = router3.(*ConditionRouter).MatchWhen(&url3, inv)
assert.Equal(t, true, matchWhen) assert.Equal(t, true, matchWhen)
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
package condition package condition
import ( import (
"reflect"
"regexp" "regexp"
"strings" "strings"
) )
...@@ -33,8 +32,8 @@ import ( ...@@ -33,8 +32,8 @@ import (
"github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/protocol" "github.com/apache/dubbo-go/protocol"
gxset "github.com/dubbogo/gost/container/set" "github.com/dubbogo/gost/container/set"
gxnet "github.com/dubbogo/gost/net" "github.com/dubbogo/gost/net"
) )
const ( const (
...@@ -149,16 +148,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url *common.URL, in ...@@ -149,16 +148,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url *common.URL, in
if len(invokers) == 0 { if len(invokers) == 0 {
return invokers return invokers
} }
isMatchWhen, err := c.MatchWhen(url, invocation) isMatchWhen := c.MatchWhen(url, invocation)
if err != nil {
var urls []string
for _, invo := range invokers {
urls = append(urls, reflect.TypeOf(invo).String())
}
logger.Warnf("Failed to execute condition router rule: %s , invokers: [%s], cause: %v", c.url.String(), strings.Join(urls, ","), err)
return invokers
}
if !isMatchWhen { if !isMatchWhen {
return invokers return invokers
} }
...@@ -168,15 +158,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url *common.URL, in ...@@ -168,15 +158,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url *common.URL, in
} }
for _, invoker := range invokers { for _, invoker := range invokers {
invokerUrl := invoker.GetUrl() invokerUrl := invoker.GetUrl()
isMatchThen, err := c.MatchThen(&invokerUrl, url) isMatchThen := c.MatchThen(&invokerUrl, url)
if err != nil {
var urls []string
for _, invo := range invokers {
urls = append(urls, reflect.TypeOf(invo).String())
}
logger.Warnf("Failed to execute condition router rule: %s , invokers: [%s], cause: %v", c.url.String(), strings.Join(urls, ","), err)
return invokers
}
if isMatchThen { if isMatchThen {
result = append(result, invoker) result = append(result, invoker)
} }
...@@ -260,22 +242,23 @@ func getStartIndex(rule string) int { ...@@ -260,22 +242,23 @@ func getStartIndex(rule string) int {
} }
// //
func (c *ConditionRouter) MatchWhen(url *common.URL, invocation protocol.Invocation) (bool, error) { func (c *ConditionRouter) MatchWhen(url *common.URL, invocation protocol.Invocation) bool {
condition, err := matchCondition(c.WhenCondition, url, nil, invocation) condition := matchCondition(c.WhenCondition, url, nil, invocation)
return len(c.WhenCondition) == 0 || condition, err return len(c.WhenCondition) == 0 || condition
} }
//MatchThen MatchThen //MatchThen MatchThen
func (c *ConditionRouter) MatchThen(url *common.URL, param *common.URL) (bool, error) { func (c *ConditionRouter) MatchThen(url *common.URL, param *common.URL) bool {
condition, err := matchCondition(c.ThenCondition, url, param, nil) condition := matchCondition(c.ThenCondition, url, param, nil)
return len(c.ThenCondition) > 0 && condition, err return len(c.ThenCondition) > 0 && condition
} }
//MatchCondition MatchCondition //MatchCondition MatchCondition
func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) (bool, error) { func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool {
sample := url.ToMap() sample := url.ToMap()
if sample == nil { if sample == nil {
return true, perrors.Errorf("url is not allowed be nil") // because url.ToMap() may return nil, but it should continue to process make condition
sample = make(map[string]string)
} }
var result bool var result bool
for key, matchPair := range pairs { for key, matchPair := range pairs {
...@@ -291,19 +274,19 @@ func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U ...@@ -291,19 +274,19 @@ func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U
} }
if len(sampleValue) > 0 { if len(sampleValue) > 0 {
if !matchPair.isMatch(sampleValue, param) { if !matchPair.isMatch(sampleValue, param) {
return false, nil return false
} }
result = true result = true
} else { } else {
if !(matchPair.Matches.Empty()) { if !(matchPair.Matches.Empty()) {
return false, nil return false
} }
result = true result = true
} }
} }
return result, nil return result
} }
// MatchPair ... // MatchPair ...
......
...@@ -29,7 +29,6 @@ import ( ...@@ -29,7 +29,6 @@ import (
import ( import (
"github.com/apache/dubbo-go/cluster/directory" "github.com/apache/dubbo-go/cluster/directory"
"github.com/apache/dubbo-go/cluster/router"
"github.com/apache/dubbo-go/cluster/router/chain" "github.com/apache/dubbo-go/cluster/router/chain"
"github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/constant"
...@@ -42,7 +41,6 @@ import ( ...@@ -42,7 +41,6 @@ import (
"github.com/apache/dubbo-go/protocol/protocolwrapper" "github.com/apache/dubbo-go/protocol/protocolwrapper"
"github.com/apache/dubbo-go/registry" "github.com/apache/dubbo-go/registry"
"github.com/apache/dubbo-go/remoting" "github.com/apache/dubbo-go/remoting"
gxset "github.com/dubbogo/gost/container/set"
) )
// Options ... // Options ...
...@@ -137,11 +135,7 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) { ...@@ -137,11 +135,7 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) {
} }
if len(urls) > 0 { if len(urls) > 0 {
routers := toRouters(urls) dir.SetRouters(urls)
logger.Infof("Init file condition router success, size: %v", len(routers))
if len(routers) > 0 {
dir.SetRouters(routers)
}
} }
//dir.cacheService.EventTypeAdd(res.Path, dir.serviceTTL) //dir.cacheService.EventTypeAdd(res.Path, dir.serviceTTL)
...@@ -161,37 +155,6 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) { ...@@ -161,37 +155,6 @@ func (dir *registryDirectory) refreshInvokers(res *registry.ServiceEvent) {
dir.cacheInvokers = newInvokers dir.cacheInvokers = newInvokers
} }
func toRouters(urls []*common.URL) []router.Router {
if len(urls) == 0 {
return nil
}
routerMap := gxset.NewSet()
for _, url := range urls {
if url.Protocol == constant.EMPTY_PROTOCOL {
continue
}
routerKey := url.GetParam(constant.ROUTER_KEY, "")
if routerKey == "" {
continue
}
url.Protocol = routerKey
factory := extension.GetRouterFactory(url.GetParam(constant.ROUTER_KEY, routerKey))
router, e := factory.Router(url)
if e != nil {
logger.Error("factory.Router(url){%s} , error : %s", url, e)
}
routerMap.Add(router)
}
routers := make([]router.Router, 0)
for _, v := range routerMap.Values() {
routers = append(routers, v.(router.Router))
}
return routers
}
func (dir *registryDirectory) toGroupInvokers() []protocol.Invoker { func (dir *registryDirectory) toGroupInvokers() []protocol.Invoker {
newInvokersList := []protocol.Invoker{} newInvokersList := []protocol.Invoker{}
groupInvokersMap := make(map[string][]protocol.Invoker) groupInvokersMap := make(map[string][]protocol.Invoker)
......
...@@ -19,7 +19,6 @@ package directory ...@@ -19,7 +19,6 @@ package directory
import ( import (
"context" "context"
"encoding/base64"
"net/url" "net/url"
"strconv" "strconv"
"testing" "testing"
...@@ -42,7 +41,6 @@ import ( ...@@ -42,7 +41,6 @@ import (
"github.com/apache/dubbo-go/protocol/protocolwrapper" "github.com/apache/dubbo-go/protocol/protocolwrapper"
"github.com/apache/dubbo-go/registry" "github.com/apache/dubbo-go/registry"
"github.com/apache/dubbo-go/remoting" "github.com/apache/dubbo-go/remoting"
gxnet "github.com/dubbogo/gost/net"
) )
func init() { func init() {
...@@ -205,20 +203,3 @@ func normalRegistryDir(noMockEvent ...bool) (*registryDirectory, *registry.MockR ...@@ -205,20 +203,3 @@ func normalRegistryDir(noMockEvent ...bool) (*registryDirectory, *registry.MockR
} }
return registryDirectory, mockRegistry.(*registry.MockRegistry) return registryDirectory, mockRegistry.(*registry.MockRegistry)
} }
func TestToRouter(t *testing.T) {
localIP, _ := gxnet.GetLocalIP()
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 10.20.3.3"))
url, _ := common.NewURL(
context.TODO(),
"dubbo://0.0.0.0/com.foo.BarService",
common.WithParamsValue(constant.RULE_KEY, rule),
common.WithParamsValue(constant.ROUTER_KEY, "condition"),
)
urls := make([]*common.URL, 0)
urls = append(urls, &url)
routers := toRouters(urls)
assert.Equal(t, 1, len(routers))
router := routers[0]
assert.Equal(t, "condition", router.Url().Protocol)
}
...@@ -393,8 +393,13 @@ func (r *zkRegistry) register(c common.URL) error { ...@@ -393,8 +393,13 @@ func (r *zkRegistry) register(c common.URL) error {
case common.ROUTER: case common.ROUTER:
dubboPath = fmt.Sprintf("/dubbo/%s/%s", c.Service(), common.DubboNodes[common.ROUTER]) dubboPath = fmt.Sprintf("/dubbo/%s/%s", c.Service(), common.DubboNodes[common.ROUTER])
r.cltLock.Lock() r.cltLock.Lock()
err = r.client.Create(dubboPath) client := r.client
r.cltLock.Unlock() r.cltLock.Unlock()
if client == nil {
logger.Errorf("zkClient.create(path{%s}) = client is null", dubboPath)
return perrors.WithStack(err)
}
err = client.Create(dubboPath)
if err != nil { if err != nil {
logger.Errorf("zkClient.create(path{%s}) = error{%v}", dubboPath, perrors.WithStack(err)) logger.Errorf("zkClient.create(path{%s}) = error{%v}", dubboPath, perrors.WithStack(err))
return perrors.WithStack(err) return perrors.WithStack(err)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment