From dec6c16d4e70f7f06b4dd5af6a1ac557bbc53a39 Mon Sep 17 00:00:00 2001
From: Ming Deng <mindeng@ebay.com>
Date: Mon, 14 Oct 2019 22:24:25 +0800
Subject: [PATCH] Refactor MethodServiceTpsLimitImpl; Read config from yml

---
 config/method_config.go                       | 15 +--
 config/service_config.go                      | 78 +++++++++-------
 .../impl/tps_limiter_method_service_impl.go   | 92 +++++++++----------
 3 files changed, 102 insertions(+), 83 deletions(-)

diff --git a/config/method_config.go b/config/method_config.go
index ac9242a23..431a30c1d 100644
--- a/config/method_config.go
+++ b/config/method_config.go
@@ -25,12 +25,15 @@ import (
 )
 
 type MethodConfig struct {
-	InterfaceId   string
-	InterfaceName string
-	Name          string `yaml:"name"  json:"name,omitempty" property:"name"`
-	Retries       string `yaml:"retries"  json:"retries,omitempty" property:"retries"`
-	Loadbalance   string `yaml:"loadbalance"  json:"loadbalance,omitempty" property:"loadbalance"`
-	Weight        int64  `yaml:"weight"  json:"weight,omitempty" property:"weight"`
+	InterfaceId      string
+	InterfaceName    string
+	Name             string `yaml:"name"  json:"name,omitempty" property:"name"`
+	Retries          string `yaml:"retries"  json:"retries,omitempty" property:"retries"`
+	Loadbalance      string `yaml:"loadbalance"  json:"loadbalance,omitempty" property:"loadbalance"`
+	Weight           int64  `yaml:"weight"  json:"weight,omitempty" property:"weight"`
+	TpsLimitInterval string `yaml:"tps.limit.interval" json:"tps.limit.interval,omitempty" property:"tps.limit.interval"`
+	TpsLimitRate     string `yaml:"tps.limit.rate" json:"tps.limit.rate,omitempty" property:"tps.limit.rate"`
+	TpsLimitStrategy string `yaml:"tps.limit.strategy" json:"tps.limit.strategy,omitempty" property:"tps.limit.strategy"`
 }
 
 func (c *MethodConfig) Prefix() string {
diff --git a/config/service_config.go b/config/service_config.go
index ee0457937..784257b8c 100644
--- a/config/service_config.go
+++ b/config/service_config.go
@@ -43,27 +43,32 @@ import (
 )
 
 type ServiceConfig struct {
-	context       context.Context
-	id            string
-	Filter        string            `yaml:"filter" json:"filter,omitempty" property:"filter"`
-	Protocol      string            `default:"dubbo"  required:"true"  yaml:"protocol"  json:"protocol,omitempty" property:"protocol"` //multi protocol support, split by ','
-	InterfaceName string            `required:"true"  yaml:"interface"  json:"interface,omitempty" property:"interface"`
-	Registry      string            `yaml:"registry"  json:"registry,omitempty"  property:"registry"`
-	Cluster       string            `default:"failover" yaml:"cluster"  json:"cluster,omitempty" property:"cluster"`
-	Loadbalance   string            `default:"random" yaml:"loadbalance"  json:"loadbalance,omitempty"  property:"loadbalance"`
-	Group         string            `yaml:"group"  json:"group,omitempty" property:"group"`
-	Version       string            `yaml:"version"  json:"version,omitempty" property:"version" `
-	Methods       []*MethodConfig   `yaml:"methods"  json:"methods,omitempty" property:"methods"`
-	Warmup        string            `yaml:"warmup"  json:"warmup,omitempty"  property:"warmup"`
-	Retries       string            `yaml:"retries"  json:"retries,omitempty" property:"retries"`
-	Params        map[string]string `yaml:"params"  json:"params,omitempty" property:"params"`
-	Token         string            `yaml:"token" json:"token,omitempty" property:"token"`
-	AccessLog     string            `yaml:"accesslog" json:"accesslog,omitempty" property:"accesslog"`
-	unexported    *atomic.Bool
-	exported      *atomic.Bool
-	rpcService    common.RPCService
-	cacheProtocol protocol.Protocol
-	cacheMutex    sync.Mutex
+	context                 context.Context
+	id                      string
+	Filter                  string            `yaml:"filter" json:"filter,omitempty" property:"filter"`
+	Protocol                string            `default:"dubbo"  required:"true"  yaml:"protocol"  json:"protocol,omitempty" property:"protocol"` // multi protocol support, split by ','
+	InterfaceName           string            `required:"true"  yaml:"interface"  json:"interface,omitempty" property:"interface"`
+	Registry                string            `yaml:"registry"  json:"registry,omitempty"  property:"registry"`
+	Cluster                 string            `default:"failover" yaml:"cluster"  json:"cluster,omitempty" property:"cluster"`
+	Loadbalance             string            `default:"random" yaml:"loadbalance"  json:"loadbalance,omitempty"  property:"loadbalance"`
+	Group                   string            `yaml:"group"  json:"group,omitempty" property:"group"`
+	Version                 string            `yaml:"version"  json:"version,omitempty" property:"version" `
+	Methods                 []*MethodConfig   `yaml:"methods"  json:"methods,omitempty" property:"methods"`
+	Warmup                  string            `yaml:"warmup"  json:"warmup,omitempty"  property:"warmup"`
+	Retries                 string            `yaml:"retries"  json:"retries,omitempty" property:"retries"`
+	Params                  map[string]string `yaml:"params"  json:"params,omitempty" property:"params"`
+	Token                   string            `yaml:"token" json:"token,omitempty" property:"token"`
+	AccessLog               string            `yaml:"accesslog" json:"accesslog,omitempty" property:"accesslog"`
+	TpsLimiter              string            `yaml:"tps.limiter" json:"tps.limiter,omitempty" property:"tps.limiter"`
+	TpsLimitInterval        string            `yaml:"tps.limit.interval" json:"tps.limit.interval,omitempty" property:"tps.limit.interval"`
+	TpsLimitRate            string            `yaml:"tps.limit.rate" json:"tps.limit.rate,omitempty" property:"tps.limit.rate"`
+	TpsLimitStrategy        string            `yaml:"tps.limit.strategy" json:"tps.limit.strategy,omitempty" property:"tps.limit.strategy"`
+	TpsLimitRejectedHandler string            `yaml:"tps.limit.rejected.handler" json:"tps.limit.rejected.handler,omitempty" property:"tps.limit.rejected.handler"`
+	unexported              *atomic.Bool
+	exported                *atomic.Bool
+	rpcService              common.RPCService
+	cacheProtocol           protocol.Protocol
+	cacheMutex              sync.Mutex
 }
 
 func (c *ServiceConfig) Prefix() string {
@@ -94,9 +99,9 @@ func NewServiceConfig(id string, context context.Context) *ServiceConfig {
 }
 
 func (srvconfig *ServiceConfig) Export() error {
-	//TODO: config center start here
+	// TODO: config center start here
 
-	//TODO:delay export
+	// TODO:delay export
 	if srvconfig.unexported != nil && srvconfig.unexported.Load() {
 		err := perrors.Errorf("The service %v has already unexported! ", srvconfig.InterfaceName)
 		logger.Errorf(err.Error())
@@ -111,7 +116,7 @@ func (srvconfig *ServiceConfig) Export() error {
 	urlMap := srvconfig.getUrlMap()
 
 	for _, proto := range loadProtocol(srvconfig.Protocol, providerConfig.Protocols) {
-		//registry the service reflect
+		// registry the service reflect
 		methods, err := common.ServiceMap.Register(proto.Name, srvconfig.rpcService)
 		if err != nil {
 			err := perrors.Errorf("The service %v  export the protocol %v error! Error message is %v .", srvconfig.InterfaceName, proto.Name, err.Error())
@@ -164,7 +169,7 @@ func (srvconfig *ServiceConfig) Implement(s common.RPCService) {
 
 func (srvconfig *ServiceConfig) getUrlMap() url.Values {
 	urlMap := url.Values{}
-	//first set user params
+	// first set user params
 	for k, v := range srvconfig.Params {
 		urlMap.Set(k, v)
 	}
@@ -177,7 +182,7 @@ func (srvconfig *ServiceConfig) getUrlMap() url.Values {
 	urlMap.Set(constant.GROUP_KEY, srvconfig.Group)
 	urlMap.Set(constant.VERSION_KEY, srvconfig.Version)
 	urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.PROVIDER))
-	//application info
+	// application info
 	urlMap.Set(constant.APPLICATION_KEY, providerConfig.ApplicationConfig.Name)
 	urlMap.Set(constant.ORGANIZATION_KEY, providerConfig.ApplicationConfig.Organization)
 	urlMap.Set(constant.NAME_KEY, providerConfig.ApplicationConfig.Name)
@@ -186,16 +191,27 @@ func (srvconfig *ServiceConfig) getUrlMap() url.Values {
 	urlMap.Set(constant.OWNER_KEY, providerConfig.ApplicationConfig.Owner)
 	urlMap.Set(constant.ENVIRONMENT_KEY, providerConfig.ApplicationConfig.Environment)
 
-	//filter
+	// filter
 	urlMap.Set(constant.SERVICE_FILTER_KEY, mergeValue(providerConfig.Filter, srvconfig.Filter, constant.DEFAULT_SERVICE_FILTERS))
 
-	//filter special config
+	// filter special config
 	urlMap.Set(constant.ACCESS_LOG_KEY, srvconfig.AccessLog)
+	// tps limiter
+	urlMap.Set(constant.TPS_LIMIT_STRATEGY_KEY, srvconfig.TpsLimitStrategy)
+	urlMap.Set(constant.TPS_LIMIT_INTERVAL_KEY, srvconfig.TpsLimitInterval)
+	urlMap.Set(constant.TPS_LIMIT_RATE_KEY, srvconfig.TpsLimitRate)
+	urlMap.Set(constant.TPS_LIMITER_KEY, srvconfig.TpsLimiter)
+	urlMap.Set(constant.TPS_REJECTED_EXECUTION_HANDLER_KEY, srvconfig.TpsLimitRejectedHandler)
 
 	for _, v := range srvconfig.Methods {
-		urlMap.Set("methods."+v.Name+"."+constant.LOADBALANCE_KEY, v.Loadbalance)
-		urlMap.Set("methods."+v.Name+"."+constant.RETRIES_KEY, v.Retries)
-		urlMap.Set("methods."+v.Name+"."+constant.WEIGHT_KEY, strconv.FormatInt(v.Weight, 10))
+		prefix := "methods." + v.Name + "."
+		urlMap.Set(prefix+constant.LOADBALANCE_KEY, v.Loadbalance)
+		urlMap.Set(prefix+constant.RETRIES_KEY, v.Retries)
+		urlMap.Set(prefix+constant.WEIGHT_KEY, strconv.FormatInt(v.Weight, 10))
+
+		urlMap.Set(prefix+constant.TPS_LIMIT_STRATEGY_KEY, srvconfig.TpsLimitStrategy)
+		urlMap.Set(prefix+constant.TPS_LIMIT_INTERVAL_KEY, srvconfig.TpsLimitInterval)
+		urlMap.Set(prefix+constant.TPS_LIMIT_RATE_KEY, srvconfig.TpsLimitRate)
 	}
 
 	return urlMap
diff --git a/filter/impl/tps_limiter_method_service_impl.go b/filter/impl/tps_limiter_method_service_impl.go
index 8ea32c2ae..2a30c30b1 100644
--- a/filter/impl/tps_limiter_method_service_impl.go
+++ b/filter/impl/tps_limiter_method_service_impl.go
@@ -46,8 +46,8 @@ func init() {
  *   interface : "com.ikurento.user.UserProvider"
  *   ... # other configuration
  *   tps.limiter: "method-service" # the name of MethodServiceTpsLimiterImpl. It's the default limiter too.
- *   tps.interval: 5000 # interval, the time unit is ms
- *   tps.rate: 300 # the max value in the interval. <0 means that the service will not be limited.
+ *   tps.limit.interval: 5000 # interval, the time unit is ms
+ *   tps.limit.rate: 300 # the max value in the interval. <0 means that the service will not be limited.
  *   methods:
  *    - name: "GetUser"
  *      tps.interval: 3000
@@ -61,66 +61,66 @@ type MethodServiceTpsLimiterImpl struct {
 
 func (limiter MethodServiceTpsLimiterImpl) IsAllowable(url common.URL, invocation protocol.Invocation) bool {
 
-	serviceLimitRate, err:= strconv.ParseInt(url.GetParam(constant.TPS_LIMIT_RATE_KEY,
-		constant.DEFAULT_TPS_LIMIT_RATE), 0, 0)
+	methodConfigPrefix := "methods." + invocation.MethodName() + "."
 
-	if err != nil {
-		panic(fmt.Sprintf("Can not parse the %s for url %s, please check your configuration!",
-			constant.TPS_LIMIT_RATE_KEY, url.String()))
-	}
-	methodLimitRateConfig := invocation.AttachmentsByKey(constant.TPS_LIMIT_RATE_KEY, "")
+	methodLimitRateConfig := url.GetParam(methodConfigPrefix+constant.TPS_LIMIT_RATE_KEY, "")
+	methodIntervalConfig := url.GetParam(methodConfigPrefix+constant.TPS_LIMIT_INTERVAL_KEY, "")
 
-	// both method-level and service-level don't have the configuration of tps limit
-	if serviceLimitRate < 0 && len(methodLimitRateConfig) <= 0 {
-		return true
+	limitTarget := url.ServiceKey()
+
+	// method-level tps limit
+	if len(methodIntervalConfig) > 0 || len(methodLimitRateConfig) > 0 {
+		limitTarget = limitTarget + "#" + invocation.MethodName()
 	}
 
-	limitRate := serviceLimitRate
-	// the method has tps limit configuration
-	if len(methodLimitRateConfig) >0 {
-		limitRate, err = strconv.ParseInt(methodLimitRateConfig, 0, 0)
-		if err != nil {
-			panic(fmt.Sprintf("Can not parse the %s for invocation %s # %s, please check your configuration!",
-				constant.TPS_LIMIT_RATE_KEY, url.ServiceKey(), invocation.MethodName()))
-		}
+	limitState, found := limiter.tpsState.Load(limitTarget)
+	if found {
+		return limitState.(filter.TpsLimitStrategy).IsAllowable()
 	}
 
-	// 1. the serviceLimitRate < 0 and methodRateConfig is empty string
-	// 2. the methodLimitRate < 0
-	if limitRate < 0{
+	limitRate := getLimitConfig(methodLimitRateConfig, url, invocation,
+		constant.TPS_LIMIT_RATE_KEY,
+		constant.DEFAULT_TPS_LIMIT_RATE)
+
+	if limitRate < 0 {
 		return true
 	}
 
-	serviceInterval, err := strconv.ParseInt(url.GetParam(constant.TPS_LIMIT_INTERVAL_KEY,
-		constant.DEFAULT_TPS_LIMIT_INTERVAL), 0, 0)
-
-	if err != nil || serviceInterval <= 0{
-		panic(fmt.Sprintf("The %s must be positive, please check your configuration!",
-			constant.TPS_LIMIT_INTERVAL_KEY))
+	limitInterval := getLimitConfig(methodIntervalConfig, url, invocation,
+		constant.TPS_LIMIT_INTERVAL_KEY,
+		constant.DEFAULT_TPS_LIMIT_INTERVAL)
+	if limitInterval <= 0{
+		panic(fmt.Sprintf("The interval must be positive, please check your configuration! url: %s", url.String()))
 	}
-	limitInterval := serviceInterval
-	methodIntervalConfig := invocation.AttachmentsByKey(constant.TPS_LIMIT_INTERVAL_KEY, "")
-	// there is the interval configuration of method-level
-	if len(methodIntervalConfig) > 0 {
-		limitInterval, err = strconv.ParseInt(methodIntervalConfig, 0, 0)
-		if err != nil || limitInterval <= 0{
+
+	limitStrategyConfig := url.GetParam(methodConfigPrefix+constant.TPS_LIMIT_STRATEGY_KEY,
+		url.GetParam(constant.TPS_LIMIT_STRATEGY_KEY, constant.DEFAULT_KEY))
+	limitStateCreator := extension.GetTpsLimitStrategyCreator(limitStrategyConfig)
+	limitState, _ = limiter.tpsState.LoadOrStore(limitTarget, limitStateCreator(int(limitRate), int(limitInterval)))
+	return limitState.(filter.TpsLimitStrategy).IsAllowable()
+}
+
+func getLimitConfig(methodLevelConfig string,
+	url common.URL,
+	invocation protocol.Invocation,
+	configKey string,
+	defaultVal string) int64 {
+
+	if len(methodLevelConfig) > 0 {
+		result, err := strconv.ParseInt(methodLevelConfig, 0, 0)
+		if err != nil {
 			panic(fmt.Sprintf("The %s for invocation %s # %s must be positive, please check your configuration!",
-				constant.TPS_LIMIT_INTERVAL_KEY, url.ServiceKey(), invocation.MethodName()))
+				configKey, url.ServiceKey(), invocation.MethodName()))
 		}
+		return result
 	}
 
-	limitTarget := url.ServiceKey()
+	result, err := strconv.ParseInt(url.GetParam(configKey, defaultVal), 0, 0)
 
-	// method-level tps limit
-	if len(methodIntervalConfig) > 0 || len(methodLimitRateConfig) >0  {
-		limitTarget = limitTarget + "#" + invocation.MethodName()
+	if err != nil {
+		panic(fmt.Sprintf("Cannot parse the configuration %s, please check your configuration!", configKey))
 	}
-
-	limitStrategyConfig := invocation.AttachmentsByKey(constant.TPS_LIMIT_STRATEGY_KEY,
-		url.GetParam(constant.TPS_LIMIT_STRATEGY_KEY, constant.DEFAULT_KEY))
-	limitStateCreator := extension.GetTpsLimitStrategyCreator(limitStrategyConfig)
-	limitState, _ := limiter.tpsState.LoadOrStore(limitTarget, limitStateCreator(int(limitRate), int(limitInterval)))
-	return limitState.(filter.TpsLimitStrategy).IsAllowable()
+	return result
 }
 
 var methodServiceTpsLimiterInstance *MethodServiceTpsLimiterImpl
-- 
GitLab