Skip to content
Snippets Groups Projects
Commit 47f300a0 authored by yì jí's avatar yì jí Committed by GitHub
Browse files

Merge pull request #66 from zonghaishang/loadbalance_round_robin

Loadbalance round robin
parents 89425d55 04e1e4ea
No related branches found
No related tags found
No related merge requests found
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package loadbalance
import (
"math"
"sync"
"sync/atomic"
"time"
)
import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/protocol"
)
const (
RoundRobin = "roundrobin"
COMPLETE = 0
UPDATING = 1
)
var (
methodWeightMap sync.Map // [string]invokers
state int32 = COMPLETE // update lock acquired ?
recyclePeriod int64 = 60 * time.Second.Nanoseconds()
)
func init() {
extension.SetLoadbalance(RoundRobin, NewRoundRobinLoadBalance)
}
type roundRobinLoadBalance struct{}
func NewRoundRobinLoadBalance() cluster.LoadBalance {
return &roundRobinLoadBalance{}
}
func (lb *roundRobinLoadBalance) Select(invokers []protocol.Invoker, invocation protocol.Invocation) protocol.Invoker {
count := len(invokers)
if count == 0 {
return nil
}
if count == 1 {
return invokers[0]
}
key := invokers[0].GetUrl().Path + "." + invocation.MethodName()
cache, _ := methodWeightMap.LoadOrStore(key, &cachedInvokers{})
cachedInvokers := cache.(*cachedInvokers)
var (
clean = false
totalWeight = int64(0)
maxCurrentWeight = int64(math.MinInt64)
now = time.Now()
selectedInvoker protocol.Invoker
selectedWeightRobin *weightedRoundRobin
)
for _, invoker := range invokers {
var weight = GetWeight(invoker, invocation)
if weight < 0 {
weight = 0
}
identifier := invoker.GetUrl().Key()
loaded, found := cachedInvokers.LoadOrStore(identifier, &weightedRoundRobin{weight: weight})
weightRobin := loaded.(*weightedRoundRobin)
if !found {
clean = true
}
if weightRobin.Weight() != weight {
weightRobin.setWeight(weight)
}
currentWeight := weightRobin.increaseCurrent()
weightRobin.lastUpdate = &now
if currentWeight > maxCurrentWeight {
maxCurrentWeight = currentWeight
selectedInvoker = invoker
selectedWeightRobin = weightRobin
}
totalWeight += weight
}
cleanIfRequired(clean, cachedInvokers, &now)
if selectedWeightRobin != nil {
selectedWeightRobin.Current(totalWeight)
return selectedInvoker
}
// should never happen
return invokers[0]
}
func cleanIfRequired(clean bool, invokers *cachedInvokers, now *time.Time) {
if clean && atomic.CompareAndSwapInt32(&state, COMPLETE, UPDATING) {
defer atomic.CompareAndSwapInt32(&state, UPDATING, COMPLETE)
invokers.Range(func(identify, robin interface{}) bool {
weightedRoundRobin := robin.(*weightedRoundRobin)
elapsed := now.Sub(*weightedRoundRobin.lastUpdate).Nanoseconds()
if elapsed > recyclePeriod {
invokers.Delete(identify)
}
return true
})
}
}
// Record the weight of the invoker
type weightedRoundRobin struct {
weight int64
current int64
lastUpdate *time.Time
}
func (robin *weightedRoundRobin) Weight() int64 {
return atomic.LoadInt64(&robin.weight)
}
func (robin *weightedRoundRobin) setWeight(weight int64) {
robin.weight = weight
robin.current = 0
}
func (robin *weightedRoundRobin) increaseCurrent() int64 {
return atomic.AddInt64(&robin.current, robin.weight)
}
func (robin *weightedRoundRobin) Current(delta int64) {
atomic.AddInt64(&robin.current, -1*delta)
}
type cachedInvokers struct {
sync.Map /*[string]weightedRoundRobin*/
}
package loadbalance
import (
"context"
"fmt"
"strconv"
"testing"
)
import (
"github.com/stretchr/testify/assert"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
)
func TestRoundRobinSelect(t *testing.T) {
loadBalance := NewRoundRobinLoadBalance()
var invokers []protocol.Invoker
url, _ := common.NewURL(context.TODO(), "dubbo://192.168.1.0:20000/org.apache.demo.HelloService")
invokers = append(invokers, protocol.NewBaseInvoker(url))
i := loadBalance.Select(invokers, &invocation.RPCInvocation{})
assert.True(t, i.GetUrl().URLEqual(url))
for i := 1; i < 10; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService", i))
invokers = append(invokers, protocol.NewBaseInvoker(url))
}
loadBalance.Select(invokers, &invocation.RPCInvocation{})
}
func TestRoundRobinByWeight(t *testing.T) {
loadBalance := NewRoundRobinLoadBalance()
var invokers []protocol.Invoker
loop := 10
for i := 1; i <= loop; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService?weight=%v", i, i))
invokers = append(invokers, protocol.NewBaseInvoker(url))
}
loop = (1 + loop) * loop / 2
selected := make(map[protocol.Invoker]int)
for i := 1; i <= loop; i++ {
invoker := loadBalance.Select(invokers, &invocation.RPCInvocation{})
selected[invoker]++
}
for _, i := range invokers {
w, _ := strconv.Atoi(i.GetUrl().GetParam("weight", "-1"))
assert.True(t, selected[i] == w)
}
}
......@@ -28,7 +28,8 @@ import (
func GetWeight(invoker protocol.Invoker, invocation protocol.Invocation) int64 {
url := invoker.GetUrl()
weight := url.GetMethodParamInt(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT)
weight := url.GetMethodParamInt64(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT)
if weight > 0 {
//get service register time an do warm up time
now := time.Now().Unix()
......
......@@ -20,6 +20,7 @@ package common
import (
"context"
"fmt"
"math"
"net"
"net/url"
"strconv"
......@@ -288,6 +289,15 @@ func (c URL) GetMethodParamInt(method string, key string, d int64) int64 {
return int64(r)
}
func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 {
r := c.GetMethodParamInt(method, key, math.MinInt64)
if r == math.MinInt64 {
return c.GetParamInt(key, d)
}
return r
}
func (c URL) GetMethodParam(method string, key string, d string) string {
var r string
if r = c.Params.Get("methods." + method + "." + key); r == "" {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment