From 56c3f9064b237a97ebdb8d0f55a0dbfcc6e5615a Mon Sep 17 00:00:00 2001 From: aliiohs <rzy1107@163.com> Date: Wed, 19 Jun 2019 16:24:18 +0800 Subject: [PATCH] Add routing-related features and add some test case for condition router --- cluster/router/condition_router.go | 123 +++++++++++++++++++----- cluster/router/condition_router_test.go | 80 ++++++++++++--- common/constant/default.go | 2 +- common/url.go | 42 +++++++- common/utils/hashset.go | 70 ++++++++++++++ 5 files changed, 281 insertions(+), 36 deletions(-) create mode 100644 common/utils/hashset.go diff --git a/cluster/router/condition_router.go b/cluster/router/condition_router.go index ceeca8e01..a8074f739 100644 --- a/cluster/router/condition_router.go +++ b/cluster/router/condition_router.go @@ -1,8 +1,10 @@ package router import ( + "encoding/base64" "github.com/apache/dubbo-go/cluster" "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/utils" "regexp" "strings" @@ -15,6 +17,8 @@ const ( RoutePattern = `([&!=,]*)\\s*([^&!=,\\s]+)` ) +var itemExists = struct{}{} + type ConditionRouter struct { Pattern string Url common.URL @@ -28,7 +32,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv if len(invokers) == 0 { return invokers, nil } - if !c.matchWhen(url, invocation) { + if !c.MatchWhen(url, invocation) { return invokers, nil } var result []protocol.Invoker @@ -36,7 +40,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv return result, nil } for _, invoker := range invokers { - if c.matchThen(invoker.GetUrl(), url) { + if c.MatchThen(invoker.GetUrl(), url) { result = append(result, invoker) } } @@ -70,7 +74,18 @@ func (c ConditionRouter) CompareTo(r cluster.Router) int { func newConditionRouter(url common.URL) (*ConditionRouter, error) { var whenRule string var thenRule string - //rule := url.GetParam("rule", "") + + ruleDec, err := base64.URLEncoding.DecodeString(url.GetParam("rule", "")) + rule := string(ruleDec) + if err != nil || rule == "" { + return nil, perrors.Errorf("Illegal route rule!") + } + rule = strings.Replace(rule, "consumer.", "", -1) + rule = strings.Replace(rule, "provider.", "", -1) + i := strings.Index(rule, "=>") + whenRule = strings.Trim(If(i < 0, "", rule[0:i]).(string), " ") + thenRule = strings.Trim(If(i < 0, rule, rule[i+2:]).(string), " ") + w, err := parseRule(whenRule) if err != nil { return nil, perrors.Errorf("%s", "") @@ -100,7 +115,7 @@ func parseRule(rule string) (map[string]MatchPair, error) { return condition, nil } var pair MatchPair - values := make(map[string]interface{}) + values := utils.NewSet() reg := regexp.MustCompile(`([&!=,]*)\s*([^&!=,\s]+)`) @@ -112,7 +127,10 @@ func parseRule(rule string) (map[string]MatchPair, error) { switch separator { case "": - pair = MatchPair{} + pair = MatchPair{ + Matches: utils.NewSet(), + Mismatches: utils.NewSet(), + } condition[content] = pair case "&": if r, ok := condition[content]; ok { @@ -126,18 +144,18 @@ func parseRule(rule string) (map[string]MatchPair, error) { return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) } values = pair.Matches - values[content] = "" + values.Add(content) case "!=": if &pair == nil { return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) } - values = pair.Matches - values[content] = "" + values = pair.Mismatches + values.Add(content) case ",": - if len(values) == 0 { + if values.Empty() { return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) } - values[content] = "" + values.Add(content) default: return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) @@ -149,16 +167,16 @@ func parseRule(rule string) (map[string]MatchPair, error) { } -func (c *ConditionRouter) matchWhen(url common.URL, invocation protocol.Invocation) bool { +func (c *ConditionRouter) MatchWhen(url common.URL, invocation protocol.Invocation) bool { - return len(c.WhenCondition) == 0 || len(c.WhenCondition) == 0 || matchCondition(c.WhenCondition, &url, nil, invocation) + return len(c.WhenCondition) == 0 || MatchCondition(c.WhenCondition, &url, nil, invocation) } -func (c *ConditionRouter) matchThen(url common.URL, param common.URL) bool { +func (c *ConditionRouter) MatchThen(url common.URL, param common.URL) bool { - return !(len(c.ThenCondition) == 0) && matchCondition(c.ThenCondition, &url, ¶m, nil) + return len(c.ThenCondition) > 0 && MatchCondition(c.ThenCondition, &url, ¶m, nil) } -func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool { +func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool { sample := url.ToMap() result := false for key, matchPair := range pairs { @@ -168,18 +186,18 @@ func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U sampleValue = invocation.MethodName() } else { sampleValue = sample[key] - if &sampleValue == nil { - sampleValue = sample[constant.DEFAULT_KEY_PREFIX+key] + if sampleValue == "" { + sampleValue = sample[constant.PREFIX_DEFAULT_KEY+key] } } - if &sampleValue != nil { + if sampleValue != "" { if !matchPair.isMatch(sampleValue, param) { return false } else { result = true } } else { - if !(len(matchPair.Matches) == 0) { + if !(matchPair.Matches.Empty()) { return false } else { result = true @@ -198,11 +216,72 @@ func If(b bool, t, f interface{}) interface{} { } type MatchPair struct { - Matches map[string]interface{} - Mismatches map[string]interface{} + Matches *utils.HashSet + Mismatches *utils.HashSet } -func (pair MatchPair) isMatch(s string, param *common.URL) bool { +func (pair MatchPair) isMatch(value string, param *common.URL) bool { + + if !pair.Matches.Empty() && pair.Mismatches.Empty() { + + for match := range pair.Matches.Items { + if isMatchGlobPattern(match.(string), value, param) { + return true + } + } + return false + } + if !pair.Mismatches.Empty() && pair.Matches.Empty() { + + for mismatch := range pair.Mismatches.Items { + if isMatchGlobPattern(mismatch.(string), value, param) { + return false + } + } + return true + } + if !pair.Mismatches.Empty() && !pair.Matches.Empty() { + for mismatch := range pair.Mismatches.Items { + if isMatchGlobPattern(mismatch.(string), value, param) { + return false + } + } + for match := range pair.Matches.Items { + if isMatchGlobPattern(match.(string), value, param) { + return true + } + } + return false + } return false } + +func isMatchGlobPattern(pattern string, value string, param *common.URL) bool { + if param != nil && strings.HasPrefix(pattern, "$") { + pattern = param.GetRawParameter(pattern[1:]) + } + if "*" == pattern { + return true + } + if pattern == "" && value == "" { + return true + } + if pattern == "" || value == "" { + return false + } + i := strings.LastIndex(pattern, "*") + switch i { + case -1: + return value == pattern + case len(pattern) - 1: + return strings.HasPrefix(value, pattern[0:i]) + case 0: + return strings.HasSuffix(value, pattern[:i+1]) + default: + prefix := pattern[0:1] + suffix := pattern[i+1:] + return strings.HasPrefix(value, prefix) && strings.HasSuffix(value, suffix) + + } +} diff --git a/cluster/router/condition_router_test.go b/cluster/router/condition_router_test.go index 83cb25e58..5c6c2ecfd 100644 --- a/cluster/router/condition_router_test.go +++ b/cluster/router/condition_router_test.go @@ -2,6 +2,7 @@ package router import ( "context" + "encoding/base64" perrors "errors" "fmt" "github.com/apache/dubbo-go/common" @@ -68,10 +69,7 @@ func (bi *MockInvoker) Destroy() { bi.destroyed = true bi.available = false } -func Test_parseRule(t *testing.T) { - parseRule("host = 10.20.153.10 => host = 10.20.153.11") -} func LocalIp() string { addrs, err := net.InterfaceAddrs() if err != nil { @@ -88,21 +86,81 @@ func LocalIp() string { return ip } func TestRoute_matchWhen(t *testing.T) { - + rpcInvacation := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("=> host = 1.2.3.4")) + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + cUrl, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService") + + matchWhen := router.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, true, matchWhen) + + rule1 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) + router1, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule1)) + matchWhen1 := router1.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, true, matchWhen1) + + rule2 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4")) + router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2)) + matchWhen2 := router2.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, false, matchWhen2) + + rule3 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.4 & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) + router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3)) + matchWhen3 := router3.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, true, matchWhen3) + + rule4 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) + router4, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule4)) + matchWhen4 := router4.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, true, matchWhen4) + + rule5 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.1 => host = 1.2.3.4")) + router5, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule5)) + matchWhen5 := router5.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, false, matchWhen5) + + rule6 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.2 => host = 1.2.3.4")) + router6, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule6)) + matchWhen6 := router6.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation) + assert.Equal(t, true, matchWhen6) } func TestRoute_matchFilter(t *testing.T) { url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService?default.serialization=fastjson") url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) invokers := []protocol.Invoker{NewMockInvoker(url1, 1), NewMockInvoker(url2, 2), NewMockInvoker(url3, 3)} - option := common.WithParamsValue("rule", "host = "+LocalIp()+" => "+" host = 10.20.3.3") - option1 := common.WithParamsValue("force", "true") - sUrl, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService", option, option1) - - router1, _ := NewConditionRouterFactory().GetRouter(sUrl) + rule1 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.3")) + rule2 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.* & host != 10.20.3.3")) + rule3 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.3 & host != 10.20.3.3")) + rule4 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.2,10.20.3.3,10.20.3.4")) + rule5 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host != 10.20.3.3")) + rule6 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " serialization = fastjson")) + router1, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule1)) + router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2)) + router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3)) + router4, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule4)) + router5, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule5)) + router6, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule6)) cUrl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") - routers, _ := router1.Route(invokers, cUrl, &invocation.RPCInvocation{}) - assert.Equal(t, 1, len(routers)) + fileredInvokers1, _ := router1.Route(invokers, cUrl, &invocation.RPCInvocation{}) + fileredInvokers2, _ := router2.Route(invokers, cUrl, &invocation.RPCInvocation{}) + fileredInvokers3, _ := router3.Route(invokers, cUrl, &invocation.RPCInvocation{}) + fileredInvokers4, _ := router4.Route(invokers, cUrl, &invocation.RPCInvocation{}) + fileredInvokers5, _ := router5.Route(invokers, cUrl, &invocation.RPCInvocation{}) + fileredInvokers6, _ := router6.Route(invokers, cUrl, &invocation.RPCInvocation{}) + assert.Equal(t, 1, len(fileredInvokers1)) + assert.Equal(t, 0, len(fileredInvokers2)) + assert.Equal(t, 0, len(fileredInvokers3)) + assert.Equal(t, 1, len(fileredInvokers4)) + assert.Equal(t, 2, len(fileredInvokers5)) + assert.Equal(t, 1, len(fileredInvokers6)) + +} +func getRouteUrl(rule string) common.URL { + url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService") + url.AddParam("rule", rule) + url.AddParam("force", "true") + return url } diff --git a/common/constant/default.go b/common/constant/default.go index 63ce0e515..51d98ce1d 100644 --- a/common/constant/default.go +++ b/common/constant/default.go @@ -33,7 +33,7 @@ const ( const ( DEFAULT_KEY = "default" - DEFAULT_KEY_PREFIX = "default." + PREFIX_DEFAULT_KEY = "default." DEFAULT_SERVICE_FILTERS = "echo" DEFAULT_REFERENCE_FILTERS = "" ECHO = "$echo" diff --git a/common/url.go b/common/url.go index 88452f643..4f3546880 100644 --- a/common/url.go +++ b/common/url.go @@ -262,6 +262,11 @@ func (c URL) Service() string { } return "" } + +func (c *URL) AddParam(key string, value string) { + c.Params.Add(key, value) +} + func (c URL) GetParam(s string, d string) string { var r string if r = c.Params.Get(s); r == "" { @@ -270,6 +275,28 @@ func (c URL) GetParam(s string, d string) string { return r } +func (c URL) GetRawParameter(key string) string { + if "protocol" == key { + return c.Protocol + } + if "username" == key { + return c.Username + } + if "password" == key { + return c.Password + } + if "host" == key { + return c.Ip + } + if "port" == key { + return c.Port + } + if "path" == key { + return c.Path + } + return c.Params.Get(key) +} + // GetParamBool func (c URL) GetParamBool(s string, d bool) bool { @@ -311,6 +338,10 @@ func (c URL) GetMethodParam(method string, key string, d string) string { func (c URL) ToMap() map[string]string { paramsMap := make(map[string]string) + + for k, v := range c.Params { + paramsMap[k] = v[0] + } if c.Protocol != "" { paramsMap["protocol"] = c.Protocol } @@ -320,8 +351,15 @@ func (c URL) ToMap() map[string]string { if c.Password != "" { paramsMap["password"] = c.Password } - if c.Ip != "" { - paramsMap["host"] = c.Ip + if c.Location != "" { + paramsMap["host"] = strings.Split(c.Location, ":")[0] + var port string + if strings.Contains(c.Location, ":") { + port = strings.Split(c.Location, ":")[1] + } else { + port = "0" + } + paramsMap["port"] = port } if c.Protocol != "" { paramsMap["protocol"] = c.Protocol diff --git a/common/utils/hashset.go b/common/utils/hashset.go new file mode 100644 index 000000000..a03c17d83 --- /dev/null +++ b/common/utils/hashset.go @@ -0,0 +1,70 @@ +package utils + +import ( + "fmt" + "strings" +) + +var itemExists = struct{}{} + +type HashSet struct { + Items map[interface{}]struct{} +} + +func NewSet(values ...interface{}) *HashSet { + set := &HashSet{Items: make(map[interface{}]struct{})} + if len(values) > 0 { + set.Add(values...) + } + return set +} + +func (set *HashSet) Add(items ...interface{}) { + for _, item := range items { + set.Items[item] = itemExists + } +} + +func (set *HashSet) Remove(items ...interface{}) { + for _, item := range items { + delete(set.Items, item) + } +} + +func (set *HashSet) Contains(items ...interface{}) bool { + for _, item := range items { + if _, contains := set.Items[item]; !contains { + return false + } + } + return true +} +func (set *HashSet) Empty() bool { + return set.Size() == 0 +} +func (set *HashSet) Size() int { + return len(set.Items) +} + +func (set *HashSet) Clear() { + set.Items = make(map[interface{}]struct{}) +} + +func (set *HashSet) Values() []interface{} { + values := make([]interface{}, set.Size()) + count := 0 + for item := range set.Items { + values[count] = item + count++ + } + return values +} +func (set *HashSet) String() string { + str := "HashSet\n" + var items []string + for k := range set.Items { + items = append(items, fmt.Sprintf("%v", k)) + } + str += strings.Join(items, ", ") + return str +} -- GitLab