From 129b3423ba5fac13697409982879c076af50d301 Mon Sep 17 00:00:00 2001
From: aliiohs <rzy1107@163.com>
Date: Thu, 20 Jun 2019 23:37:44 +0800
Subject: [PATCH] fix some bug and add some test case for condition router

---
 cluster/router/condition_router.go      |  29 ++--
 cluster/router/condition_router_test.go | 188 +++++++++++++++++++++++-
 cluster/router/router_factory.go        |   2 +-
 common/url.go                           |   6 +-
 protocol/invocation/rpcinvocation.go    |   8 +
 5 files changed, 214 insertions(+), 19 deletions(-)

diff --git a/cluster/router/condition_router.go b/cluster/router/condition_router.go
index a8074f739..a175e5c5d 100644
--- a/cluster/router/condition_router.go
+++ b/cluster/router/condition_router.go
@@ -40,6 +40,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv
 		return result, nil
 	}
 	for _, invoker := range invokers {
+
 		if c.MatchThen(invoker.GetUrl(), url) {
 			result = append(result, invoker)
 		}
@@ -50,7 +51,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv
 		//todo 鏃ュ織
 		return result, nil
 	}
-	return result, nil
+	return invokers, nil
 }
 
 func (c ConditionRouter) CompareTo(r cluster.Router) int {
@@ -101,8 +102,8 @@ func newConditionRouter(url common.URL) (*ConditionRouter, error) {
 	return &ConditionRouter{
 		RoutePattern,
 		url,
-		url.GetParamInt("Priority", 0),
-		url.GetParamBool("Force", false),
+		url.GetParamInt("priority", 0),
+		url.GetParamBool("force", false),
 		when,
 		then,
 	}, nil
@@ -118,8 +119,10 @@ func parseRule(rule string) (map[string]MatchPair, error) {
 	values := utils.NewSet()
 
 	reg := regexp.MustCompile(`([&!=,]*)\s*([^&!=,\s]+)`)
-
-	startIndex := reg.FindIndex([]byte(rule))
+	var startIndex = 0
+	if indexTuple := reg.FindIndex([]byte(rule)); len(indexTuple) > 0 {
+		startIndex = indexTuple[0]
+	}
 	matches := reg.FindAllSubmatch([]byte(rule), -1)
 	for _, groups := range matches {
 		separator := string(groups[1])
@@ -136,28 +139,31 @@ func parseRule(rule string) (map[string]MatchPair, error) {
 			if r, ok := condition[content]; ok {
 				pair = r
 			} else {
-				pair = MatchPair{}
+				pair = MatchPair{
+					Matches:    utils.NewSet(),
+					Mismatches: utils.NewSet(),
+				}
 				condition[content] = pair
 			}
 		case "=":
 			if &pair == nil {
-				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
+				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
 			}
 			values = pair.Matches
 			values.Add(content)
 		case "!=":
 			if &pair == nil {
-				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
+				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
 			}
 			values = pair.Mismatches
 			values.Add(content)
 		case ",":
 			if values.Empty() {
-				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
+				return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
 			}
 			values.Add(content)
 		default:
-			return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
+			return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
 
 		}
 	}
@@ -178,6 +184,9 @@ func (c *ConditionRouter) MatchThen(url common.URL, param common.URL) bool {
 
 func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool {
 	sample := url.ToMap()
+	if len(sample) == 0 {
+		return true
+	}
 	result := false
 	for key, matchPair := range pairs {
 		var sampleValue string
diff --git a/cluster/router/condition_router_test.go b/cluster/router/condition_router_test.go
index 5c6c2ecfd..b5a32be0d 100644
--- a/cluster/router/condition_router_test.go
+++ b/cluster/router/condition_router_test.go
@@ -11,6 +11,7 @@ import (
 	"github.com/apache/dubbo-go/protocol/invocation"
 	"github.com/stretchr/testify/assert"
 	"net"
+	"reflect"
 	"testing"
 )
 
@@ -35,6 +36,24 @@ func (bi *MockInvoker) GetUrl() common.URL {
 	return bi.url
 }
 
+func getRouteUrl(rule string) common.URL {
+	url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
+	url.AddParam("rule", rule)
+	url.AddParam("force", "true")
+	return url
+}
+func getRouteUrlWithForce(rule, force string) common.URL {
+	url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
+	url.AddParam("rule", rule)
+	url.AddParam("force", force)
+	return url
+}
+func getRouteUrlWithNoForce(rule string) common.URL {
+	url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
+	url.AddParam("rule", rule)
+	return url
+}
+
 func (bi *MockInvoker) IsAvailable() bool {
 	return bi.available
 }
@@ -158,9 +177,168 @@ func TestRoute_matchFilter(t *testing.T) {
 
 }
 
-func getRouteUrl(rule string) common.URL {
-	url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
-	url.AddParam("rule", rule)
-	url.AddParam("force", "true")
-	return url
+func TestRoute_methodRoute(t *testing.T) {
+
+	inv := invocation.NewRPCInvocation("getFoo", []reflect.Type{}, []interface{}{})
+	rule := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+
+	url, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=setFoo,getFoo,findFoo")
+	matchWhen := router.(*ConditionRouter).MatchWhen(url, inv)
+	assert.Equal(t, true, matchWhen)
+
+	url1, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
+	matchWhen = router.(*ConditionRouter).MatchWhen(url1, inv)
+	assert.Equal(t, true, matchWhen)
+
+	url2, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
+	rule2 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host!=1.1.1.1 => host = 1.2.3.4"))
+	router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2))
+	matchWhen = router2.(*ConditionRouter).MatchWhen(url2, inv)
+	assert.Equal(t, false, matchWhen)
+
+	url3, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
+	rule3 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host=1.1.1.1 => host = 1.2.3.4"))
+	router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3))
+	matchWhen = router3.(*ConditionRouter).MatchWhen(url3, inv)
+	assert.Equal(t, true, matchWhen)
+
+}
+
+func TestRoute_ReturnFalse(t *testing.T) {
+	url, _ := common.NewURL(context.TODO(), "")
+	invokers := []protocol.Invoker{NewMockInvoker(url, 1), NewMockInvoker(url, 2), NewMockInvoker(url, 3)}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => false"))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 0, len(fileredInvokers))
+}
+func TestRoute_ReturnEmpty(t *testing.T) {
+	url, _ := common.NewURL(context.TODO(), "")
+	invokers := []protocol.Invoker{NewMockInvoker(url, 1), NewMockInvoker(url, 2), NewMockInvoker(url, 3)}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => "))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 0, len(fileredInvokers))
+}
+func TestRoute_ReturnAll(t *testing.T) {
+	invokers := []protocol.Invoker{&MockInvoker{}, &MockInvoker{}, &MockInvoker{}}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = " + LocalIp()))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, invokers, fileredInvokers)
+}
+
+func TestRoute_HostFilter(t *testing.T) {
+	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
+	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	invoker1 := NewMockInvoker(url1, 1)
+	invoker2 := NewMockInvoker(url2, 2)
+	invoker3 := NewMockInvoker(url3, 3)
+	invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = " + LocalIp()))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 2, len(fileredInvokers))
+	assert.Equal(t, invoker2, fileredInvokers[0])
+	assert.Equal(t, invoker3, fileredInvokers[1])
+}
+func TestRoute_Empty_HostFilter(t *testing.T) {
+	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
+	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	invoker1 := NewMockInvoker(url1, 1)
+	invoker2 := NewMockInvoker(url2, 2)
+	invoker3 := NewMockInvoker(url3, 3)
+	invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte(" => " + " host = " + LocalIp()))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 2, len(fileredInvokers))
+	assert.Equal(t, invoker2, fileredInvokers[0])
+	assert.Equal(t, invoker3, fileredInvokers[1])
+}
+func TestRoute_False_HostFilter(t *testing.T) {
+	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
+	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	invoker1 := NewMockInvoker(url1, 1)
+	invoker2 := NewMockInvoker(url2, 2)
+	invoker3 := NewMockInvoker(url3, 3)
+	invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("true => " + " host = " + LocalIp()))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 2, len(fileredInvokers))
+	assert.Equal(t, invoker2, fileredInvokers[0])
+	assert.Equal(t, invoker3, fileredInvokers[1])
+}
+func TestRoute_Placeholder(t *testing.T) {
+	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
+	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	invoker1 := NewMockInvoker(url1, 1)
+	invoker2 := NewMockInvoker(url2, 2)
+	invoker3 := NewMockInvoker(url3, 3)
+	invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = $host"))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 2, len(fileredInvokers))
+	assert.Equal(t, invoker2, fileredInvokers[0])
+	assert.Equal(t, invoker3, fileredInvokers[1])
+}
+func TestRoute_NoForce(t *testing.T) {
+	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
+	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	invoker1 := NewMockInvoker(url1, 1)
+	invoker2 := NewMockInvoker(url2, 2)
+	invoker3 := NewMockInvoker(url3, 3)
+	invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 1.2.3.4"))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrlWithNoForce(rule))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, invokers, fileredInvokers)
+}
+func TestRoute_Force(t *testing.T) {
+	url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
+	url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
+	invoker1 := NewMockInvoker(url1, 1)
+	invoker2 := NewMockInvoker(url2, 2)
+	invoker3 := NewMockInvoker(url3, 3)
+	invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
+	inv := &invocation.RPCInvocation{}
+	rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 1.2.3.4"))
+	curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
+
+	router, _ := NewConditionRouterFactory().GetRouter(getRouteUrlWithForce(rule, "true"))
+	fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv)
+	assert.Equal(t, 0, len(fileredInvokers))
 }
diff --git a/cluster/router/router_factory.go b/cluster/router/router_factory.go
index 3134804e2..04ff0502c 100644
--- a/cluster/router/router_factory.go
+++ b/cluster/router/router_factory.go
@@ -7,7 +7,7 @@ import (
 )
 
 func init() {
-	extension.SetRouterFactory("conditionRouterFactory", NewConditionRouterFactory)
+	extension.SetRouterFactory("condition", NewConditionRouterFactory)
 }
 
 type ConditionRouterFactory struct {
diff --git a/common/url.go b/common/url.go
index 4f3546880..e6fdf49e9 100644
--- a/common/url.go
+++ b/common/url.go
@@ -282,12 +282,12 @@ func (c URL) GetRawParameter(key string) string {
 	if "username" == key {
 		return c.Username
 	}
+	if "host" == key {
+		return strings.Split(c.Location, ":")[0]
+	}
 	if "password" == key {
 		return c.Password
 	}
-	if "host" == key {
-		return c.Ip
-	}
 	if "port" == key {
 		return c.Port
 	}
diff --git a/protocol/invocation/rpcinvocation.go b/protocol/invocation/rpcinvocation.go
index c8f45c561..008e883ae 100644
--- a/protocol/invocation/rpcinvocation.go
+++ b/protocol/invocation/rpcinvocation.go
@@ -69,6 +69,14 @@ func NewRPCInvocationForProvider(methodName string, arguments []interface{}, att
 	}
 }
 
+func NewRPCInvocation(methodName string, parameterTypes []reflect.Type, arguments []interface{}) *RPCInvocation {
+	return &RPCInvocation{
+		methodName:     methodName,
+		arguments:      arguments,
+		parameterTypes: parameterTypes,
+	}
+}
+
 func (r *RPCInvocation) MethodName() string {
 	return r.methodName
 }
-- 
GitLab