diff --git a/cluster/router/condition_router.go b/cluster/router/condition_router.go index f8f05f116ff43b969ad9a1e4b8589368edd7cf93..f1d12af38d202512ae2135f60accaff3fdcfdd08 100644 --- a/cluster/router/condition_router.go +++ b/cluster/router/condition_router.go @@ -21,14 +21,18 @@ import ( "reflect" "regexp" "strings" +) +import ( "github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/common/utils" "github.com/apache/dubbo-go/gostd/container" "github.com/apache/dubbo-go/protocol" +) +import ( perrors "github.com/pkg/errors" ) @@ -56,7 +60,7 @@ func newConditionRouter(url *common.URL) (*ConditionRouter, error) { then map[string]MatchPair ) rule, err := url.GetParamAndDecoded(constant.RULE_KEY) - if err != nil || rule == "" { + if err != nil || len(rule) == 0 { return nil, perrors.Errorf("Illegal route rule!") } rule = strings.Replace(rule, "consumer.", "", -1) @@ -80,13 +84,13 @@ func newConditionRouter(url *common.URL) (*ConditionRouter, error) { if err != nil { return nil, perrors.Errorf("%s", "") } - if whenRule == "" || "true" == whenRule { - when = make(map[string]MatchPair) + if len(whenRule) == 0 || "true" == whenRule { + when = make(map[string]MatchPair, 16) } else { when = w } - if thenRule == "" || "false" == thenRule { - when = make(map[string]MatchPair) + if len(thenRule) == 0 || "false" == thenRule { + when = make(map[string]MatchPair, 16) } else { then = t } @@ -147,8 +151,8 @@ func (c *ConditionRouter) Route(invokers []protocol.Invoker, url common.URL, inv } func parseRule(rule string) (map[string]MatchPair, error) { - condition := make(map[string]MatchPair) - if rule == "" { + condition := make(map[string]MatchPair, 16) + if len(rule) == 0 { return condition, nil } var pair MatchPair @@ -230,11 +234,11 @@ func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U sampleValue = invocation.MethodName() } else { sampleValue = sample[key] - if sampleValue == "" { + if len(sampleValue) == 0 { sampleValue = sample[constant.PREFIX_DEFAULT_KEY+key] } } - if sampleValue != "" { + if len(sampleValue) > 0 { if !matchPair.isMatch(sampleValue, param) { return false, nil } else { @@ -251,13 +255,6 @@ func MatchCondition(pairs map[string]MatchPair, url *common.URL, param *common.U return result, nil } -func If(b bool, t, f interface{}) interface{} { - if b { - return t - } - return f -} - type MatchPair struct { Matches *container.HashSet Mismatches *container.HashSet @@ -305,10 +302,10 @@ func isMatchGlobPattern(pattern string, value string, param *common.URL) bool { if "*" == pattern { return true } - if pattern == "" && value == "" { + if len(pattern) == 0 && len(value) == 0 { return true } - if pattern == "" || value == "" { + if len(pattern) == 0 || len(value) == 0 { return false } i := strings.LastIndex(pattern, "*") diff --git a/cluster/router/condition_router_test.go b/cluster/router/condition_router_test.go index e60cda31791e097bb83b16da0ab30385662d95cd..b4c6a829d667765c903faf61bfbd68c290b5a6ad 100644 --- a/cluster/router/condition_router_test.go +++ b/cluster/router/condition_router_test.go @@ -23,13 +23,17 @@ import ( "fmt" "reflect" "testing" +) +import ( "github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/common/utils" "github.com/apache/dubbo-go/protocol" "github.com/apache/dubbo-go/protocol/invocation" +) +import ( perrors "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -176,7 +180,7 @@ func TestRoute_matchFilter(t *testing.T) { } func TestRoute_methodRoute(t *testing.T) { - inv := invocation.NewRPCInvocationForUT("getFoo", []reflect.Type{}, []interface{}{}) + inv := invocation.NewRPCInvocationWithOptions(invocation.WithMethodName("getFoo"), invocation.WithParameterTypes([]reflect.Type{}), invocation.WithArguments([]interface{}{})) rule := base64.URLEncoding.EncodeToString([]byte("host !=4.4.4.* & host = 2.2.2.2,1.1.1.1,3.3.3.3 => host = 1.2.3.4")) router, _ := NewConditionRouterFactory().Router(getRouteUrl(rule)) url, _ := common.NewURL(context.TODO(), "consumer://1.1.1.1/com.foo.BarService?methods=setFoo,getFoo,findFoo") diff --git a/common/url.go b/common/url.go index 5343bf58512259d97afb6870758b073de4b4a674..47f875101f690a46e6281f7b14c6b026e77a79ad 100644 --- a/common/url.go +++ b/common/url.go @@ -300,25 +300,22 @@ func (c URL) GetParamAndDecoded(key string) (string, error) { } func (c URL) GetRawParam(key string) string { - if "protocol" == key { + switch key { + case "protocol": return c.Protocol - } - if "username" == key { + case "username": return c.Username - } - if "host" == key { + case "host": return strings.Split(c.Location, ":")[0] - } - if "password" == key { + case "password": return c.Password - } - if "port" == key { + case "port": return c.Port - } - if "path" == key { + case "path": return c.Path + default: + return c.Params.Get(key) } - return c.Params.Get(key) } // GetParamBool diff --git a/protocol/invocation/rpcinvocation.go b/protocol/invocation/rpcinvocation.go index 183596eecb39bf560738b37aabcfc4e6e49ac230..d515cc4c8ad4bcdcc88eccd4b1e8ddb545a17315 100644 --- a/protocol/invocation/rpcinvocation.go +++ b/protocol/invocation/rpcinvocation.go @@ -69,12 +69,32 @@ func NewRPCInvocationForProvider(methodName string, arguments []interface{}, att } } -func NewRPCInvocationForUT(methodName string, parameterTypes []reflect.Type, arguments []interface{}) *RPCInvocation { - return &RPCInvocation{ - methodName: methodName, - arguments: arguments, - parameterTypes: parameterTypes, +type option func(invo *RPCInvocation) + +func WithMethodName(methodName string) option { + return func(invo *RPCInvocation) { + invo.methodName = methodName + } +} + +func WithParameterTypes(parameterTypes []reflect.Type) option { + return func(invo *RPCInvocation) { + invo.parameterTypes = parameterTypes + } +} + +func WithArguments(arguments []interface{}) option { + return func(invo *RPCInvocation) { + invo.arguments = arguments + } +} + +func NewRPCInvocationWithOptions(opts ...option) *RPCInvocation { + invo := &RPCInvocation{} + for _, opt := range opts { + opt(invo) } + return invo } func (r *RPCInvocation) MethodName() string {