diff --git a/cluster/router/condition_router.go b/cluster/router/condition_router.go index a8074f73928301173aa7ab5760509cb0a5e53ea2..a175e5c5dfbbefa30edd5553e944765d6bbe8066 100644 --- a/cluster/router/condition_router.go +++ b/cluster/router/condition_router.go @@ -40,6 +40,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv return result, nil } for _, invoker := range invokers { + if c.MatchThen(invoker.GetUrl(), url) { result = append(result, invoker) } @@ -50,7 +51,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv //todo 鏃ュ織 return result, nil } - return result, nil + return invokers, nil } func (c ConditionRouter) CompareTo(r cluster.Router) int { @@ -101,8 +102,8 @@ func newConditionRouter(url common.URL) (*ConditionRouter, error) { return &ConditionRouter{ RoutePattern, url, - url.GetParamInt("Priority", 0), - url.GetParamBool("Force", false), + url.GetParamInt("priority", 0), + url.GetParamBool("force", false), when, then, }, nil @@ -118,8 +119,10 @@ func parseRule(rule string) (map[string]MatchPair, error) { values := utils.NewSet() reg := regexp.MustCompile(`([&!=,]*)\s*([^&!=,\s]+)`) - - startIndex := reg.FindIndex([]byte(rule)) + var startIndex = 0 + if indexTuple := reg.FindIndex([]byte(rule)); len(indexTuple) > 0 { + startIndex = indexTuple[0] + } matches := reg.FindAllSubmatch([]byte(rule), -1) for _, groups := range matches { separator := string(groups[1]) @@ -136,28 +139,31 @@ func parseRule(rule string) (map[string]MatchPair, error) { if r, ok := condition[content]; ok { pair = r } else { - pair = MatchPair{} + pair = MatchPair{ + Matches: utils.NewSet(), + Mismatches: utils.NewSet(), + } condition[content] = pair } case "=": if &pair == nil { - return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) + return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex) } values = pair.Matches values.Add(content) case "!=": if &pair == nil { - return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) + return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex) } values = pair.Mismatches values.Add(content) case ",": if values.Empty() { - return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) + return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex) } values.Add(content) default: - return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0]) + return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex) } } @@ -178,6 +184,9 @@ func (c *ConditionRouter) MatchThen(url common.URL, param common.URL) bool { func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool { sample := url.ToMap() + if len(sample) == 0 { + return true + } result := false for key, matchPair := range pairs { var sampleValue string diff --git a/cluster/router/condition_router_test.go b/cluster/router/condition_router_test.go index 5c6c2ecfd2e6ee6b74b9407f9d0ea736b9a248a3..b5a32be0de4316cddf1c909f9646d701a0b0b8f9 100644 --- a/cluster/router/condition_router_test.go +++ b/cluster/router/condition_router_test.go @@ -11,6 +11,7 @@ import ( "github.com/apache/dubbo-go/protocol/invocation" "github.com/stretchr/testify/assert" "net" + "reflect" "testing" ) @@ -35,6 +36,24 @@ func (bi *MockInvoker) GetUrl() common.URL { return bi.url } +func getRouteUrl(rule string) common.URL { + url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService") + url.AddParam("rule", rule) + url.AddParam("force", "true") + return url +} +func getRouteUrlWithForce(rule, force string) common.URL { + url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService") + url.AddParam("rule", rule) + url.AddParam("force", force) + return url +} +func getRouteUrlWithNoForce(rule string) common.URL { + url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService") + url.AddParam("rule", rule) + return url +} + func (bi *MockInvoker) IsAvailable() bool { return bi.available } @@ -158,9 +177,168 @@ func TestRoute_matchFilter(t *testing.T) { } -func getRouteUrl(rule string) common.URL { - url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService") - url.AddParam("rule", rule) - url.AddParam("force", "true") - return url +func TestRoute_methodRoute(t *testing.T) { + + inv := invocation.NewRPCInvocation("getFoo", []reflect.Type{}, []interface{}{}) + rule := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + + url, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=setFoo,getFoo,findFoo") + matchWhen := router.(*ConditionRouter).MatchWhen(url, inv) + assert.Equal(t, true, matchWhen) + + url1, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo") + matchWhen = router.(*ConditionRouter).MatchWhen(url1, inv) + assert.Equal(t, true, matchWhen) + + url2, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo") + rule2 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host!=1.1.1.1 => host = 1.2.3.4")) + router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2)) + matchWhen = router2.(*ConditionRouter).MatchWhen(url2, inv) + assert.Equal(t, false, matchWhen) + + url3, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo") + rule3 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host=1.1.1.1 => host = 1.2.3.4")) + router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3)) + matchWhen = router3.(*ConditionRouter).MatchWhen(url3, inv) + assert.Equal(t, true, matchWhen) + +} + +func TestRoute_ReturnFalse(t *testing.T) { + url, _ := common.NewURL(context.TODO(), "") + invokers := []protocol.Invoker{NewMockInvoker(url, 1), NewMockInvoker(url, 2), NewMockInvoker(url, 3)} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => false")) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 0, len(fileredInvokers)) +} +func TestRoute_ReturnEmpty(t *testing.T) { + url, _ := common.NewURL(context.TODO(), "") + invokers := []protocol.Invoker{NewMockInvoker(url, 1), NewMockInvoker(url, 2), NewMockInvoker(url, 3)} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => ")) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 0, len(fileredInvokers)) +} +func TestRoute_ReturnAll(t *testing.T) { + invokers := []protocol.Invoker{&MockInvoker{}, &MockInvoker{}, &MockInvoker{}} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = " + LocalIp())) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, invokers, fileredInvokers) +} + +func TestRoute_HostFilter(t *testing.T) { + url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService") + url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + invoker1 := NewMockInvoker(url1, 1) + invoker2 := NewMockInvoker(url2, 2) + invoker3 := NewMockInvoker(url3, 3) + invokers := []protocol.Invoker{invoker1, invoker2, invoker3} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = " + LocalIp())) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 2, len(fileredInvokers)) + assert.Equal(t, invoker2, fileredInvokers[0]) + assert.Equal(t, invoker3, fileredInvokers[1]) +} +func TestRoute_Empty_HostFilter(t *testing.T) { + url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService") + url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + invoker1 := NewMockInvoker(url1, 1) + invoker2 := NewMockInvoker(url2, 2) + invoker3 := NewMockInvoker(url3, 3) + invokers := []protocol.Invoker{invoker1, invoker2, invoker3} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte(" => " + " host = " + LocalIp())) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 2, len(fileredInvokers)) + assert.Equal(t, invoker2, fileredInvokers[0]) + assert.Equal(t, invoker3, fileredInvokers[1]) +} +func TestRoute_False_HostFilter(t *testing.T) { + url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService") + url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + invoker1 := NewMockInvoker(url1, 1) + invoker2 := NewMockInvoker(url2, 2) + invoker3 := NewMockInvoker(url3, 3) + invokers := []protocol.Invoker{invoker1, invoker2, invoker3} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("true => " + " host = " + LocalIp())) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 2, len(fileredInvokers)) + assert.Equal(t, invoker2, fileredInvokers[0]) + assert.Equal(t, invoker3, fileredInvokers[1]) +} +func TestRoute_Placeholder(t *testing.T) { + url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService") + url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + invoker1 := NewMockInvoker(url1, 1) + invoker2 := NewMockInvoker(url2, 2) + invoker3 := NewMockInvoker(url3, 3) + invokers := []protocol.Invoker{invoker1, invoker2, invoker3} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = $host")) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 2, len(fileredInvokers)) + assert.Equal(t, invoker2, fileredInvokers[0]) + assert.Equal(t, invoker3, fileredInvokers[1]) +} +func TestRoute_NoForce(t *testing.T) { + url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService") + url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + invoker1 := NewMockInvoker(url1, 1) + invoker2 := NewMockInvoker(url2, 2) + invoker3 := NewMockInvoker(url3, 3) + invokers := []protocol.Invoker{invoker1, invoker2, invoker3} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 1.2.3.4")) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrlWithNoForce(rule)) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, invokers, fileredInvokers) +} +func TestRoute_Force(t *testing.T) { + url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService") + url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp())) + invoker1 := NewMockInvoker(url1, 1) + invoker2 := NewMockInvoker(url2, 2) + invoker3 := NewMockInvoker(url3, 3) + invokers := []protocol.Invoker{invoker1, invoker2, invoker3} + inv := &invocation.RPCInvocation{} + rule := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 1.2.3.4")) + curl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService") + + router, _ := NewConditionRouterFactory().GetRouter(getRouteUrlWithForce(rule, "true")) + fileredInvokers, _ := router.(*ConditionRouter).Route(invokers, curl, inv) + assert.Equal(t, 0, len(fileredInvokers)) } diff --git a/cluster/router/router_factory.go b/cluster/router/router_factory.go index 3134804e2787abca31e1e8b68182bfc23264ce5a..04ff0502cd587f0568dec4e420d34e688619883e 100644 --- a/cluster/router/router_factory.go +++ b/cluster/router/router_factory.go @@ -7,7 +7,7 @@ import ( ) func init() { - extension.SetRouterFactory("conditionRouterFactory", NewConditionRouterFactory) + extension.SetRouterFactory("condition", NewConditionRouterFactory) } type ConditionRouterFactory struct { diff --git a/common/url.go b/common/url.go index 4f354688092d702ae58fd6c30e1adea387e5d399..e6fdf49e9f8e0b6eab28e6efa8abf31c364a5c40 100644 --- a/common/url.go +++ b/common/url.go @@ -282,12 +282,12 @@ func (c URL) GetRawParameter(key string) string { if "username" == key { return c.Username } + if "host" == key { + return strings.Split(c.Location, ":")[0] + } if "password" == key { return c.Password } - if "host" == key { - return c.Ip - } if "port" == key { return c.Port } diff --git a/protocol/invocation/rpcinvocation.go b/protocol/invocation/rpcinvocation.go index c8f45c561ef88f49d33b755f988ff9a125e30c8f..008e883aed317a78d68fc0f0d55f49df66b2dee0 100644 --- a/protocol/invocation/rpcinvocation.go +++ b/protocol/invocation/rpcinvocation.go @@ -69,6 +69,14 @@ func NewRPCInvocationForProvider(methodName string, arguments []interface{}, att } } +func NewRPCInvocation(methodName string, parameterTypes []reflect.Type, arguments []interface{}) *RPCInvocation { + return &RPCInvocation{ + methodName: methodName, + arguments: arguments, + parameterTypes: parameterTypes, + } +} + func (r *RPCInvocation) MethodName() string { return r.methodName }