From 56c3f9064b237a97ebdb8d0f55a0dbfcc6e5615a Mon Sep 17 00:00:00 2001
From: aliiohs <rzy1107@163.com>
Date: Wed, 19 Jun 2019 16:24:18 +0800
Subject: [PATCH] Add routing-related features and add some test case for
 condition router

---
 cluster/router/condition_router.go      | 123 +++++++++++++++++++-----
 cluster/router/condition_router_test.go |  80 ++++++++++++---
 common/constant/default.go              |   2 +-
 common/url.go                           |  42 +++++++-
 common/utils/hashset.go                 |  70 ++++++++++++++
 5 files changed, 281 insertions(+), 36 deletions(-)
 create mode 100644 common/utils/hashset.go

diff --git a/cluster/router/condition_router.go b/cluster/router/condition_router.go
index ceeca8e01..a8074f739 100644
--- a/cluster/router/condition_router.go
+++ b/cluster/router/condition_router.go
@@ -1,8 +1,10 @@
 package router
 
 import (
+	"encoding/base64"
 	"github.com/apache/dubbo-go/cluster"
 	"github.com/apache/dubbo-go/common/constant"
+	"github.com/apache/dubbo-go/common/utils"
 	"regexp"
 	"strings"
 
@@ -15,6 +17,8 @@ const (
 	RoutePattern = `([&!=,]*)\\s*([^&!=,\\s]+)`
 )
 
+var itemExists = struct{}{}
+
 type ConditionRouter struct {
 	Pattern       string
 	Url           common.URL
@@ -28,7 +32,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv
 	if len(invokers) == 0 {
 		return invokers, nil
 	}
-	if !c.matchWhen(url, invocation) {
+	if !c.MatchWhen(url, invocation) {
 		return invokers, nil
 	}
 	var result []protocol.Invoker
@@ -36,7 +40,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv
 		return result, nil
 	}
 	for _, invoker := range invokers {
-		if c.matchThen(invoker.GetUrl(), url) {
+		if c.MatchThen(invoker.GetUrl(), url) {
 			result = append(result, invoker)
 		}
 	}
@@ -70,7 +74,18 @@ func (c ConditionRouter) CompareTo(r cluster.Router) int {
 func newConditionRouter(url common.URL) (*ConditionRouter, error) {
 	var whenRule string
 	var thenRule string
-	//rule := url.GetParam("rule", "")
+
+	ruleDec, err := base64.URLEncoding.DecodeString(url.GetParam("rule", ""))
+	rule := string(ruleDec)
+	if err != nil || rule == "" {
+		return nil, perrors.Errorf("Illegal route rule!")
+	}
+	rule = strings.Replace(rule, "consumer.", "", -1)
+	rule = strings.Replace(rule, "provider.", "", -1)
+	i := strings.Index(rule, "=>")
+	whenRule = strings.Trim(If(i < 0, "", rule[0:i]).(string), " ")
+	thenRule = strings.Trim(If(i < 0, rule, rule[i+2:]).(string), " ")
+
 	w, err := parseRule(whenRule)
 	if err != nil {
 		return nil, perrors.Errorf("%s", "")
@@ -100,7 +115,7 @@ func parseRule(rule string) (map[string]MatchPair, error) {
 		return condition, nil
 	}
 	var pair MatchPair
-	values := make(map[string]interface{})
+	values := utils.NewSet()
 
 	reg := regexp.MustCompile(`([&!=,]*)\s*([^&!=,\s]+)`)
 
@@ -112,7 +127,10 @@ func parseRule(rule string) (map[string]MatchPair, error) {
 
 		switch separator {
 		case "":
-			pair = MatchPair{}
+			pair = MatchPair{
+				Matches:    utils.NewSet(),
+				Mismatches: utils.NewSet(),
+			}
 			condition[content] = pair
 		case "&":
 			if r, ok := condition[content]; ok {
@@ -126,18 +144,18 @@ func parseRule(rule string) (map[string]MatchPair, error) {
 				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
 			}
 			values = pair.Matches
-			values[content] = ""
+			values.Add(content)
 		case "!=":
 			if &pair == nil {
 				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
 			}
-			values = pair.Matches
-			values[content] = ""
+			values = pair.Mismatches
+			values.Add(content)
 		case ",":
-			if len(values) == 0 {
+			if values.Empty() {
 				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
 			}
-			values[content] = ""
+			values.Add(content)
 		default:
 			return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
 
@@ -149,16 +167,16 @@ func parseRule(rule string) (map[string]MatchPair, error) {
 
 }
 
-func (c *ConditionRouter) matchWhen(url common.URL, invocation protocol.Invocation) bool {
+func (c *ConditionRouter) MatchWhen(url common.URL, invocation protocol.Invocation) bool {
 
-	return len(c.WhenCondition) == 0 || len(c.WhenCondition) == 0 || matchCondition(c.WhenCondition, &url, nil, invocation)
+	return len(c.WhenCondition) == 0 || MatchCondition(c.WhenCondition, &url, nil, invocation)
 }
-func (c *ConditionRouter) matchThen(url common.URL, param common.URL) bool {
+func (c *ConditionRouter) MatchThen(url common.URL, param common.URL) bool {
 
-	return !(len(c.ThenCondition) == 0) && matchCondition(c.ThenCondition, &url, &param, nil)
+	return len(c.ThenCondition) > 0 && MatchCondition(c.ThenCondition, &url, &param, nil)
 }
 
-func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool {
+func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool {
 	sample := url.ToMap()
 	result := false
 	for key, matchPair := range pairs {
@@ -168,18 +186,18 @@ func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U
 			sampleValue = invocation.MethodName()
 		} else {
 			sampleValue = sample[key]
-			if &sampleValue == nil {
-				sampleValue = sample[constant.DEFAULT_KEY_PREFIX+key]
+			if sampleValue == "" {
+				sampleValue = sample[constant.PREFIX_DEFAULT_KEY+key]
 			}
 		}
-		if &sampleValue != nil {
+		if sampleValue != "" {
 			if !matchPair.isMatch(sampleValue, param) {
 				return false
 			} else {
 				result = true
 			}
 		} else {
-			if !(len(matchPair.Matches) == 0) {
+			if !(matchPair.Matches.Empty()) {
 				return false
 			} else {
 				result = true
@@ -198,11 +216,72 @@ func If(b bool, t, f interface{}) interface{} {
 }
 
 type MatchPair struct {
-	Matches    map[string]interface{}
-	Mismatches map[string]interface{}
+	Matches    *utils.HashSet
+	Mismatches *utils.HashSet
 }
 
-func (pair MatchPair) isMatch(s string, param *common.URL) bool {
+func (pair MatchPair) isMatch(value string, param *common.URL) bool {
+
+	if !pair.Matches.Empty() && pair.Mismatches.Empty() {
+
+		for match := range pair.Matches.Items {
+			if isMatchGlobPattern(match.(string), value, param) {
+				return true
+			}
+		}
+		return false
+	}
+	if !pair.Mismatches.Empty() && pair.Matches.Empty() {
+
+		for mismatch := range pair.Mismatches.Items {
+			if isMatchGlobPattern(mismatch.(string), value, param) {
+				return false
+			}
+		}
+		return true
+	}
+	if !pair.Mismatches.Empty() && !pair.Matches.Empty() {
+		for mismatch := range pair.Mismatches.Items {
+			if isMatchGlobPattern(mismatch.(string), value, param) {
+				return false
+			}
+		}
+		for match := range pair.Matches.Items {
+			if isMatchGlobPattern(match.(string), value, param) {
+				return true
+			}
+		}
+		return false
+	}
 
 	return false
 }
+
+func isMatchGlobPattern(pattern string, value string, param *common.URL) bool {
+	if param != nil && strings.HasPrefix(pattern, "$") {
+		pattern = param.GetRawParameter(pattern[1:])
+	}
+	if "*" == pattern {
+		return true
+	}
+	if pattern == "" && value == "" {
+		return true
+	}
+	if pattern == "" || value == "" {
+		return false
+	}
+	i := strings.LastIndex(pattern, "*")
+	switch i {
+	case -1:
+		return value == pattern
+	case len(pattern) - 1:
+		return strings.HasPrefix(value, pattern[0:i])
+	case 0:
+		return strings.HasSuffix(value, pattern[:i+1])
+	default:
+		prefix := pattern[0:1]
+		suffix := pattern[i+1:]
+		return strings.HasPrefix(value, prefix) && strings.HasSuffix(value, suffix)
+
+	}
+}
diff --git a/cluster/router/condition_router_test.go b/cluster/router/condition_router_test.go
index 83cb25e58..5c6c2ecfd 100644
--- a/cluster/router/condition_router_test.go
+++ b/cluster/router/condition_router_test.go
@@ -2,6 +2,7 @@ package router
 
 import (
 	"context"
+	"encoding/base64"
 	perrors "errors"
 	"fmt"
 	"github.com/apache/dubbo-go/common"
@@ -68,10 +69,7 @@ func (bi *MockInvoker) Destroy() {
 	bi.destroyed = true
 	bi.available = false
 }
-func Test_parseRule(t *testing.T) {
-	parseRule("host = 10.20.153.10 => host = 10.20.153.11")
 
-}
 func LocalIp() string {
 	addrs, err := net.InterfaceAddrs()
 	if err != nil {
@@ -88,21 +86,81 @@ func LocalIp() string {
 	return ip
 }
 func TestRoute_matchWhen(t *testing.T) {
-
+	rpcInvacation := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("=> host = 1.2.3.4"))
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	cUrl, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService")
+
+	matchWhen := router.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, true, matchWhen)
+
+	rule1 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
+	router1, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule1))
+	matchWhen1 := router1.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, true, matchWhen1)
+
+	rule2 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"))
+	router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2))
+	matchWhen2 := router2.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, false, matchWhen2)
+
+	rule3 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.4 & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
+	router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3))
+	matchWhen3 := router3.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, true, matchWhen3)
+
+	rule4 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
+	router4, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule4))
+	matchWhen4 := router4.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, true, matchWhen4)
+
+	rule5 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.1 => host = 1.2.3.4"))
+	router5, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule5))
+	matchWhen5 := router5.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, false, matchWhen5)
+
+	rule6 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.2 => host = 1.2.3.4"))
+	router6, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule6))
+	matchWhen6 := router6.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
+	assert.Equal(t, true, matchWhen6)
 }
 func TestRoute_matchFilter(t *testing.T) {
 	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService?default.serialization=fastjson")
 	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
 	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
 	invokers := []protocol.Invoker{NewMockInvoker(url1, 1), NewMockInvoker(url2, 2), NewMockInvoker(url3, 3)}
-	option := common.WithParamsValue("rule", "host = "+LocalIp()+" => "+" host = 10.20.3.3")
-	option1 := common.WithParamsValue("force", "true")
-	sUrl, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService", option, option1)
-
-	router1, _ := NewConditionRouterFactory().GetRouter(sUrl)
+	rule1 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.3"))
+	rule2 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.* & host != 10.20.3.3"))
+	rule3 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.3  & host != 10.20.3.3"))
+	rule4 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.2,10.20.3.3,10.20.3.4"))
+	rule5 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host != 10.20.3.3"))
+	rule6 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " serialization = fastjson"))
+	router1, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule1))
+	router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2))
+	router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3))
+	router4, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule4))
+	router5, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule5))
+	router6, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule6))
 	cUrl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
 
-	routers, _ := router1.Route(invokers, cUrl, &invocation.RPCInvocation{})
-	assert.Equal(t, 1, len(routers))
+	fileredInvokers1, _ := router1.Route(invokers, cUrl, &invocation.RPCInvocation{})
+	fileredInvokers2, _ := router2.Route(invokers, cUrl, &invocation.RPCInvocation{})
+	fileredInvokers3, _ := router3.Route(invokers, cUrl, &invocation.RPCInvocation{})
+	fileredInvokers4, _ := router4.Route(invokers, cUrl, &invocation.RPCInvocation{})
+	fileredInvokers5, _ := router5.Route(invokers, cUrl, &invocation.RPCInvocation{})
+	fileredInvokers6, _ := router6.Route(invokers, cUrl, &invocation.RPCInvocation{})
+	assert.Equal(t, 1, len(fileredInvokers1))
+	assert.Equal(t, 0, len(fileredInvokers2))
+	assert.Equal(t, 0, len(fileredInvokers3))
+	assert.Equal(t, 1, len(fileredInvokers4))
+	assert.Equal(t, 2, len(fileredInvokers5))
+	assert.Equal(t, 1, len(fileredInvokers6))
+
+}
 
+func getRouteUrl(rule string) common.URL {
+	url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
+	url.AddParam("rule", rule)
+	url.AddParam("force", "true")
+	return url
 }
diff --git a/common/constant/default.go b/common/constant/default.go
index 63ce0e515..51d98ce1d 100644
--- a/common/constant/default.go
+++ b/common/constant/default.go
@@ -33,7 +33,7 @@ const (
 
 const (
 	DEFAULT_KEY               = "default"
-	DEFAULT_KEY_PREFIX        = "default."
+	PREFIX_DEFAULT_KEY        = "default."
 	DEFAULT_SERVICE_FILTERS   = "echo"
 	DEFAULT_REFERENCE_FILTERS = ""
 	ECHO                      = "$echo"
diff --git a/common/url.go b/common/url.go
index 88452f643..4f3546880 100644
--- a/common/url.go
+++ b/common/url.go
@@ -262,6 +262,11 @@ func (c URL) Service() string {
 	}
 	return ""
 }
+
+func (c *URL) AddParam(key string, value string) {
+	c.Params.Add(key, value)
+}
+
 func (c URL) GetParam(s string, d string) string {
 	var r string
 	if r = c.Params.Get(s); r == "" {
@@ -270,6 +275,28 @@ func (c URL) GetParam(s string, d string) string {
 	return r
 }
 
+func (c URL) GetRawParameter(key string) string {
+	if "protocol" == key {
+		return c.Protocol
+	}
+	if "username" == key {
+		return c.Username
+	}
+	if "password" == key {
+		return c.Password
+	}
+	if "host" == key {
+		return c.Ip
+	}
+	if "port" == key {
+		return c.Port
+	}
+	if "path" == key {
+		return c.Path
+	}
+	return c.Params.Get(key)
+}
+
 // GetParamBool
 func (c URL) GetParamBool(s string, d bool) bool {
 
@@ -311,6 +338,10 @@ func (c URL) GetMethodParam(method string, key string, d string) string {
 func (c URL) ToMap() map[string]string {
 
 	paramsMap := make(map[string]string)
+
+	for k, v := range c.Params {
+		paramsMap[k] = v[0]
+	}
 	if c.Protocol != "" {
 		paramsMap["protocol"] = c.Protocol
 	}
@@ -320,8 +351,15 @@ func (c URL) ToMap() map[string]string {
 	if c.Password != "" {
 		paramsMap["password"] = c.Password
 	}
-	if c.Ip != "" {
-		paramsMap["host"] = c.Ip
+	if c.Location != "" {
+		paramsMap["host"] = strings.Split(c.Location, ":")[0]
+		var port string
+		if strings.Contains(c.Location, ":") {
+			port = strings.Split(c.Location, ":")[1]
+		} else {
+			port = "0"
+		}
+		paramsMap["port"] = port
 	}
 	if c.Protocol != "" {
 		paramsMap["protocol"] = c.Protocol
diff --git a/common/utils/hashset.go b/common/utils/hashset.go
new file mode 100644
index 000000000..a03c17d83
--- /dev/null
+++ b/common/utils/hashset.go
@@ -0,0 +1,70 @@
+package utils
+
+import (
+	"fmt"
+	"strings"
+)
+
+var itemExists = struct{}{}
+
+type HashSet struct {
+	Items map[interface{}]struct{}
+}
+
+func NewSet(values ...interface{}) *HashSet {
+	set := &HashSet{Items: make(map[interface{}]struct{})}
+	if len(values) > 0 {
+		set.Add(values...)
+	}
+	return set
+}
+
+func (set *HashSet) Add(items ...interface{}) {
+	for _, item := range items {
+		set.Items[item] = itemExists
+	}
+}
+
+func (set *HashSet) Remove(items ...interface{}) {
+	for _, item := range items {
+		delete(set.Items, item)
+	}
+}
+
+func (set *HashSet) Contains(items ...interface{}) bool {
+	for _, item := range items {
+		if _, contains := set.Items[item]; !contains {
+			return false
+		}
+	}
+	return true
+}
+func (set *HashSet) Empty() bool {
+	return set.Size() == 0
+}
+func (set *HashSet) Size() int {
+	return len(set.Items)
+}
+
+func (set *HashSet) Clear() {
+	set.Items = make(map[interface{}]struct{})
+}
+
+func (set *HashSet) Values() []interface{} {
+	values := make([]interface{}, set.Size())
+	count := 0
+	for item := range set.Items {
+		values[count] = item
+		count++
+	}
+	return values
+}
+func (set *HashSet) String() string {
+	str := "HashSet\n"
+	var items []string
+	for k := range set.Items {
+		items = append(items, fmt.Sprintf("%v", k))
+	}
+	str += strings.Join(items, ", ")
+	return str
+}
-- 
GitLab