diff --git a/cluster/loadbalance/least_active.go b/cluster/loadbalance/least_active.go
index aa69f3cc207ae7465bc6d5472bc075d0902c8978..a1e8516698d23118fdb42e855dabd1cb485ac41c 100644
--- a/cluster/loadbalance/least_active.go
+++ b/cluster/loadbalance/least_active.go
@@ -63,7 +63,7 @@ func (lb *leastActiveLoadBalance) Select(invokers []protocol.Invoker, invocation
for i := 0; i < count; i++ {
invoker := invokers[i]
// Active number
- active := protocol.GetStatus(invoker.GetUrl(), invocation.MethodName()).GetActive()
+ active := protocol.GetMethodStatus(invoker.GetUrl(), invocation.MethodName()).GetActive()
// current weight (maybe in warmUp)
weight := GetWeight(invoker, invocation)
// There are smaller active services
diff --git a/filter/filter_impl/active_filter.go b/filter/filter_impl/active_filter.go
index b12f776322986b46e6ab0ca878e9d83bf74822e8..e23e4dde74fdeb7f56c5ccad9caa3202e92882a4 100644
--- a/filter/filter_impl/active_filter.go
+++ b/filter/filter_impl/active_filter.go
@@ -17,14 +17,22 @@
package filter_impl
+import (
+ "strconv"
+)
+
import (
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
+ invocation2 "github.com/apache/dubbo-go/protocol/invocation"
)
-const active = "active"
+const (
+ active = "active"
+ dubboInvokeStartTime = "dubboInvokeStartTime"
+)
func init() {
extension.SetFilter(active, GetActiveFilter)
@@ -36,13 +44,21 @@ type ActiveFilter struct {
func (ef *ActiveFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking active filter. %v,%v", invocation.MethodName(), len(invocation.Arguments()))
+ invocation.(*invocation2.RPCInvocation).SetAttachments(dubboInvokeStartTime, strconv.FormatInt(protocol.CurrentTimeMillis(), 10))
protocol.BeginCount(invoker.GetUrl(), invocation.MethodName())
return invoker.Invoke(invocation)
}
func (ef *ActiveFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
- protocol.EndCount(invoker.GetUrl(), invocation.MethodName())
+ startTime, err := strconv.ParseInt(invocation.(*invocation2.RPCInvocation).AttachmentsByKey(dubboInvokeStartTime, "0"), 10, 64)
+ if err != nil {
+ result.SetError(err)
+ logger.Errorf("parse dubbo_invoke_start_time to int64 failed")
+ return result
+ }
+ elapsed := protocol.CurrentTimeMillis() - startTime
+ protocol.EndCount(invoker.GetUrl(), invocation.MethodName(), elapsed, result.Error() == nil)
return result
}
diff --git a/filter/filter_impl/active_filter_test.go b/filter/filter_impl/active_filter_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..acc4f9121641bfbbc484a711c0ea04dffeab55e3
--- /dev/null
+++ b/filter/filter_impl/active_filter_test.go
@@ -0,0 +1,66 @@
+package filter_impl
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "testing"
+)
+
+import (
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ "github.com/apache/dubbo-go/common"
+ "github.com/apache/dubbo-go/protocol"
+ "github.com/apache/dubbo-go/protocol/invocation"
+ "github.com/apache/dubbo-go/protocol/mock"
+)
+
+func TestActiveFilter_Invoke(t *testing.T) {
+ invoc := invocation.NewRPCInvocation("test", []interface{}{"OK"}, make(map[string]string, 0))
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ filter := ActiveFilter{}
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ invoker := mock.NewMockInvoker(ctrl)
+ invoker.EXPECT().Invoke(gomock.Any()).Return(nil)
+ invoker.EXPECT().GetUrl().Return(url).Times(1)
+ filter.Invoke(invoker, invoc)
+ assert.True(t, invoc.AttachmentsByKey(dubboInvokeStartTime, "") != "")
+
+}
+
+func TestActiveFilter_OnResponse(t *testing.T) {
+ c := protocol.CurrentTimeMillis()
+ elapsed := 100
+ invoc := invocation.NewRPCInvocation("test", []interface{}{"OK"}, map[string]string{
+ dubboInvokeStartTime: strconv.FormatInt(c-int64(elapsed), 10),
+ })
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ filter := ActiveFilter{}
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ invoker := mock.NewMockInvoker(ctrl)
+ invoker.EXPECT().GetUrl().Return(url).Times(1)
+ result := &protocol.RPCResult{
+ Err: errors.New("test"),
+ }
+ filter.OnResponse(result, invoker, invoc)
+ methodStatus := protocol.GetMethodStatus(url, "test")
+ urlStatus := protocol.GetURLStatus(url)
+
+ assert.Equal(t, int32(1), methodStatus.GetTotal())
+ assert.Equal(t, int32(1), urlStatus.GetTotal())
+ assert.Equal(t, int32(1), methodStatus.GetFailed())
+ assert.Equal(t, int32(1), urlStatus.GetFailed())
+ assert.Equal(t, int32(1), methodStatus.GetSuccessiveRequestFailureCount())
+ assert.Equal(t, int32(1), urlStatus.GetSuccessiveRequestFailureCount())
+ assert.True(t, methodStatus.GetFailedElapsed() >= int64(elapsed))
+ assert.True(t, urlStatus.GetFailedElapsed() >= int64(elapsed))
+ assert.True(t, urlStatus.GetLastRequestFailedTimestamp() != int64(0))
+ assert.True(t, methodStatus.GetLastRequestFailedTimestamp() != int64(0))
+
+}
diff --git a/protocol/rpc_status.go b/protocol/rpc_status.go
index 3a8bfbc87f285e0e86269d44c47d6771566d97b1..67f05e98020298b04096d2ba05874143324a7c7e 100644
--- a/protocol/rpc_status.go
+++ b/protocol/rpc_status.go
@@ -20,6 +20,7 @@ package protocol
import (
"sync"
"sync/atomic"
+ "time"
)
import (
@@ -27,18 +28,69 @@ import (
)
var (
- methodStatistics sync.Map // url -> { methodName : RpcStatus}
+ methodStatistics sync.Map // url -> { methodName : RPCStatus}
+ serviceStatistic sync.Map // url -> RPCStatus
)
-type RpcStatus struct {
- active int32
+type RPCStatus struct {
+ active int32
+ failed int32
+ total int32
+ totalElapsed int64
+ failedElapsed int64
+ maxElapsed int64
+ failedMaxElapsed int64
+ succeededMaxElapsed int64
+ successiveRequestFailureCount int32
+ lastRequestFailedTimestamp int64
}
-func (rpc *RpcStatus) GetActive() int32 {
+func (rpc *RPCStatus) GetActive() int32 {
return atomic.LoadInt32(&rpc.active)
}
-func GetStatus(url common.URL, methodName string) *RpcStatus {
+func (rpc *RPCStatus) GetFailed() int32 {
+ return atomic.LoadInt32(&rpc.failed)
+}
+
+func (rpc *RPCStatus) GetTotal() int32 {
+ return atomic.LoadInt32(&rpc.total)
+}
+
+func (rpc *RPCStatus) GetTotalElapsed() int64 {
+ return atomic.LoadInt64(&rpc.totalElapsed)
+}
+
+func (rpc *RPCStatus) GetFailedElapsed() int64 {
+ return atomic.LoadInt64(&rpc.failedElapsed)
+}
+
+func (rpc *RPCStatus) GetMaxElapsed() int64 {
+ return atomic.LoadInt64(&rpc.maxElapsed)
+}
+
+func (rpc *RPCStatus) GetFailedMaxElapsed() int64 {
+ return atomic.LoadInt64(&rpc.failedMaxElapsed)
+}
+
+func (rpc *RPCStatus) GetSucceededMaxElapsed() int64 {
+ return atomic.LoadInt64(&rpc.succeededMaxElapsed)
+}
+
+func (rpc *RPCStatus) GetLastRequestFailedTimestamp() int64 {
+ return atomic.LoadInt64(&rpc.lastRequestFailedTimestamp)
+}
+
+func (rpc *RPCStatus) GetSuccessiveRequestFailureCount() int32 {
+ return atomic.LoadInt32(&rpc.successiveRequestFailureCount)
+}
+
+func GetURLStatus(url common.URL) *RPCStatus {
+ rpcStatus, _ := serviceStatistic.LoadOrStore(url.Key(), &RPCStatus{})
+ return rpcStatus.(*RPCStatus)
+}
+
+func GetMethodStatus(url common.URL, methodName string) *RPCStatus {
identifier := url.Key()
methodMap, found := methodStatistics.Load(identifier)
if !found {
@@ -49,27 +101,53 @@ func GetStatus(url common.URL, methodName string) *RpcStatus {
methodActive := methodMap.(*sync.Map)
rpcStatus, found := methodActive.Load(methodName)
if !found {
- rpcStatus = &RpcStatus{}
+ rpcStatus = &RPCStatus{}
methodActive.Store(methodName, rpcStatus)
}
- status := rpcStatus.(*RpcStatus)
+ status := rpcStatus.(*RPCStatus)
return status
}
func BeginCount(url common.URL, methodName string) {
- beginCount0(GetStatus(url, methodName))
+ beginCount0(GetURLStatus(url))
+ beginCount0(GetMethodStatus(url, methodName))
}
-func EndCount(url common.URL, methodName string) {
- endCount0(GetStatus(url, methodName))
+func EndCount(url common.URL, methodName string, elapsed int64, succeeded bool) {
+ endCount0(GetURLStatus(url), elapsed, succeeded)
+ endCount0(GetMethodStatus(url, methodName), elapsed, succeeded)
}
// private methods
-func beginCount0(rpcStatus *RpcStatus) {
+func beginCount0(rpcStatus *RPCStatus) {
atomic.AddInt32(&rpcStatus.active, 1)
}
-func endCount0(rpcStatus *RpcStatus) {
+func endCount0(rpcStatus *RPCStatus, elapsed int64, succeeded bool) {
atomic.AddInt32(&rpcStatus.active, -1)
+ atomic.AddInt32(&rpcStatus.total, 1)
+ atomic.AddInt64(&rpcStatus.totalElapsed, elapsed)
+
+ if rpcStatus.maxElapsed < elapsed {
+ atomic.StoreInt64(&rpcStatus.maxElapsed, elapsed)
+ }
+ if succeeded {
+ if rpcStatus.succeededMaxElapsed < elapsed {
+ atomic.StoreInt64(&rpcStatus.succeededMaxElapsed, elapsed)
+ }
+ atomic.StoreInt32(&rpcStatus.successiveRequestFailureCount, 0)
+ } else {
+ atomic.StoreInt64(&rpcStatus.lastRequestFailedTimestamp, time.Now().Unix())
+ atomic.AddInt32(&rpcStatus.successiveRequestFailureCount, 1)
+ atomic.AddInt32(&rpcStatus.failed, 1)
+ atomic.AddInt64(&rpcStatus.failedElapsed, elapsed)
+ if rpcStatus.failedMaxElapsed < elapsed {
+ atomic.StoreInt64(&rpcStatus.failedMaxElapsed, elapsed)
+ }
+ }
+}
+
+func CurrentTimeMillis() int64 {
+ return time.Now().UnixNano() / int64(time.Millisecond)
}
diff --git a/protocol/rpc_status_test.go b/protocol/rpc_status_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..ce2b4dc0d0fae2b271dbaeb3fdafab8858a7aa0c
--- /dev/null
+++ b/protocol/rpc_status_test.go
@@ -0,0 +1,152 @@
+package protocol
+
+import (
+ "context"
+ "strconv"
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+import (
+ "github.com/apache/dubbo-go/common"
+)
+
+func TestBeginCount(t *testing.T) {
+ defer destroy()
+
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ BeginCount(url, "test")
+ urlStatus := GetURLStatus(url)
+ methodStatus := GetMethodStatus(url, "test")
+ methodStatus1 := GetMethodStatus(url, "test1")
+ assert.Equal(t, int32(1), methodStatus.active)
+ assert.Equal(t, int32(1), urlStatus.active)
+ assert.Equal(t, int32(0), methodStatus1.active)
+
+}
+
+func TestEndCount(t *testing.T) {
+ defer destroy()
+
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ EndCount(url, "test", 100, true)
+ urlStatus := GetURLStatus(url)
+ methodStatus := GetMethodStatus(url, "test")
+ assert.Equal(t, int32(-1), methodStatus.active)
+ assert.Equal(t, int32(-1), urlStatus.active)
+ assert.Equal(t, int32(1), methodStatus.total)
+ assert.Equal(t, int32(1), urlStatus.total)
+}
+
+func TestGetMethodStatus(t *testing.T) {
+ defer destroy()
+
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ status := GetMethodStatus(url, "test")
+ assert.NotNil(t, status)
+ assert.Equal(t, int32(0), status.total)
+}
+
+func TestGetUrlStatus(t *testing.T) {
+ defer destroy()
+
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ status := GetURLStatus(url)
+ assert.NotNil(t, status)
+ assert.Equal(t, int32(0), status.total)
+}
+
+func Test_beginCount0(t *testing.T) {
+ defer destroy()
+
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ status := GetURLStatus(url)
+ beginCount0(status)
+ assert.Equal(t, int32(1), status.active)
+}
+
+func Test_All(t *testing.T) {
+ defer destroy()
+
+ url, _ := common.NewURL(context.TODO(), "dubbo://192.168.10.10:20000/com.ikurento.user.UserProvider")
+ request(url, "test", 100, false, true)
+ urlStatus := GetURLStatus(url)
+ methodStatus := GetMethodStatus(url, "test")
+ assert.Equal(t, int32(1), methodStatus.total)
+ assert.Equal(t, int32(1), urlStatus.total)
+ assert.Equal(t, int32(0), methodStatus.active)
+ assert.Equal(t, int32(0), urlStatus.active)
+ assert.Equal(t, int32(0), methodStatus.failed)
+ assert.Equal(t, int32(0), urlStatus.failed)
+ assert.Equal(t, int32(0), methodStatus.successiveRequestFailureCount)
+ assert.Equal(t, int32(0), urlStatus.successiveRequestFailureCount)
+ assert.Equal(t, int64(100), methodStatus.totalElapsed)
+ assert.Equal(t, int64(100), urlStatus.totalElapsed)
+ request(url, "test", 100, false, false)
+ request(url, "test", 100, false, false)
+ request(url, "test", 100, false, false)
+ request(url, "test", 100, false, false)
+ request(url, "test", 100, false, false)
+ assert.Equal(t, int32(6), methodStatus.total)
+ assert.Equal(t, int32(6), urlStatus.total)
+ assert.Equal(t, int32(5), methodStatus.failed)
+ assert.Equal(t, int32(5), urlStatus.failed)
+ assert.Equal(t, int32(5), methodStatus.successiveRequestFailureCount)
+ assert.Equal(t, int32(5), urlStatus.successiveRequestFailureCount)
+ assert.Equal(t, int64(600), methodStatus.totalElapsed)
+ assert.Equal(t, int64(600), urlStatus.totalElapsed)
+ assert.Equal(t, int64(500), methodStatus.failedElapsed)
+ assert.Equal(t, int64(500), urlStatus.failedElapsed)
+
+ request(url, "test", 100, false, true)
+ assert.Equal(t, int32(0), methodStatus.successiveRequestFailureCount)
+ assert.Equal(t, int32(0), urlStatus.successiveRequestFailureCount)
+
+ request(url, "test", 200, false, false)
+ request(url, "test", 200, false, false)
+ assert.Equal(t, int32(2), methodStatus.successiveRequestFailureCount)
+ assert.Equal(t, int32(2), urlStatus.successiveRequestFailureCount)
+ assert.Equal(t, int64(200), methodStatus.maxElapsed)
+ assert.Equal(t, int64(200), urlStatus.maxElapsed)
+
+ request(url, "test1", 200, false, false)
+ request(url, "test1", 200, false, false)
+ request(url, "test1", 200, false, false)
+ assert.Equal(t, int32(5), urlStatus.successiveRequestFailureCount)
+ methodStatus1 := GetMethodStatus(url, "test1")
+ assert.Equal(t, int32(2), methodStatus.successiveRequestFailureCount)
+ assert.Equal(t, int32(3), methodStatus1.successiveRequestFailureCount)
+
+}
+
+func request(url common.URL, method string, elapsed int64, active, succeeded bool) {
+ BeginCount(url, method)
+ if !active {
+ EndCount(url, method, elapsed, succeeded)
+ }
+}
+
+func TestCurrentTimeMillis(t *testing.T) {
+ defer destroy()
+ c := CurrentTimeMillis()
+ assert.NotNil(t, c)
+ str := strconv.FormatInt(c, 10)
+ i, _ := strconv.ParseInt(str, 10, 64)
+ assert.Equal(t, c, i)
+}
+
+func destroy() {
+ delete1 := func(key interface{}, value interface{}) bool {
+ methodStatistics.Delete(key)
+ return true
+ }
+ methodStatistics.Range(delete1)
+ delete2 := func(key interface{}, value interface{}) bool {
+ serviceStatistic.Delete(key)
+ return true
+ }
+ serviceStatistic.Range(delete2)
+}