Skip to content
Snippets Groups Projects
Commit aa705b60 authored by vito.he's avatar vito.he Committed by GitHub
Browse files

Merge pull request #270 from CodingSinger/sticky

support sticky connection
parents 86852f62 689b3473
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,7 @@ type baseClusterInvoker struct {
directory cluster.Directory
availablecheck bool
destroyed *atomic.Bool
stickyInvoker protocol.Invoker
}
func newBaseClusterInvoker(directory cluster.Directory) baseClusterInvoker {
......@@ -56,7 +57,9 @@ func (invoker *baseClusterInvoker) Destroy() {
}
func (invoker *baseClusterInvoker) IsAvailable() bool {
//TODO:sticky connection
if invoker.stickyInvoker != nil {
return invoker.stickyInvoker.IsAvailable()
}
return invoker.directory.IsAvailable()
}
......@@ -83,15 +86,42 @@ func (invoker *baseClusterInvoker) checkWhetherDestroyed() error {
}
func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker {
//todo:sticky connect
var selectedInvoker protocol.Invoker
url := invokers[0].GetUrl()
sticky := url.GetParamBool(constant.STICKY_KEY, false)
//Get the service method sticky config if have
sticky = url.GetMethodParamBool(invocation.MethodName(), constant.STICKY_KEY, sticky)
if invoker.stickyInvoker != nil && !isInvoked(invoker.stickyInvoker, invokers) {
invoker.stickyInvoker = nil
}
if sticky && invoker.stickyInvoker != nil && (invoked == nil || !isInvoked(invoker.stickyInvoker, invoked)) {
if invoker.availablecheck && invoker.stickyInvoker.IsAvailable() {
return invoker.stickyInvoker
}
}
selectedInvoker = invoker.doSelectInvoker(lb, invocation, invokers, invoked)
if sticky {
invoker.stickyInvoker = selectedInvoker
}
return selectedInvoker
}
func (invoker *baseClusterInvoker) doSelectInvoker(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker {
if len(invokers) == 1 {
return invokers[0]
}
selectedInvoker := lb.Select(invokers, invocation)
//judge to if the selectedInvoker is invoked
if !selectedInvoker.IsAvailable() || !invoker.availablecheck || isInvoked(selectedInvoker, invoked) {
if (!selectedInvoker.IsAvailable() && invoker.availablecheck) || isInvoked(selectedInvoker, invoked) {
// do reselect
var reslectInvokers []protocol.Invoker
......@@ -106,13 +136,12 @@ func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation p
}
if len(reslectInvokers) > 0 {
return lb.Select(reslectInvokers, invocation)
selectedInvoker = lb.Select(reslectInvokers, invocation)
} else {
return nil
}
}
return selectedInvoker
}
func isInvoked(selectedInvoker protocol.Invoker, invoked []protocol.Invoker) bool {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cluster_impl
import (
"context"
"fmt"
"testing"
)
import (
"github.com/stretchr/testify/assert"
)
import (
"github.com/apache/dubbo-go/cluster/loadbalance"
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
)
func Test_StickyNormal(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
url.SetParam("sticky", "true")
invokers = append(invokers, NewMockInvoker(url, 1))
}
base := &baseClusterInvoker{}
base.availablecheck = true
invoked := []protocol.Invoker{}
result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
assert.Equal(t, result, result1)
}
func Test_StickyNormalWhenError(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
url.SetParam("sticky", "true")
invokers = append(invokers, NewMockInvoker(url, 1))
}
base := &baseClusterInvoker{}
base.availablecheck = true
invoked := []protocol.Invoker{}
result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
invoked = append(invoked, result)
result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
assert.NotEqual(t, result, result1)
}
......@@ -67,7 +67,7 @@ func Test_FailbackSuceess(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := registerFailback(t, invoker).(*failbackClusterInvoker)
invoker.EXPECT().GetUrl().Return(failbackUrl).Times(1)
invoker.EXPECT().GetUrl().Return(failbackUrl).AnyTimes()
mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
......
......@@ -64,7 +64,7 @@ func Test_FailfastInvokeSuccess(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := registerFailfast(t, invoker)
invoker.EXPECT().GetUrl().Return(failfastUrl)
invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes()
mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
......@@ -84,7 +84,7 @@ func Test_FailfastInvokeFail(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := registerFailfast(t, invoker)
invoker.EXPECT().GetUrl().Return(failfastUrl)
invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes()
mockResult := &protocol.RPCResult{Err: perrors.New("error")}
......
......@@ -64,7 +64,7 @@ func Test_FailSafeInvokeSuccess(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := register_failsafe(t, invoker)
invoker.EXPECT().GetUrl().Return(failsafeUrl)
invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes()
mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
......@@ -83,7 +83,7 @@ func Test_FailSafeInvokeFail(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := register_failsafe(t, invoker)
invoker.EXPECT().GetUrl().Return(failsafeUrl)
invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes()
mockResult := &protocol.RPCResult{Err: perrors.New("error")}
......
......@@ -55,6 +55,7 @@ const (
WEIGHT_KEY = "weight"
WARMUP_KEY = "warmup"
RETRIES_KEY = "retries"
STICKY_KEY = "sticky"
BEAN_NAME = "bean.name"
FAIL_BACK_TASKS_KEY = "failbacktasks"
FORKS_KEY = "forks"
......
......@@ -447,6 +447,11 @@ func (c URL) GetMethodParam(method string, key string, d string) string {
return r
}
func (c URL) GetMethodParamBool(method string, key string, d bool) bool {
r := c.GetParamBool("methods."+method+"."+key, d)
return r
}
func (c *URL) RemoveParams(set *gxset.HashSet) {
c.paramsLock.Lock()
defer c.paramsLock.Unlock()
......
......@@ -217,6 +217,18 @@ func TestURL_GetMethodParam(t *testing.T) {
assert.Equal(t, "1s", v)
}
func TestURL_GetMethodParamBool(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.async", "true")
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParamBool("GetValue", "async", false)
assert.Equal(t, true, v)
u = URL{}
v = u.GetMethodParamBool("GetValue2", "async", false)
assert.Equal(t, false, v)
}
func TestMergeUrl(t *testing.T) {
referenceUrlParams := url.Values{}
referenceUrlParams.Set(constant.CLUSTER_KEY, "random")
......
......@@ -36,6 +36,7 @@ type MethodConfig struct {
TpsLimitStrategy string `yaml:"tps.limit.strategy" json:"tps.limit.strategy,omitempty" property:"tps.limit.strategy"`
ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"`
ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"`
Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"`
}
func (c *MethodConfig) Prefix() string {
......
......@@ -60,6 +60,7 @@ type ReferenceConfig struct {
invoker protocol.Invoker
urls []*common.URL
Generic bool `yaml:"generic" json:"generic,omitempty" property:"generic"`
Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"`
}
func (c *ReferenceConfig) Prefix() string {
......@@ -175,6 +176,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values {
urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
//getty invoke async or sync
urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.Async))
urlMap.Set(constant.STICKY_KEY, strconv.FormatBool(refconfig.Sticky))
//application info
urlMap.Set(constant.APPLICATION_KEY, consumerConfig.ApplicationConfig.Name)
......@@ -195,6 +197,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values {
for _, v := range refconfig.Methods {
urlMap.Set("methods."+v.Name+"."+constant.LOADBALANCE_KEY, v.Loadbalance)
urlMap.Set("methods."+v.Name+"."+constant.RETRIES_KEY, v.Retries)
urlMap.Set("methods."+v.Name+"."+constant.STICKY_KEY, strconv.FormatBool(v.Sticky))
}
return urlMap
......
......@@ -86,6 +86,7 @@ func doInitConsumer() {
"serviceid": "soa.mock",
"forks": "5",
},
Sticky: false,
Registry: "shanghai_reg1,shanghai_reg2,hangzhou_reg1,hangzhou_reg2",
InterfaceName: "com.MockService",
Protocol: "mock",
......@@ -104,6 +105,7 @@ func doInitConsumer() {
Name: "GetUser1",
Retries: "2",
Loadbalance: "random",
Sticky: true,
},
},
},
......@@ -291,6 +293,24 @@ func Test_Forking(t *testing.T) {
consumerConfig = nil
}
func Test_Sticky(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
m := consumerConfig.References["MockService"]
m.Url = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"
reference := consumerConfig.References["MockService"]
reference.Refer()
referenceSticky := reference.invoker.GetUrl().GetParam(constant.STICKY_KEY, "false")
assert.Equal(t, "false", referenceSticky)
method0StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[0].Name, constant.STICKY_KEY, "false")
assert.Equal(t, "false", method0StickKey)
method1StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[1].Name, constant.STICKY_KEY, "false")
assert.Equal(t, "true", method1StickKey)
}
func GetProtocol() protocol.Protocol {
if regProtocol != nil {
return regProtocol
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment