diff --git a/cluster/loadbalance/round_robin.go b/cluster/loadbalance/round_robin.go index a17f7c6c7551876f9896865da2f19355475ada3c..3169b3037bd6a6224209f123238ac923faae2aeb 100644 --- a/cluster/loadbalance/round_robin.go +++ b/cluster/loadbalance/round_robin.go @@ -66,7 +66,7 @@ func (lb *roundRobinLoadBalance) Select(invokers []protocol.Invoker, invocation totalWeight := int64(0) maxCurrentWeight := int64(math.MinInt64) var selectedInvoker protocol.Invoker - var selectedWeightRobin weightedRoundRobin + var selectedWeightRobin *weightedRoundRobin now := time.Now() for _, invoker := range invokers { @@ -92,18 +92,19 @@ func (lb *roundRobinLoadBalance) Select(invokers []protocol.Invoker, invocation if currentWeight > maxCurrentWeight { maxCurrentWeight = currentWeight selectedInvoker = invoker - selectedWeightRobin = *weightRobin + selectedWeightRobin = weightRobin } totalWeight += weight } cleanIfRequired(clean, cachedInvokers, &now) - if selectedInvoker != nil { + if selectedWeightRobin != nil { selectedWeightRobin.Current(totalWeight) return selectedInvoker } + // should never happen return invokers[0] } diff --git a/cluster/loadbalance/round_robin_test.go b/cluster/loadbalance/round_robin_test.go index fc07a61bda58c8906cbcfcc16a63331ca88d7279..75ead5765a90d30021553f5f6332ff5cb81086ff 100644 --- a/cluster/loadbalance/round_robin_test.go +++ b/cluster/loadbalance/round_robin_test.go @@ -3,6 +3,7 @@ package loadbalance import ( "context" "fmt" + "strconv" "testing" ) @@ -31,5 +32,28 @@ func TestRoundRobinSelect(t *testing.T) { invokers = append(invokers, protocol.NewBaseInvoker(url)) } loadBalance.Select(invokers, &invocation.RPCInvocation{}) +} + +func TestRoundRobinByWeight(t *testing.T) { + loadBalance := NewRoundRobinLoadBalance() + var invokers []protocol.Invoker + loop := 10 + for i := 1; i <= loop; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService?weight=%v", i, i)) + invokers = append(invokers, protocol.NewBaseInvoker(url)) + } + + loop = (1 + loop) * loop / 2 + selected := make(map[protocol.Invoker]int) + + for i := 1; i <= loop; i++ { + invoker := loadBalance.Select(invokers, &invocation.RPCInvocation{}) + selected[invoker]++ + } + + for _, i := range invokers { + w, _ := strconv.Atoi(i.GetUrl().GetParam("weight", "-1")) + assert.True(t, selected[i] == w) + } } diff --git a/cluster/loadbalance/util.go b/cluster/loadbalance/util.go index fdbeb14dd31fa8b2581f46d3adacf5355b12878b..84ef39c40cff32cfb479ea41924535a1ad17ee29 100644 --- a/cluster/loadbalance/util.go +++ b/cluster/loadbalance/util.go @@ -24,7 +24,8 @@ import ( func GetWeight(invoker protocol.Invoker, invocation protocol.Invocation) int64 { url := invoker.GetUrl() - weight := url.GetMethodParamInt(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT) + weight := url.GetMethodParamInt64(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT) + if weight > 0 { //get service register time an do warm up time now := time.Now().Unix() diff --git a/common/url.go b/common/url.go index 9b0c6352ae781a4f6eacb36e3f727e2b7ac55232..55843a03e07ddda88b7982c44a67b59f1889ad7f 100644 --- a/common/url.go +++ b/common/url.go @@ -17,6 +17,7 @@ package common import ( "context" "fmt" + "math" "net" "net/url" "strconv" @@ -285,6 +286,15 @@ func (c URL) GetMethodParamInt(method string, key string, d int64) int64 { return int64(r) } +func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 { + r := c.GetMethodParamInt(method, key, math.MinInt64) + if r == math.MinInt64 { + return c.GetParamInt(key, d) + } + + return r +} + func (c URL) GetMethodParam(method string, key string, d string) string { var r string if r = c.Params.Get("methods." + method + "." + key); r == "" {