Skip to content
Snippets Groups Projects
Unverified Commit d6ecdb64 authored by Xin.Zh's avatar Xin.Zh Committed by GitHub
Browse files

Merge pull request #100 from aliiohs/feature/routing_rule

Feature/routing rule
parents b1720e7e 6e6f13b0
No related branches found
No related tags found
No related merge requests found
......@@ -21,4 +21,4 @@ classes
vendor/
logs/
.vscode/
......@@ -25,7 +25,7 @@ import (
// Extension - Router
type RouterFactory interface {
Router(common.URL) Router
Router(*common.URL) (Router, error)
}
type Router interface {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package router
import (
"reflect"
"regexp"
"strings"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/common/utils"
"github.com/apache/dubbo-go/gostd/container"
"github.com/apache/dubbo-go/protocol"
)
import (
perrors "github.com/pkg/errors"
)
const (
ROUTE_PATTERN = `([&!=,]*)\\s*([^&!=,\\s]+)`
FORCE = "force"
PRIORITY = "priority"
)
//ConditionRouter condition router struct
type ConditionRouter struct {
Pattern string
Url *common.URL
Priority int64
Force bool
WhenCondition map[string]MatchPair
ThenCondition map[string]MatchPair
}
func newConditionRouter(url *common.URL) (*ConditionRouter, error) {
var (
whenRule string
thenRule string
when map[string]MatchPair
then map[string]MatchPair
)
rule, err := url.GetParamAndDecoded(constant.RULE_KEY)
if err != nil || len(rule) == 0 {
return nil, perrors.Errorf("Illegal route rule!")
}
rule = strings.Replace(rule, "consumer.", "", -1)
rule = strings.Replace(rule, "provider.", "", -1)
i := strings.Index(rule, "=>")
if i > 0 {
whenRule = rule[0:i]
}
if i < 0 {
thenRule = rule
} else {
thenRule = rule[i+2:]
}
whenRule = strings.Trim(whenRule, " ")
thenRule = strings.Trim(thenRule, " ")
w, err := parseRule(whenRule)
if err != nil {
return nil, perrors.Errorf("%s", "")
}
t, err := parseRule(thenRule)
if err != nil {
return nil, perrors.Errorf("%s", "")
}
if len(whenRule) == 0 || "true" == whenRule {
when = make(map[string]MatchPair, 16)
} else {
when = w
}
if len(thenRule) == 0 || "false" == thenRule {
when = make(map[string]MatchPair, 16)
} else {
then = t
}
return &ConditionRouter{
ROUTE_PATTERN,
url,
url.GetParamInt(PRIORITY, 0),
url.GetParamBool(FORCE, false),
when,
then,
}, nil
}
//Router determine the target server list.
func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, invocation protocol.Invocation) []protocol.Invoker {
if len(invokers) == 0 {
return invokers
}
isMatchWhen, err := c.MatchWhen(url, invocation)
if err != nil {
var urls []string
for _, invo := range invokers {
urls = append(urls, reflect.TypeOf(invo).String())
}
logger.Warnf("Failed to execute condition router rule: %s , invokers: [%s], cause: %v", c.Url.String(), strings.Join(urls, ","), err)
return invokers
}
if !isMatchWhen {
return invokers
}
var result []protocol.Invoker
if len(c.ThenCondition) == 0 {
return result
}
localIP, _ := utils.GetLocalIP()
for _, invoker := range invokers {
isMatchThen, err := c.MatchThen(invoker.GetUrl(), url)
if err != nil {
var urls []string
for _, invo := range invokers {
urls = append(urls, reflect.TypeOf(invo).String())
}
logger.Warnf("Failed to execute condition router rule: %s , invokers: [%s], cause: %v", c.Url.String(), strings.Join(urls, ","), err)
return invokers
}
if isMatchThen {
result = append(result, invoker)
}
}
if len(result) > 0 {
return result
} else if c.Force {
rule, _ := url.GetParamAndDecoded(constant.RULE_KEY)
logger.Warnf("The route result is empty and force execute. consumer: %s, service: %s, router: %s", localIP, url.Service(), rule)
return result
}
return invokers
}
func parseRule(rule string) (map[string]MatchPair, error) {
condition := make(map[string]MatchPair, 16)
if len(rule) == 0 {
return condition, nil
}
var pair MatchPair
values := container.NewSet()
reg := regexp.MustCompile(`([&!=,]*)\s*([^&!=,\s]+)`)
var startIndex = 0
if indexTuple := reg.FindIndex([]byte(rule)); len(indexTuple) > 0 {
startIndex = indexTuple[0]
}
matches := reg.FindAllSubmatch([]byte(rule), -1)
for _, groups := range matches {
separator := string(groups[1])
content := string(groups[2])
switch separator {
case "":
pair = MatchPair{
Matches: container.NewSet(),
Mismatches: container.NewSet(),
}
condition[content] = pair
case "&":
if r, ok := condition[content]; ok {
pair = r
} else {
pair = MatchPair{
Matches: container.NewSet(),
Mismatches: container.NewSet(),
}
condition[content] = pair
}
case "=":
if &pair == nil {
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
}
values = pair.Matches
values.Add(content)
case "!=":
if &pair == nil {
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
}
values = pair.Mismatches
values.Add(content)
case ",":
if values.Empty() {
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
}
values.Add(content)
default:
return nil, perrors.Errorf("Illegal route rule \"%s\", The error char '%s' at index %d before \"%d\".", rule, separator, startIndex, startIndex)
}
}
return condition, nil
}
//
func (c *ConditionRouter) MatchWhen(url common.URL, invocation protocol.Invocation) (bool, error) {
condition, err := MatchCondition(c.WhenCondition, &url, nil, invocation)
return len(c.WhenCondition) == 0 || condition, err
}
//MatchThen MatchThen
func (c *ConditionRouter) MatchThen(url common.URL, param common.URL) (bool, error) {
condition, err := MatchCondition(c.ThenCondition, &url, &param, nil)
return len(c.ThenCondition) > 0 && condition, err
}
//MatchCondition MatchCondition
func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.URL, invocation protocol.Invocation) (bool, error) {
sample := url.ToMap()
if sample == nil {
return true, perrors.Errorf("url is not allowed be nil")
}
result := false
for key, matchPair := range pairs {
var sampleValue string
if invocation != nil && ((constant.METHOD_KEY == key) || (constant.METHOD_KEYS == key)) {
sampleValue = invocation.MethodName()
} else {
sampleValue = sample[key]
if len(sampleValue) == 0 {
sampleValue = sample[constant.PREFIX_DEFAULT_KEY+key]
}
}
if len(sampleValue) > 0 {
if !matchPair.isMatch(sampleValue, param) {
return false, nil
} else {
result = true
}
} else {
if !(matchPair.Matches.Empty()) {
return false, nil
} else {
result = true
}
}
}
return result, nil
}
type MatchPair struct {
Matches *container.HashSet
Mismatches *container.HashSet
}
func (pair MatchPair) isMatch(value string, param *common.URL) bool {
if !pair.Matches.Empty() && pair.Mismatches.Empty() {
for match := range pair.Matches.Items {
if isMatchGlobPattern(match.(string), value, param) {
return true
}
}
return false
}
if !pair.Mismatches.Empty() && pair.Matches.Empty() {
for mismatch := range pair.Mismatches.Items {
if isMatchGlobPattern(mismatch.(string), value, param) {
return false
}
}
return true
}
if !pair.Mismatches.Empty() && !pair.Matches.Empty() {
for mismatch := range pair.Mismatches.Items {
if isMatchGlobPattern(mismatch.(string), value, param) {
return false
}
}
for match := range pair.Matches.Items {
if isMatchGlobPattern(match.(string), value, param) {
return true
}
}
return false
}
return false
}
func isMatchGlobPattern(pattern string, value string, param *common.URL) bool {
if param != nil && strings.HasPrefix(pattern, "$") {
pattern = param.GetRawParam(pattern[1:])
}
if "*" == pattern {
return true
}
if len(pattern) == 0 && len(value) == 0 {
return true
}
if len(pattern) == 0 || len(value) == 0 {
return false
}
i := strings.LastIndex(pattern, "*")
switch i {
case -1:
return value == pattern
case len(pattern) - 1:
return strings.HasPrefix(value, pattern[0:i])
case 0:
return strings.HasSuffix(value, pattern[:i+1])
default:
prefix := pattern[0:1]
suffix := pattern[i+1:]
return strings.HasPrefix(value, prefix) && strings.HasSuffix(value, suffix)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package router
import (
"context"
"encoding/base64"
"fmt"
"reflect"
"testing"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/common/utils"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
)
import (
perrors "github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
type MockInvoker struct {
url common.URL
available bool
destroyed bool
successCount int
}
func NewMockInvoker(url common.URL, successCount int) *MockInvoker {
return &MockInvoker{
url: url,
available: true,
destroyed: false,
successCount: successCount,
}
}
func (bi *MockInvoker) GetUrl() common.URL {
return bi.url
}
func getRouteUrl(rule string) *common.URL {
url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
url.AddParam("rule", rule)
url.AddParam("force", "true")
return &url
}
func getRouteUrlWithForce(rule, force string) *common.URL {
url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
url.AddParam("rule", rule)
url.AddParam("force", force)
return &url
}
func getRouteUrlWithNoForce(rule string) *common.URL {
url, _ := common.NewURL(context.TODO(), "condition://0.0.0.0/com.foo.BarService")
url.AddParam("rule", rule)
return &url
}
func (bi *MockInvoker) IsAvailable() bool {
return bi.available
}
func (bi *MockInvoker) IsDestroyed() bool {
return bi.destroyed
}
type rest struct {
tried int
success bool
}
var count int
func (bi *MockInvoker) Invoke(invocation protocol.Invocation) protocol.Result {
count++
var success bool
var err error = nil
if count >= bi.successCount {
success = true
} else {
err = perrors.New("error")
}
result := &protocol.RPCResult{Err: err, Rest: rest{tried: count, success: success}}
return result
}
func (bi *MockInvoker) Destroy() {
logger.Infof("Destroy invoker: %v", bi.GetUrl().String())
bi.destroyed = true
bi.available = false
}
func TestRoute_matchWhen(t *testing.T) {
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("=> host = 1.2.3.4"))
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
cUrl, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService")
matchWhen, _ := router.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, true, matchWhen)
rule1 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router1, _ := NewConditionRouterFactory().Router(getRouteUrl(rule1))
matchWhen1, _ := router1.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, true, matchWhen1)
rule2 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"))
router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2))
matchWhen2, _ := router2.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, false, matchWhen2)
rule3 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.4 & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3))
matchWhen3, _ := router3.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, true, matchWhen3)
rule4 := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router4, _ := NewConditionRouterFactory().Router(getRouteUrl(rule4))
matchWhen4, _ := router4.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, true, matchWhen4)
rule5 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.1 => host = 1.2.3.4"))
router5, _ := NewConditionRouterFactory().Router(getRouteUrl(rule5))
matchWhen5, _ := router5.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, false, matchWhen5)
rule6 := base64.URLEncoding.EncodeToString([]byte("host = 2.2.2.2,1.1.1.*,3.3.3.3 & host != 1.1.1.2 => host = 1.2.3.4"))
router6, _ := NewConditionRouterFactory().Router(getRouteUrl(rule6))
matchWhen6, _ := router6.(*ConditionRouter).MatchWhen(cUrl, inv)
assert.Equal(t, true, matchWhen6)
}
func TestRoute_matchFilter(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService?default.serialization=fastjson")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invokers := []protocol.Invoker{NewMockInvoker(url1, 1), NewMockInvoker(url2, 2), NewMockInvoker(url3, 3)}
rule1 := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 10.20.3.3"))
rule2 := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 10.20.3.* & host != 10.20.3.3"))
rule3 := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 10.20.3.3 & host != 10.20.3.3"))
rule4 := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 10.20.3.2,10.20.3.3,10.20.3.4"))
rule5 := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host != 10.20.3.3"))
rule6 := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " serialization = fastjson"))
router1, _ := NewConditionRouterFactory().Router(getRouteUrl(rule1))
router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2))
router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3))
router4, _ := NewConditionRouterFactory().Router(getRouteUrl(rule4))
router5, _ := NewConditionRouterFactory().Router(getRouteUrl(rule5))
router6, _ := NewConditionRouterFactory().Router(getRouteUrl(rule6))
cUrl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
fileredInvokers1 := router1.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers2 := router2.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers3 := router3.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers4 := router4.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers5 := router5.Route(invokers, cUrl, &invocation.RPCInvocation{})
fileredInvokers6 := router6.Route(invokers, cUrl, &invocation.RPCInvocation{})
assert.Equal(t, 1, len(fileredInvokers1))
assert.Equal(t, 0, len(fileredInvokers2))
assert.Equal(t, 0, len(fileredInvokers3))
assert.Equal(t, 1, len(fileredInvokers4))
assert.Equal(t, 2, len(fileredInvokers5))
assert.Equal(t, 1, len(fileredInvokers6))
}
func TestRoute_methodRoute(t *testing.T) {
inv := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("getFoo"), invocation.WithParameterTypes([]reflect.Type{}), invocation.WithArguments([]interface{}{}))
rule := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4"))
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
url, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=setFoo,getFoo,findFoo")
matchWhen, _ := router.(*ConditionRouter).MatchWhen(url, inv)
assert.Equal(t, true, matchWhen)
url1, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
matchWhen, _ = router.(*ConditionRouter).MatchWhen(url1, inv)
assert.Equal(t, true, matchWhen)
url2, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
rule2 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host!=1.1.1.1 => host = 1.2.3.4"))
router2, _ := NewConditionRouterFactory().Router(getRouteUrl(rule2))
matchWhen, _ = router2.(*ConditionRouter).MatchWhen(url2, inv)
assert.Equal(t, false, matchWhen)
url3, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=getFoo")
rule3 := base64.URLEncoding.EncodeToString([]byte("methods=getFoo & host=1.1.1.1 => host = 1.2.3.4"))
router3, _ := NewConditionRouterFactory().Router(getRouteUrl(rule3))
matchWhen, _ = router3.(*ConditionRouter).MatchWhen(url3, inv)
assert.Equal(t, true, matchWhen)
}
func TestRoute_ReturnFalse(t *testing.T) {
url, _ := common.NewURL(context.TODO(), "")
localIP, _ := utils.GetLocalIP()
invokers := []protocol.Invoker{NewMockInvoker(url, 1), NewMockInvoker(url, 2), NewMockInvoker(url, 3)}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => false"))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 0, len(fileredInvokers))
}
func TestRoute_ReturnEmpty(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url, _ := common.NewURL(context.TODO(), "")
invokers := []protocol.Invoker{NewMockInvoker(url, 1), NewMockInvoker(url, 2), NewMockInvoker(url, 3)}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => "))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 0, len(fileredInvokers))
}
func TestRoute_ReturnAll(t *testing.T) {
localIP, _ := utils.GetLocalIP()
invokers := []protocol.Invoker{&MockInvoker{}, &MockInvoker{}, &MockInvoker{}}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = " + localIP))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, invokers, fileredInvokers)
}
func TestRoute_HostFilter(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invoker1 := NewMockInvoker(url1, 1)
invoker2 := NewMockInvoker(url2, 2)
invoker3 := NewMockInvoker(url3, 3)
invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = " + localIP))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 2, len(fileredInvokers))
assert.Equal(t, invoker2, fileredInvokers[0])
assert.Equal(t, invoker3, fileredInvokers[1])
}
func TestRoute_Empty_HostFilter(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invoker1 := NewMockInvoker(url1, 1)
invoker2 := NewMockInvoker(url2, 2)
invoker3 := NewMockInvoker(url3, 3)
invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte(" => " + " host = " + localIP))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 2, len(fileredInvokers))
assert.Equal(t, invoker2, fileredInvokers[0])
assert.Equal(t, invoker3, fileredInvokers[1])
}
func TestRoute_False_HostFilter(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invoker1 := NewMockInvoker(url1, 1)
invoker2 := NewMockInvoker(url2, 2)
invoker3 := NewMockInvoker(url3, 3)
invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("true => " + " host = " + localIP))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 2, len(fileredInvokers))
assert.Equal(t, invoker2, fileredInvokers[0])
assert.Equal(t, invoker3, fileredInvokers[1])
}
func TestRoute_Placeholder(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invoker1 := NewMockInvoker(url1, 1)
invoker2 := NewMockInvoker(url2, 2)
invoker3 := NewMockInvoker(url3, 3)
invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = $host"))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 2, len(fileredInvokers))
assert.Equal(t, invoker2, fileredInvokers[0])
assert.Equal(t, invoker3, fileredInvokers[1])
}
func TestRoute_NoForce(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invoker1 := NewMockInvoker(url1, 1)
invoker2 := NewMockInvoker(url2, 2)
invoker3 := NewMockInvoker(url3, 3)
invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 1.2.3.4"))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrlWithNoForce(rule))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, invokers, fileredInvokers)
}
func TestRoute_Force(t *testing.T) {
localIP, _ := utils.GetLocalIP()
url1, _ := common.NewURL(context.TODO(), "dubbo://10.20.3.3:20880/com.foo.BarService")
url2, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
url3, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://%s:20880/com.foo.BarService", localIP))
invoker1 := NewMockInvoker(url1, 1)
invoker2 := NewMockInvoker(url2, 2)
invoker3 := NewMockInvoker(url3, 3)
invokers := []protocol.Invoker{invoker1, invoker2, invoker3}
inv := &invocation.RPCInvocation{}
rule := base64.URLEncoding.EncodeToString([]byte("host = " + localIP + " => " + " host = 1.2.3.4"))
curl, _ := common.NewURL(context.TODO(), "consumer://"+localIP+"/com.foo.BarService")
router, _ := NewConditionRouterFactory().Router(getRouteUrlWithForce(rule, "true"))
fileredInvokers := router.(*ConditionRouter).Route(invokers, curl, inv)
assert.Equal(t, 0, len(fileredInvokers))
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package router
import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/extension"
)
func init() {
extension.SetRouterFactory("condition", NewConditionRouterFactory)
}
type ConditionRouterFactory struct{}
func NewConditionRouterFactory() cluster.RouterFactory {
return ConditionRouterFactory{}
}
func (c ConditionRouterFactory) Router(url *common.URL) (cluster.Router, error) {
return newConditionRouter(url)
}
......@@ -35,6 +35,7 @@ const (
const (
DEFAULT_KEY = "default"
PREFIX_DEFAULT_KEY = "default."
DEFAULT_SERVICE_FILTERS = "echo"
DEFAULT_REFERENCE_FILTERS = ""
ECHO = "$echo"
......
......@@ -67,6 +67,9 @@ const (
APP_VERSION_KEY = "app.version"
OWNER_KEY = "owner"
ENVIRONMENT_KEY = "environment"
METHOD_KEY = "method"
METHOD_KEYS = "methods"
RULE_KEY = "rule"
)
const (
......
package extension
import (
"github.com/apache/dubbo-go/cluster"
)
var (
routers = make(map[string]func() cluster.RouterFactory)
)
func SetRouterFactory(name string, fun func() cluster.RouterFactory) {
routers[name] = fun
}
func GetRouterFactory(name string) cluster.RouterFactory {
if routers[name] == nil {
panic("router_factory for " + name + " is not existing, make sure you have import the package.")
}
return routers[name]()
}
......@@ -20,6 +20,7 @@ package common
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"math"
"net"
......@@ -275,6 +276,11 @@ func (c URL) Service() string {
}
return ""
}
func (c *URL) AddParam(key string, value string) {
c.Params.Add(key, value)
}
func (c URL) GetParam(s string, d string) string {
var r string
if r = c.Params.Get(s); r == "" {
......@@ -282,6 +288,41 @@ func (c URL) GetParam(s string, d string) string {
}
return r
}
func (c URL) GetParamAndDecoded(key string) (string, error) {
ruleDec, err := base64.URLEncoding.DecodeString(c.GetParam(key, ""))
value := string(ruleDec)
return value, err
}
func (c URL) GetRawParam(key string) string {
switch key {
case "protocol":
return c.Protocol
case "username":
return c.Username
case "host":
return strings.Split(c.Location, ":")[0]
case "password":
return c.Password
case "port":
return c.Port
case "path":
return c.Path
default:
return c.Params.Get(key)
}
}
// GetParamBool
func (c URL) GetParamBool(s string, d bool) bool {
var r bool
var err error
if r, err = strconv.ParseBool(c.Params.Get(s)); err != nil {
return d
}
return r
}
func (c URL) GetParamInt(s string, d int64) int64 {
var r int
......@@ -318,6 +359,45 @@ func (c URL) GetMethodParam(method string, key string, d string) string {
return r
}
// ToMap transfer URL to Map
func (c URL) ToMap() map[string]string {
paramsMap := make(map[string]string)
for k, v := range c.Params {
paramsMap[k] = v[0]
}
if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
}
if c.Username != "" {
paramsMap["username"] = c.Username
}
if c.Password != "" {
paramsMap["password"] = c.Password
}
if c.Location != "" {
paramsMap["host"] = strings.Split(c.Location, ":")[0]
var port string
if strings.Contains(c.Location, ":") {
port = strings.Split(c.Location, ":")[1]
} else {
port = "0"
}
paramsMap["port"] = port
}
if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
}
if c.Path != "" {
paramsMap["path"] = c.Path
}
if len(paramsMap) == 0 {
return nil
}
return paramsMap
}
// configuration > reference config >service config
// in this function we should merge the reference local url config into the service url from registry.
//TODO configuration merge, in the future , the configuration center's config should merge too.
......
......@@ -19,6 +19,7 @@ package common
import (
"context"
"encoding/base64"
"net/url"
"testing"
)
......@@ -144,6 +145,54 @@ func TestURL_GetParamInt(t *testing.T) {
assert.Equal(t, int64(1), v)
}
func TestURL_GetParamBool(t *testing.T) {
params := url.Values{}
params.Set("force", "true")
u := URL{baseUrl: baseUrl{Params: params}}
v := u.GetParamBool("force", false)
assert.Equal(t, true, v)
u = URL{}
v = u.GetParamBool("force", false)
assert.Equal(t, false, v)
}
func TestURL_GetParamAndDecoded(t *testing.T) {
rule := "host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"
params := url.Values{}
params.Set("rule", base64.URLEncoding.EncodeToString([]byte(rule)))
u := URL{baseUrl: baseUrl{Params: params}}
v, _ := u.GetParamAndDecoded("rule")
assert.Equal(t, rule, v)
}
func TestURL_GetRawParam(t *testing.T) {
u, _ := NewURL(context.TODO(), "condition://0.0.0.0:8080/com.foo.BarService?serialization=fastjson")
u.Username = "test"
u.Password = "test"
assert.Equal(t, "condition", u.GetRawParam("protocol"))
assert.Equal(t, "0.0.0.0", u.GetRawParam("host"))
assert.Equal(t, "8080", u.GetRawParam("port"))
assert.Equal(t, "test", u.GetRawParam("username"))
assert.Equal(t, "test", u.GetRawParam("password"))
assert.Equal(t, "/com.foo.BarService", u.GetRawParam("path"))
assert.Equal(t, "fastjson", u.GetRawParam("serialization"))
}
func TestURL_ToMap(t *testing.T) {
u, _ := NewURL(context.TODO(), "condition://0.0.0.0:8080/com.foo.BarService?serialization=fastjson")
u.Username = "test"
u.Password = "test"
m := u.ToMap()
assert.Equal(t, 7, len(m))
assert.Equal(t, "condition", m["protocol"])
assert.Equal(t, "0.0.0.0", m["host"])
assert.Equal(t, "8080", m["port"])
assert.Equal(t, "test", m["username"])
assert.Equal(t, "test", m["password"])
assert.Equal(t, "/com.foo.BarService", m["path"])
assert.Equal(t, "fastjson", m["serialization"])
}
func TestURL_GetMethodParamInt(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3")
......
File added
......@@ -35,5 +35,6 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb78=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package container
import (
"fmt"
"strings"
)
var itemExists = struct{}{}
type HashSet struct {
Items map[interface{}]struct{}
}
func NewSet(values ...interface{}) *HashSet {
set := &HashSet{Items: make(map[interface{}]struct{})}
if len(values) > 0 {
set.Add(values...)
}
return set
}
func (set *HashSet) Add(items ...interface{}) {
for _, item := range items {
set.Items[item] = itemExists
}
}
func (set *HashSet) Remove(items ...interface{}) {
for _, item := range items {
delete(set.Items, item)
}
}
func (set *HashSet) Contains(items ...interface{}) bool {
for _, item := range items {
if _, contains := set.Items[item]; !contains {
return false
}
}
return true
}
func (set *HashSet) Empty() bool {
return set.Size() == 0
}
func (set *HashSet) Size() int {
return len(set.Items)
}
func (set *HashSet) Clear() {
set.Items = make(map[interface{}]struct{})
}
func (set *HashSet) Values() []interface{} {
values := make([]interface{}, set.Size())
count := 0
for item := range set.Items {
values[count] = item
count++
}
return values
}
func (set *HashSet) String() string {
str := "HashSet\n"
var items []string
for k := range set.Items {
items = append(items, fmt.Sprintf("%v", k))
}
str += strings.Join(items, ", ")
return str
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package container
import "testing"
func TestSetNew(t *testing.T) {
set := NewSet(2, 1)
if actualValue := set.Size(); actualValue != 2 {
t.Errorf("Got %v expected %v", actualValue, 2)
}
if actualValue := set.Contains(1); actualValue != true {
t.Errorf("Got %v expected %v", actualValue, true)
}
if actualValue := set.Contains(2); actualValue != true {
t.Errorf("Got %v expected %v", actualValue, true)
}
if actualValue := set.Contains(3); actualValue != false {
t.Errorf("Got %v expected %v", actualValue, true)
}
}
func TestSetAdd(t *testing.T) {
set := NewSet()
set.Add()
set.Add(1)
set.Add(2)
set.Add(2, 3)
set.Add()
if actualValue := set.Empty(); actualValue != false {
t.Errorf("Got %v expected %v", actualValue, false)
}
if actualValue := set.Size(); actualValue != 3 {
t.Errorf("Got %v expected %v", actualValue, 3)
}
}
func TestSetContains(t *testing.T) {
set := NewSet()
set.Add(3, 1, 2)
set.Add(2, 3)
set.Add()
if actualValue := set.Contains(); actualValue != true {
t.Errorf("Got %v expected %v", actualValue, true)
}
if actualValue := set.Contains(1); actualValue != true {
t.Errorf("Got %v expected %v", actualValue, true)
}
if actualValue := set.Contains(1, 2, 3); actualValue != true {
t.Errorf("Got %v expected %v", actualValue, true)
}
if actualValue := set.Contains(1, 2, 3, 4); actualValue != false {
t.Errorf("Got %v expected %v", actualValue, false)
}
}
func TestSetRemove(t *testing.T) {
set := NewSet()
set.Add(3, 1, 2)
set.Remove()
if actualValue := set.Size(); actualValue != 3 {
t.Errorf("Got %v expected %v", actualValue, 3)
}
set.Remove(1)
if actualValue := set.Size(); actualValue != 2 {
t.Errorf("Got %v expected %v", actualValue, 2)
}
set.Remove(3)
set.Remove(3)
set.Remove()
set.Remove(2)
if actualValue := set.Size(); actualValue != 0 {
t.Errorf("Got %v expected %v", actualValue, 0)
}
}
func benchmarkContains(b *testing.B, set *HashSet, size int) {
for i := 0; i < b.N; i++ {
for n := 0; n < size; n++ {
set.Contains(n)
}
}
}
func benchmarkAdd(b *testing.B, set *HashSet, size int) {
for i := 0; i < b.N; i++ {
for n := 0; n < size; n++ {
set.Add(n)
}
}
}
func benchmarkRemove(b *testing.B, set *HashSet, size int) {
for i := 0; i < b.N; i++ {
for n := 0; n < size; n++ {
set.Remove(n)
}
}
}
func BenchmarkHashSetContains100(b *testing.B) {
b.StopTimer()
size := 100
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkContains(b, set, size)
}
func BenchmarkHashSetContains1000(b *testing.B) {
b.StopTimer()
size := 1000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkContains(b, set, size)
}
func BenchmarkHashSetContains10000(b *testing.B) {
b.StopTimer()
size := 10000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkContains(b, set, size)
}
func BenchmarkHashSetContains100000(b *testing.B) {
b.StopTimer()
size := 100000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkContains(b, set, size)
}
func BenchmarkHashSetAdd100(b *testing.B) {
b.StopTimer()
size := 100
set := NewSet()
b.StartTimer()
benchmarkAdd(b, set, size)
}
func BenchmarkHashSetAdd1000(b *testing.B) {
b.StopTimer()
size := 1000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkAdd(b, set, size)
}
func BenchmarkHashSetAdd10000(b *testing.B) {
b.StopTimer()
size := 10000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkAdd(b, set, size)
}
func BenchmarkHashSetAdd100000(b *testing.B) {
b.StopTimer()
size := 100000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkAdd(b, set, size)
}
func BenchmarkHashSetRemove100(b *testing.B) {
b.StopTimer()
size := 100
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkRemove(b, set, size)
}
func BenchmarkHashSetRemove1000(b *testing.B) {
b.StopTimer()
size := 1000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkRemove(b, set, size)
}
func BenchmarkHashSetRemove10000(b *testing.B) {
b.StopTimer()
size := 10000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkRemove(b, set, size)
}
func BenchmarkHashSetRemove100000(b *testing.B) {
b.StopTimer()
size := 100000
set := NewSet()
for n := 0; n < size; n++ {
set.Add(n)
}
b.StartTimer()
benchmarkRemove(b, set, size)
}
......@@ -69,6 +69,34 @@ func NewRPCInvocationForProvider(methodName string, arguments []interface{}, att
}
}
type option func(invo *RPCInvocation)
func WithMethodName(methodName string) option {
return func(invo *RPCInvocation) {
invo.methodName = methodName
}
}
func WithParameterTypes(parameterTypes []reflect.Type) option {
return func(invo *RPCInvocation) {
invo.parameterTypes = parameterTypes
}
}
func WithArguments(arguments []interface{}) option {
return func(invo *RPCInvocation) {
invo.arguments = arguments
}
}
func NewRPCInvocationWithOptions(opts ...option) *RPCInvocation {
invo := &RPCInvocation{}
for _, opt := range opts {
opt(invo)
}
return invo
}
func (r *RPCInvocation) MethodName() string {
return r.methodName
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment