Skip to content
Snippets Groups Projects
Commit 56c3f906 authored by aliiohs's avatar aliiohs
Browse files

Add routing-related features and add some test case for condition router

parent 4c536173
No related branches found
No related tags found
No related merge requests found
package router
import (
"encoding/base64"
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/utils"
"regexp"
"strings"
......@@ -15,6 +17,8 @@ const (
RoutePattern = `([&!=,]*)\\s*([^&!=,\\s]+)`
)
var itemExists = struct{}{}
type ConditionRouter struct {
Pattern string
Url common.URL
......@@ -28,7 +32,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv
if len(invokers) == 0 {
return invokers, nil
}
if !c.matchWhen(url, invocation) {
if !c.MatchWhen(url, invocation) {
return invokers, nil
}
var result []protocol.Invoker
......@@ -36,7 +40,7 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv
return result, nil
}
for _, invoker := range invokers {
if c.matchThen(invoker.GetUrl(), url) {
if c.MatchThen(invoker.GetUrl(), url) {
result = append(result, invoker)
}
}
......@@ -70,7 +74,18 @@ func (c ConditionRouter) CompareTo(r cluster.Router) int {
func newConditionRouter(url common.URL) (*ConditionRouter, error) {
var whenRule string
var thenRule string
//rule := url.GetParam("rule", "")
ruleDec, err := base64.URLEncoding.DecodeString(url.GetParam("rule", ""))
rule := string(ruleDec)
if err != nil || rule == "" {
return nil, perrors.Errorf("Illegal route rule!")
}
rule = strings.Replace(rule, "consumer.", "", -1)
rule = strings.Replace(rule, "provider.", "", -1)
i := strings.Index(rule, "=>")
whenRule = strings.Trim(If(i < 0, "", rule[0:i]).(string), " ")
thenRule = strings.Trim(If(i < 0, rule, rule[i+2:]).(string), " ")
w, err := parseRule(whenRule)
if err != nil {
return nil, perrors.Errorf("%s", "")
......@@ -100,7 +115,7 @@ func parseRule(rule string) (map[string]MatchPair, error) {
return condition, nil
}
var pair MatchPair
values := make(map[string]interface{})
values := utils.NewSet()
reg := regexp.MustCompile(`([&!=,]*)\s*([^&!=,\s]+)`)
......@@ -112,7 +127,10 @@ func parseRule(rule string) (map[string]MatchPair, error) {
switch separator {
case "":
pair = MatchPair{}
pair = MatchPair{
Matches: utils.NewSet(),
Mismatches: utils.NewSet(),
}
condition[content] = pair
case "&":
if r, ok := condition[content]; ok {
......@@ -126,18 +144,18 @@ func parseRule(rule string) (map[string]MatchPair, error) {
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
}
values = pair.Matches
values[content] = ""
values.Add(content)
case "!=":
if &pair == nil {
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
}
values = pair.Matches
values[content] = ""
values = pair.Mismatches
values.Add(content)
case ",":
if len(values) == 0 {
if values.Empty() {
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
}
values[content] = ""
values.Add(content)
default:
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex[0], startIndex[0])
......@@ -149,16 +167,16 @@ func parseRule(rule string) (map[string]MatchPair, error) {
}
func (c *ConditionRouter) matchWhen(url common.URL, invocation protocol.Invocation) bool {
func (c *ConditionRouter) MatchWhen(url common.URL, invocation protocol.Invocation) bool {
return len(c.WhenCondition) == 0 || len(c.WhenCondition) == 0 || matchCondition(c.WhenCondition, &url, nil, invocation)
return len(c.WhenCondition) == 0 || MatchCondition(c.WhenCondition, &url, nil, invocation)
}
func (c *ConditionRouter) matchThen(url common.URL, param common.URL) bool {
func (c *ConditionRouter) MatchThen(url common.URL, param common.URL) bool {
return !(len(c.ThenCondition) == 0) && matchCondition(c.ThenCondition, &url, &param, nil)
return len(c.ThenCondition) > 0 && MatchCondition(c.ThenCondition, &url, &param, nil)
}
func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool {
func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) bool {
sample := url.ToMap()
result := false
for key, matchPair := range pairs {
......@@ -168,18 +186,18 @@ func matchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U
sampleValue = invocation.MethodName()
} else {
sampleValue = sample[key]
if &sampleValue == nil {
sampleValue = sample[constant.DEFAULT_KEY_PREFIX+key]
if sampleValue == "" {
sampleValue = sample[constant.PREFIX_DEFAULT_KEY+key]
}
}
if &sampleValue != nil {
if sampleValue != "" {
if !matchPair.isMatch(sampleValue, param) {
return false
} else {
result = true
}
} else {
if !(len(matchPair.Matches) == 0) {
if !(matchPair.Matches.Empty()) {
return false
} else {
result = true
......@@ -198,11 +216,72 @@ func If(b bool, t, f interface{}) interface{} {
}
type MatchPair struct {
Matches map[string]interface{}
Mismatches map[string]interface{}
Matches *utils.HashSet
Mismatches *utils.HashSet
}
func (pair MatchPair) isMatch(s string, param *common.URL) bool {
func (pair MatchPair) isMatch(value string, param *common.URL) bool {
if !pair.Matches.Empty() && pair.Mismatches.Empty() {
for match := range pair.Matches.Items {
if isMatchGlobPattern(match.(string), value, param) {
return true
}
}
return false
}
if !pair.Mismatches.Empty() && pair.Matches.Empty() {
for mismatch := range pair.Mismatches.Items {
if isMatchGlobPattern(mismatch.(string), value, param) {
return false
}
}
return true
}
if !pair.Mismatches.Empty() && !pair.Matches.Empty() {
for mismatch := range pair.Mismatches.Items {
if isMatchGlobPattern(mismatch.(string), value, param) {
return false
}
}
for match := range pair.Matches.Items {
if isMatchGlobPattern(match.(string), value, param) {
return true
}
}
return false
}
return false
}
func isMatchGlobPattern(pattern string, value string, param *common.URL) bool {
if param != nil && strings.HasPrefix(pattern, "$") {
pattern = param.GetRawParameter(pattern[1:])
}
if "*" == pattern {
return true
}
if pattern == "" && value == "" {
return true
}
if pattern == "" || value == "" {
return false
}
i := strings.LastIndex(pattern, "*")
switch i {
case -1:
return value == pattern
case len(pattern) - 1:
return strings.HasPrefix(value, pattern[0:i])
case 0:
return strings.HasSuffix(value, pattern[:i+1])
default:
prefix := pattern[0:1]
suffix := pattern[i+1:]
return strings.HasPrefix(value, prefix) && strings.HasSuffix(value, suffix)
}
}
......@@ -2,6 +2,7 @@ package router
import (
"context"
"encoding/base64"
perrors "errors"
"fmt"
"github.com/apache/dubbo-go/common"
......@@ -68,10 +69,7 @@ func (bi *MockInvoker) Destroy() {
bi.destroyed = true
bi.available = false
}
func Test_parseRule(t *testing.T) {
parseRule("host = 10.20.153.10 => host = 10.20.153.11")
}
func LocalIp() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
......@@ -88,21 +86,81 @@ func LocalIp() string {
return ip
}
func TestRoute_matchWhen(t *testing.T) {
rpcInvacation := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("=> host = 1.2.3.4"))
router, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule))
cUrl, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService")
matchWhen := router.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, true, matchWhen)
rule1 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router1, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule1))
matchWhen1 := router1.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, true, matchWhen1)
rule2 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"))
router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2))
matchWhen2 := router2.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, false, matchWhen2)
rule3 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.4 & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3))
matchWhen3 := router3.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, true, matchWhen3)
rule4 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router4, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule4))
matchWhen4 := router4.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, true, matchWhen4)
rule5 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.1 => host = 1.2.3.4"))
router5, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule5))
matchWhen5 := router5.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, false, matchWhen5)
rule6 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.2 => host = 1.2.3.4"))
router6, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule6))
matchWhen6 := router6.(*ConditionRouter).MatchWhen(cUrl, rpcInvacation)
assert.Equal(t, true, matchWhen6)
}
func TestRoute_matchFilter(t *testing.T) {
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService?default.serialization=fastjson")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", LocalIp()))
invokers := []protocol.Invoker{NewMockInvoker(url1, 1), NewMockInvoker(url2, 2), NewMockInvoker(url3, 3)}
option := common.WithParamsValue("rule", "host = "+LocalIp()+" => "+" host = 10.20.3.3")
option1 := common.WithParamsValue("force", "true")
sUrl, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService", option, option1)
router1, _ := NewConditionRouterFactory().GetRouter(sUrl)
rule1 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.3"))
rule2 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.* & host != 10.20.3.3"))
rule3 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.3 & host != 10.20.3.3"))
rule4 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host = 10.20.3.2,10.20.3.3,10.20.3.4"))
rule5 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " host != 10.20.3.3"))
rule6 := base64.URLEncoding.EncodeToString([]byte("host = " + LocalIp() + " => " + " serialization = fastjson"))
router1, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule1))
router2, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule2))
router3, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule3))
router4, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule4))
router5, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule5))
router6, _ := NewConditionRouterFactory().GetRouter(getRouteUrl(rule6))
cUrl, _ := common.NewURL(context.TODO(), "consumer://"+LocalIp()+"/com.foo.BarService")
routers, _ := router1.Route(invokers, cUrl, &invocation.RPCInvocation{})
assert.Equal(t, 1, len(routers))
fileredInvokers1, _ := router1.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers2, _ := router2.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers3, _ := router3.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers4, _ := router4.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers5, _ := router5.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers6, _ := router6.Route(invokers, cUrl, &invocation.RPCInvocation{})
assert.Equal(t, 1, len(fileredInvokers1))
assert.Equal(t, 0, len(fileredInvokers2))
assert.Equal(t, 0, len(fileredInvokers3))
assert.Equal(t, 1, len(fileredInvokers4))
assert.Equal(t, 2, len(fileredInvokers5))
assert.Equal(t, 1, len(fileredInvokers6))
}
func getRouteUrl(rule string) common.URL {
url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
url.AddParam("rule", rule)
url.AddParam("force", "true")
return url
}
......@@ -33,7 +33,7 @@ const (
const (
DEFAULT_KEY = "default"
DEFAULT_KEY_PREFIX = "default."
PREFIX_DEFAULT_KEY = "default."
DEFAULT_SERVICE_FILTERS = "echo"
DEFAULT_REFERENCE_FILTERS = ""
ECHO = "$echo"
......
......@@ -262,6 +262,11 @@ func (c URL) Service() string {
}
return ""
}
func (c *URL) AddParam(key string, value string) {
c.Params.Add(key, value)
}
func (c URL) GetParam(s string, d string) string {
var r string
if r = c.Params.Get(s); r == "" {
......@@ -270,6 +275,28 @@ func (c URL) GetParam(s string, d string) string {
return r
}
func (c URL) GetRawParameter(key string) string {
if "protocol" == key {
return c.Protocol
}
if "username" == key {
return c.Username
}
if "password" == key {
return c.Password
}
if "host" == key {
return c.Ip
}
if "port" == key {
return c.Port
}
if "path" == key {
return c.Path
}
return c.Params.Get(key)
}
// GetParamBool
func (c URL) GetParamBool(s string, d bool) bool {
......@@ -311,6 +338,10 @@ func (c URL) GetMethodParam(method string, key string, d string) string {
func (c URL) ToMap() map[string]string {
paramsMap := make(map[string]string)
for k, v := range c.Params {
paramsMap[k] = v[0]
}
if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
}
......@@ -320,8 +351,15 @@ func (c URL) ToMap() map[string]string {
if c.Password != "" {
paramsMap["password"] = c.Password
}
if c.Ip != "" {
paramsMap["host"] = c.Ip
if c.Location != "" {
paramsMap["host"] = strings.Split(c.Location, ":")[0]
var port string
if strings.Contains(c.Location, ":") {
port = strings.Split(c.Location, ":")[1]
} else {
port = "0"
}
paramsMap["port"] = port
}
if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
......
package utils
import (
"fmt"
"strings"
)
var itemExists = struct{}{}
type HashSet struct {
Items map[interface{}]struct{}
}
func NewSet(values ...interface{}) *HashSet {
set := &HashSet{Items: make(map[interface{}]struct{})}
if len(values) > 0 {
set.Add(values...)
}
return set
}
func (set *HashSet) Add(items ...interface{}) {
for _, item := range items {
set.Items[item] = itemExists
}
}
func (set *HashSet) Remove(items ...interface{}) {
for _, item := range items {
delete(set.Items, item)
}
}
func (set *HashSet) Contains(items ...interface{}) bool {
for _, item := range items {
if _, contains := set.Items[item]; !contains {
return false
}
}
return true
}
func (set *HashSet) Empty() bool {
return set.Size() == 0
}
func (set *HashSet) Size() int {
return len(set.Items)
}
func (set *HashSet) Clear() {
set.Items = make(map[interface{}]struct{})
}
func (set *HashSet) Values() []interface{} {
values := make([]interface{}, set.Size())
count := 0
for item := range set.Items {
values[count] = item
count++
}
return values
}
func (set *HashSet) String() string {
str := "HashSet\n"
var items []string
for k := range set.Items {
items = append(items, fmt.Sprintf("%v", k))
}
str += strings.Join(items, ", ")
return str
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment