diff --git a/common/url.go b/common/url.go index eca1d27992f7ea1e46f292d26544756842867de5..60a53c6df5aa4bc3d208067ee3b7b74034b4045b 100644 --- a/common/url.go +++ b/common/url.go @@ -71,7 +71,7 @@ type baseUrl struct { Port string //url.Values is not safe map, add to avoid concurrent map read and map write error paramsLock sync.RWMutex - Params url.Values + params url.Values PrimitiveURL string ctx context.Context } @@ -108,13 +108,13 @@ func WithMethods(methods []string) option { func WithParams(params url.Values) option { return func(url *URL) { - url.Params = params + url.params = params } } func WithParamsValue(key, val string) option { return func(url *URL) { - url.Params.Set(key, val) + url.SetParam(key, val) } } @@ -189,7 +189,7 @@ func NewURL(ctx context.Context, urlString string, opts ...option) (URL, error) return s, perrors.Errorf("url.Parse(url string{%s}), error{%v}", rawUrlString, err) } - s.Params, err = url.ParseQuery(serviceUrl.RawQuery) + s.params, err = url.ParseQuery(serviceUrl.RawQuery) if err != nil { return s, perrors.Errorf("url.ParseQuery(raw url string{%s}), error{%v}", serviceUrl.RawQuery, err) } @@ -237,7 +237,9 @@ func (c URL) String() string { buildString := fmt.Sprintf( "%s://%s:%s@%s:%s%s?", c.Protocol, c.Username, c.Password, c.Ip, c.Port, c.Path) - buildString += c.Params.Encode() + c.paramsLock.RLock() + buildString += c.params.Encode() + c.paramsLock.RUnlock() return buildString } @@ -291,26 +293,46 @@ func (c URL) Service() string { func (c *URL) AddParam(key string, value string) { c.paramsLock.Lock() - c.Params.Add(key, value) + c.params.Add(key, value) + c.paramsLock.Unlock() +} + +func (c *URL) SetParam(key string, value string) { + c.paramsLock.Lock() + c.params.Set(key, value) c.paramsLock.Unlock() } +func (c *URL) RangeParams(f func(key, value string) bool) { + c.paramsLock.RLock() + for k, v := range c.params { + if !f(k, v[0]) { + break + } + } + c.paramsLock.RUnlock() +} + func (c URL) GetParam(s string, d string) string { var r string c.paramsLock.RLock() - if r = c.Params.Get(s); len(r) == 0 { + if r = c.params.Get(s); len(r) == 0 { r = d } c.paramsLock.RUnlock() return r } func (c URL) GetParamAndDecoded(key string) (string, error) { + c.paramsLock.RLock() + defer c.paramsLock.RUnlock() ruleDec, err := base64.URLEncoding.DecodeString(c.GetParam(key, "")) value := string(ruleDec) return value, err } func (c URL) GetRawParam(key string) string { + c.paramsLock.RLock() + defer c.paramsLock.RUnlock() switch key { case "protocol": return c.Protocol @@ -325,7 +347,7 @@ func (c URL) GetRawParam(key string) string { case "path": return c.Path default: - return c.Params.Get(key) + return c.params.Get(key) } } @@ -334,7 +356,9 @@ func (c URL) GetParamBool(s string, d bool) bool { var r bool var err error - if r, err = strconv.ParseBool(c.Params.Get(s)); err != nil { + c.paramsLock.RLock() + defer c.paramsLock.RUnlock() + if r, err = strconv.ParseBool(c.params.Get(s)); err != nil { return d } return r @@ -343,7 +367,9 @@ func (c URL) GetParamBool(s string, d bool) bool { func (c URL) GetParamInt(s string, d int64) int64 { var r int var err error - if r, err = strconv.Atoi(c.Params.Get(s)); r == 0 || err != nil { + c.paramsLock.RLock() + defer c.paramsLock.RUnlock() + if r, err = strconv.Atoi(c.params.Get(s)); r == 0 || err != nil { return d } return int64(r) @@ -352,7 +378,9 @@ func (c URL) GetParamInt(s string, d int64) int64 { func (c URL) GetMethodParamInt(method string, key string, d int64) int64 { var r int var err error - if r, err = strconv.Atoi(c.Params.Get("methods." + method + "." + key)); r == 0 || err != nil { + c.paramsLock.RLock() + defer c.paramsLock.RUnlock() + if r, err = strconv.Atoi(c.params.Get("methods." + method + "." + key)); r == 0 || err != nil { return d } return int64(r) @@ -369,7 +397,7 @@ func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 { func (c URL) GetMethodParam(method string, key string, d string) string { var r string - if r = c.Params.Get("methods." + method + "." + key); r == "" { + if r = c.params.Get("methods." + method + "." + key); r == "" { r = d } return r @@ -380,9 +408,11 @@ func (c URL) ToMap() map[string]string { paramsMap := make(map[string]string) - for k, v := range c.Params { - paramsMap[k] = v[0] - } + c.RangeParams(func(key, value string) bool { + paramsMap[key] = value + return true + }) + if c.Protocol != "" { paramsMap["protocol"] = c.Protocol } @@ -421,19 +451,19 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL { mergedUrl := serviceUrl //iterator the referenceUrl if serviceUrl not have the key ,merge in - - for k, v := range referenceUrl.Params { - if _, ok := mergedUrl.Params[k]; !ok { - mergedUrl.Params.Set(k, v[0]) + referenceUrl.RangeParams(func(key, value string) bool { + if v := mergedUrl.GetParam(key, ""); len(v) == 0 { + mergedUrl.SetParam(key, value) } - } + return true + }) //loadBalance,cluster,retries strategy config methodConfigMergeFcn := mergeNormalParam(mergedUrl, referenceUrl, []string{constant.LOADBALANCE_KEY, constant.CLUSTER_KEY, constant.RETRIES_KEY}) //remote timestamp - if v := serviceUrl.Params.Get(constant.TIMESTAMP_KEY); len(v) > 0 { - mergedUrl.Params.Set(constant.REMOTE_TIMESTAMP_KEY, v) - mergedUrl.Params.Set(constant.TIMESTAMP_KEY, referenceUrl.Params.Get(constant.TIMESTAMP_KEY)) + if v := serviceUrl.GetParam(constant.TIMESTAMP_KEY, ""); len(v) > 0 { + mergedUrl.SetParam(constant.REMOTE_TIMESTAMP_KEY, v) + mergedUrl.SetParam(constant.TIMESTAMP_KEY, referenceUrl.GetParam(constant.TIMESTAMP_KEY, "")) } //finally execute methodConfigMergeFcn @@ -449,12 +479,12 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL { func mergeNormalParam(mergedUrl URL, referenceUrl *URL, paramKeys []string) []func(method string) { var methodConfigMergeFcn = []func(method string){} for _, paramKey := range paramKeys { - if v := referenceUrl.Params.Get(paramKey); len(v) > 0 { - mergedUrl.Params.Set(paramKey, v) + if v := referenceUrl.GetParam(paramKey, ""); len(v) > 0 { + mergedUrl.SetParam(paramKey, v) } methodConfigMergeFcn = append(methodConfigMergeFcn, func(method string) { - if v := referenceUrl.Params.Get(method + "." + paramKey); len(v) > 0 { - mergedUrl.Params.Set(method+"."+paramKey, v) + if v := referenceUrl.GetParam(method+"."+paramKey, ""); len(v) > 0 { + mergedUrl.SetParam(method+"."+paramKey, v) } }) } diff --git a/common/url_test.go b/common/url_test.go index 9366b0c1cc39910bfec48efd052fd1fa31cf1750..4b7d9376a41d8ab4023dad145201ace04fd6eac3 100644 --- a/common/url_test.go +++ b/common/url_test.go @@ -52,7 +52,7 @@ func TestNewURLWithOptions(t *testing.T) { assert.Equal(t, "127.0.0.1", u.Ip) assert.Equal(t, "8080", u.Port) assert.Equal(t, methods, u.Methods) - assert.Equal(t, params, u.Params) + assert.Equal(t, params, u.params) } func TestURL(t *testing.T) { @@ -74,7 +74,7 @@ func TestURL(t *testing.T) { assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+ "provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+ "2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+ - "imestamp=1556509797245", u.Params.Encode()) + "imestamp=1556509797245", u.params.Encode()) assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+ "ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+ @@ -101,7 +101,7 @@ func TestURLWithoutSchema(t *testing.T) { assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+ "provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+ "2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+ - "imestamp=1556509797245", u.Params.Encode()) + "imestamp=1556509797245", u.params.Encode()) assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+ "ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+ @@ -124,7 +124,7 @@ func TestURL_URLEqual(t *testing.T) { func TestURL_GetParam(t *testing.T) { params := url.Values{} params.Set("key", "value") - u := URL{baseUrl: baseUrl{Params: params}} + u := URL{baseUrl: baseUrl{params: params}} v := u.GetParam("key", "default") assert.Equal(t, "value", v) @@ -136,7 +136,7 @@ func TestURL_GetParam(t *testing.T) { func TestURL_GetParamInt(t *testing.T) { params := url.Values{} params.Set("key", "3") - u := URL{baseUrl: baseUrl{Params: params}} + u := URL{baseUrl: baseUrl{params: params}} v := u.GetParamInt("key", 1) assert.Equal(t, int64(3), v) @@ -148,7 +148,7 @@ func TestURL_GetParamInt(t *testing.T) { func TestURL_GetParamBool(t *testing.T) { params := url.Values{} params.Set("force", "true") - u := URL{baseUrl: baseUrl{Params: params}} + u := URL{baseUrl: baseUrl{params: params}} v := u.GetParamBool("force", false) assert.Equal(t, true, v) @@ -161,7 +161,7 @@ func TestURL_GetParamAndDecoded(t *testing.T) { rule := "host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4" params := url.Values{} params.Set("rule", base64.URLEncoding.EncodeToString([]byte(rule))) - u := URL{baseUrl: baseUrl{Params: params}} + u := URL{baseUrl: baseUrl{params: params}} v, _ := u.GetParamAndDecoded("rule") assert.Equal(t, rule, v) } @@ -196,7 +196,7 @@ func TestURL_ToMap(t *testing.T) { func TestURL_GetMethodParamInt(t *testing.T) { params := url.Values{} params.Set("methods.GetValue.timeout", "3") - u := URL{baseUrl: baseUrl{Params: params}} + u := URL{baseUrl: baseUrl{params: params}} v := u.GetMethodParamInt("GetValue", "timeout", 1) assert.Equal(t, int64(3), v) @@ -208,7 +208,7 @@ func TestURL_GetMethodParamInt(t *testing.T) { func TestURL_GetMethodParam(t *testing.T) { params := url.Values{} params.Set("methods.GetValue.timeout", "3s") - u := URL{baseUrl: baseUrl{Params: params}} + u := URL{baseUrl: baseUrl{params: params}} v := u.GetMethodParam("GetValue", "timeout", "1s") assert.Equal(t, "3s", v) diff --git a/protocol/protocolwrapper/protocol_filter_wrapper.go b/protocol/protocolwrapper/protocol_filter_wrapper.go index b1392fff511dba2e2cbedf2547d6be2d4276a912..7c58fabea3cccf5a39e1622fedd4a3a297e05983 100644 --- a/protocol/protocolwrapper/protocol_filter_wrapper.go +++ b/protocol/protocolwrapper/protocol_filter_wrapper.go @@ -62,7 +62,7 @@ func (pfw *ProtocolFilterWrapper) Destroy() { } func buildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker { - filtName := invoker.GetUrl().Params.Get(key) + filtName := invoker.GetUrl().GetParam(key, "") if filtName == "" { return invoker } diff --git a/registry/consul/utils.go b/registry/consul/utils.go index 6ecb0573ec5683d66795f539a8beb2eff5b6be2c..ee17fcc0df43228e26b40f3ac3f992147fc33d6e 100644 --- a/registry/consul/utils.go +++ b/registry/consul/utils.go @@ -62,9 +62,12 @@ func buildService(url common.URL) (*consul.AgentServiceRegistration, error) { // tags tags := make([]string, 0, 8) - for k := range url.Params { - tags = append(tags, k+"="+url.Params.Get(k)) - } + + url.RangeParams(func(key, value string) bool { + tags = append(tags, key+"="+value) + return true + }) + tags = append(tags, "dubbo") // meta diff --git a/registry/directory/directory_test.go b/registry/directory/directory_test.go index f31165d0a2e32c89b3d15df3df4e2048dadcb5e5..9b48da80df1f13684210ccabf38e9780ae8d3f42 100644 --- a/registry/directory/directory_test.go +++ b/registry/directory/directory_test.go @@ -69,7 +69,7 @@ func TestSubscribe_Group(t *testing.T) { regurl, _ := common.NewURL(context.TODO(), "mock://127.0.0.1:1111") suburl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000") - suburl.Params.Set(constant.CLUSTER_KEY, "mock") + suburl.SetParam(constant.CLUSTER_KEY, "mock") regurl.SubURL = &suburl mockRegistry, _ := registry.NewMockRegistry(&common.URL{}) registryDirectory, _ := NewRegistryDirectory(®url, mockRegistry) diff --git a/registry/etcdv3/registry.go b/registry/etcdv3/registry.go index 5802142989e5d8297f027ddecff3d7780070729f..96d237f18b9e6461d0330f8b39eec7b7aa5d3a3e 100644 --- a/registry/etcdv3/registry.go +++ b/registry/etcdv3/registry.go @@ -257,10 +257,11 @@ func (r *etcdV3Registry) registerProvider(svc common.URL) error { } params := url.Values{} - for k, v := range svc.Params { - params[k] = v - } + svc.RangeParams(func(key, value string) bool { + params[key] = []string{value} + return true + }) params.Add("pid", processID) params.Add("ip", localIP) params.Add("anyhost", "true") diff --git a/registry/etcdv3/registry_test.go b/registry/etcdv3/registry_test.go index 26204c74ad4305278e33d9c8b50199cfa578bf8a..6d8fc240561f76264f79b6452014629b3e6e1868 100644 --- a/registry/etcdv3/registry_test.go +++ b/registry/etcdv3/registry_test.go @@ -63,7 +63,7 @@ func (suite *RegistryTestSuite) TestSubscribe() { } //consumer register - regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) + regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) reg2 := initRegistry(t) reg2.Register(url) diff --git a/registry/nacos/registry.go b/registry/nacos/registry.go index bf86ead7a31f5873078b9dc4acd3f0dcf6aec783..d6235cb23072ef7a22d99d065890f6068f4b9ea0 100644 --- a/registry/nacos/registry.go +++ b/registry/nacos/registry.go @@ -119,10 +119,13 @@ func appendParam(target *bytes.Buffer, url common.URL, key string) { func createRegisterParam(url common.URL, serviceName string) vo.RegisterInstanceParam { category := getCategory(url) - params := make(map[string]string, len(url.Params)+3) - for k := range url.Params { - params[k] = url.Params.Get(k) - } + params := make(map[string]string) + + url.RangeParams(func(key, value string) bool { + params[key] = value + return true + }) + params[constant.NACOS_CATEGORY_KEY] = category params[constant.NACOS_PROTOCOL_KEY] = url.Protocol params[constant.NACOS_PATH_KEY] = url.Path diff --git a/registry/nacos/registry_test.go b/registry/nacos/registry_test.go index 9ce9dcfe4d9f7d3974a3d07e093f59888e73a91d..bb43a5feed677425f148f5af9351143adfd9272f 100644 --- a/registry/nacos/registry_test.go +++ b/registry/nacos/registry_test.go @@ -66,7 +66,7 @@ func TestNacosRegistry_Subscribe(t *testing.T) { return } - regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) + regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) reg2, _ := newNacosRegistry(®url) listener, err := reg2.Subscribe(url) assert.Nil(t, err) @@ -111,7 +111,7 @@ func TestNacosRegistry_Subscribe_del(t *testing.T) { return } - regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) + regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) reg2, _ := newNacosRegistry(®url) listener, err := reg2.Subscribe(url1) assert.Nil(t, err) diff --git a/registry/zookeeper/registry.go b/registry/zookeeper/registry.go index ff57eb638aa8919720b9eeed1cb4603cc2928cf2..86fe52940cce55a723acc9fd72b2496dccc87572 100644 --- a/registry/zookeeper/registry.go +++ b/registry/zookeeper/registry.go @@ -80,7 +80,6 @@ type zkRegistry struct { configListener *RegistryConfigurationListener //for provider zkPath map[string]int // key = protocol://ip:port/interface - } func newZkRegistry(url *common.URL) (registry.Registry, error) { @@ -271,9 +270,11 @@ func (r *zkRegistry) register(c common.URL) error { return perrors.WithStack(err) } params = url.Values{} - for k, v := range c.Params { - params[k] = v - } + + c.RangeParams(func(key, value string) bool { + params[key] = []string{value} + return true + }) params.Add("pid", processID) params.Add("ip", localIP) diff --git a/registry/zookeeper/registry_test.go b/registry/zookeeper/registry_test.go index 2b5e2f8f7caf749be28bf3ff6e5d14980d70f2f4..2ac946a3fa292e824d74501210349b0401bfdd5f 100644 --- a/registry/zookeeper/registry_test.go +++ b/registry/zookeeper/registry_test.go @@ -60,7 +60,7 @@ func Test_Subscribe(t *testing.T) { } //consumer register - regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) + regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER)) _, reg2, _ := newMockZkRegistry(®url, zookeeper.WithTestCluster(ts)) reg2.Register(url)