Skip to content
Snippets Groups Projects
Commit 996218c9 authored by vito.he's avatar vito.he
Browse files

Fix:fix url params unsafe

parent 7b6ad298
No related branches found
No related tags found
No related merge requests found
......@@ -71,7 +71,7 @@ type baseUrl struct {
Port string
//url.Values is not safe map, add to avoid concurrent map read and map write error
paramsLock sync.RWMutex
Params url.Values
params url.Values
PrimitiveURL string
ctx context.Context
}
......@@ -108,13 +108,13 @@ func WithMethods(methods []string) option {
func WithParams(params url.Values) option {
return func(url *URL) {
url.Params = params
url.params = params
}
}
func WithParamsValue(key, val string) option {
return func(url *URL) {
url.Params.Set(key, val)
url.SetParam(key, val)
}
}
......@@ -189,7 +189,7 @@ func NewURL(ctx context.Context, urlString string, opts ...option) (URL, error)
return s, perrors.Errorf("url.Parse(url string{%s}), error{%v}", rawUrlString, err)
}
s.Params, err = url.ParseQuery(serviceUrl.RawQuery)
s.params, err = url.ParseQuery(serviceUrl.RawQuery)
if err != nil {
return s, perrors.Errorf("url.ParseQuery(raw url string{%s}), error{%v}", serviceUrl.RawQuery, err)
}
......@@ -237,7 +237,9 @@ func (c URL) String() string {
buildString := fmt.Sprintf(
"%s://%s:%s@%s:%s%s?",
c.Protocol, c.Username, c.Password, c.Ip, c.Port, c.Path)
buildString += c.Params.Encode()
c.paramsLock.RLock()
buildString += c.params.Encode()
c.paramsLock.RUnlock()
return buildString
}
......@@ -291,26 +293,46 @@ func (c URL) Service() string {
func (c *URL) AddParam(key string, value string) {
c.paramsLock.Lock()
c.Params.Add(key, value)
c.params.Add(key, value)
c.paramsLock.Unlock()
}
func (c *URL) SetParam(key string, value string) {
c.paramsLock.Lock()
c.params.Set(key, value)
c.paramsLock.Unlock()
}
func (c *URL) RangeParams(f func(key, value string) bool) {
c.paramsLock.RLock()
for k, v := range c.params {
if !f(k, v[0]) {
break
}
}
c.paramsLock.RUnlock()
}
func (c URL) GetParam(s string, d string) string {
var r string
c.paramsLock.RLock()
if r = c.Params.Get(s); len(r) == 0 {
if r = c.params.Get(s); len(r) == 0 {
r = d
}
c.paramsLock.RUnlock()
return r
}
func (c URL) GetParamAndDecoded(key string) (string, error) {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
ruleDec, err := base64.URLEncoding.DecodeString(c.GetParam(key, ""))
value := string(ruleDec)
return value, err
}
func (c URL) GetRawParam(key string) string {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
switch key {
case "protocol":
return c.Protocol
......@@ -325,7 +347,7 @@ func (c URL) GetRawParam(key string) string {
case "path":
return c.Path
default:
return c.Params.Get(key)
return c.params.Get(key)
}
}
......@@ -334,7 +356,9 @@ func (c URL) GetParamBool(s string, d bool) bool {
var r bool
var err error
if r, err = strconv.ParseBool(c.Params.Get(s)); err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.ParseBool(c.params.Get(s)); err != nil {
return d
}
return r
......@@ -343,7 +367,9 @@ func (c URL) GetParamBool(s string, d bool) bool {
func (c URL) GetParamInt(s string, d int64) int64 {
var r int
var err error
if r, err = strconv.Atoi(c.Params.Get(s)); r == 0 || err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.Atoi(c.params.Get(s)); r == 0 || err != nil {
return d
}
return int64(r)
......@@ -352,7 +378,9 @@ func (c URL) GetParamInt(s string, d int64) int64 {
func (c URL) GetMethodParamInt(method string, key string, d int64) int64 {
var r int
var err error
if r, err = strconv.Atoi(c.Params.Get("methods." + method + "." + key)); r == 0 || err != nil {
c.paramsLock.RLock()
defer c.paramsLock.RUnlock()
if r, err = strconv.Atoi(c.params.Get("methods." + method + "." + key)); r == 0 || err != nil {
return d
}
return int64(r)
......@@ -369,7 +397,7 @@ func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 {
func (c URL) GetMethodParam(method string, key string, d string) string {
var r string
if r = c.Params.Get("methods." + method + "." + key); r == "" {
if r = c.params.Get("methods." + method + "." + key); r == "" {
r = d
}
return r
......@@ -380,9 +408,11 @@ func (c URL) ToMap() map[string]string {
paramsMap := make(map[string]string)
for k, v := range c.Params {
paramsMap[k] = v[0]
}
c.RangeParams(func(key, value string) bool {
paramsMap[key] = value
return true
})
if c.Protocol != "" {
paramsMap["protocol"] = c.Protocol
}
......@@ -421,19 +451,19 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL {
mergedUrl := serviceUrl
//iterator the referenceUrl if serviceUrl not have the key ,merge in
for k, v := range referenceUrl.Params {
if _, ok := mergedUrl.Params[k]; !ok {
mergedUrl.Params.Set(k, v[0])
referenceUrl.RangeParams(func(key, value string) bool {
if v := mergedUrl.GetParam(key, ""); len(v) == 0 {
mergedUrl.SetParam(key, value)
}
}
return true
})
//loadBalance,cluster,retries strategy config
methodConfigMergeFcn := mergeNormalParam(mergedUrl, referenceUrl, []string{constant.LOADBALANCE_KEY, constant.CLUSTER_KEY, constant.RETRIES_KEY})
//remote timestamp
if v := serviceUrl.Params.Get(constant.TIMESTAMP_KEY); len(v) > 0 {
mergedUrl.Params.Set(constant.REMOTE_TIMESTAMP_KEY, v)
mergedUrl.Params.Set(constant.TIMESTAMP_KEY, referenceUrl.Params.Get(constant.TIMESTAMP_KEY))
if v := serviceUrl.GetParam(constant.TIMESTAMP_KEY, ""); len(v) > 0 {
mergedUrl.SetParam(constant.REMOTE_TIMESTAMP_KEY, v)
mergedUrl.SetParam(constant.TIMESTAMP_KEY, referenceUrl.GetParam(constant.TIMESTAMP_KEY, ""))
}
//finally execute methodConfigMergeFcn
......@@ -449,12 +479,12 @@ func MergeUrl(serviceUrl URL, referenceUrl *URL) URL {
func mergeNormalParam(mergedUrl URL, referenceUrl *URL, paramKeys []string) []func(method string) {
var methodConfigMergeFcn = []func(method string){}
for _, paramKey := range paramKeys {
if v := referenceUrl.Params.Get(paramKey); len(v) > 0 {
mergedUrl.Params.Set(paramKey, v)
if v := referenceUrl.GetParam(paramKey, ""); len(v) > 0 {
mergedUrl.SetParam(paramKey, v)
}
methodConfigMergeFcn = append(methodConfigMergeFcn, func(method string) {
if v := referenceUrl.Params.Get(method + "." + paramKey); len(v) > 0 {
mergedUrl.Params.Set(method+"."+paramKey, v)
if v := referenceUrl.GetParam(method+"."+paramKey, ""); len(v) > 0 {
mergedUrl.SetParam(method+"."+paramKey, v)
}
})
}
......
......@@ -52,7 +52,7 @@ func TestNewURLWithOptions(t *testing.T) {
assert.Equal(t, "127.0.0.1", u.Ip)
assert.Equal(t, "8080", u.Port)
assert.Equal(t, methods, u.Methods)
assert.Equal(t, params, u.Params)
assert.Equal(t, params, u.params)
}
func TestURL(t *testing.T) {
......@@ -74,7 +74,7 @@ func TestURL(t *testing.T) {
assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+
"provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+
"2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+
"imestamp=1556509797245", u.Params.Encode())
"imestamp=1556509797245", u.params.Encode())
assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+
"ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+
......@@ -101,7 +101,7 @@ func TestURLWithoutSchema(t *testing.T) {
assert.Equal(t, "anyhost=true&application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-"+
"provider-golang-1.0.0&environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%"+
"2C&module=dubbogo+user-info+server&org=ikurento.com&owner=ZX&pid=1447&revision=0.0.1&side=provider&timeout=3000&t"+
"imestamp=1556509797245", u.Params.Encode())
"imestamp=1556509797245", u.params.Encode())
assert.Equal(t, "dubbo://:@127.0.0.1:20000/com.ikurento.user.UserProvider?anyhost=true&application=BDTServi"+
"ce&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&environment=dev&interface=com.ikure"+
......@@ -124,7 +124,7 @@ func TestURL_URLEqual(t *testing.T) {
func TestURL_GetParam(t *testing.T) {
params := url.Values{}
params.Set("key", "value")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParam("key", "default")
assert.Equal(t, "value", v)
......@@ -136,7 +136,7 @@ func TestURL_GetParam(t *testing.T) {
func TestURL_GetParamInt(t *testing.T) {
params := url.Values{}
params.Set("key", "3")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParamInt("key", 1)
assert.Equal(t, int64(3), v)
......@@ -148,7 +148,7 @@ func TestURL_GetParamInt(t *testing.T) {
func TestURL_GetParamBool(t *testing.T) {
params := url.Values{}
params.Set("force", "true")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetParamBool("force", false)
assert.Equal(t, true, v)
......@@ -161,7 +161,7 @@ func TestURL_GetParamAndDecoded(t *testing.T) {
rule := "host = 2.2.2.2,1.1.1.1,3.3.3.3 & host !=1.1.1.1 => host = 1.2.3.4"
params := url.Values{}
params.Set("rule", base64.URLEncoding.EncodeToString([]byte(rule)))
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v, _ := u.GetParamAndDecoded("rule")
assert.Equal(t, rule, v)
}
......@@ -196,7 +196,7 @@ func TestURL_ToMap(t *testing.T) {
func TestURL_GetMethodParamInt(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParamInt("GetValue", "timeout", 1)
assert.Equal(t, int64(3), v)
......@@ -208,7 +208,7 @@ func TestURL_GetMethodParamInt(t *testing.T) {
func TestURL_GetMethodParam(t *testing.T) {
params := url.Values{}
params.Set("methods.GetValue.timeout", "3s")
u := URL{baseUrl: baseUrl{Params: params}}
u := URL{baseUrl: baseUrl{params: params}}
v := u.GetMethodParam("GetValue", "timeout", "1s")
assert.Equal(t, "3s", v)
......
......@@ -62,7 +62,7 @@ func (pfw *ProtocolFilterWrapper) Destroy() {
}
func buildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
filtName := invoker.GetUrl().Params.Get(key)
filtName := invoker.GetUrl().GetParam(key, "")
if filtName == "" {
return invoker
}
......
......@@ -62,9 +62,12 @@ func buildService(url common.URL) (*consul.AgentServiceRegistration, error) {
// tags
tags := make([]string, 0, 8)
for k := range url.Params {
tags = append(tags, k+"="+url.Params.Get(k))
}
url.RangeParams(func(key, value string) bool {
tags = append(tags, key+"="+value)
return true
})
tags = append(tags, "dubbo")
// meta
......
......@@ -69,7 +69,7 @@ func TestSubscribe_Group(t *testing.T) {
regurl, _ := common.NewURL(context.TODO(), "mock://127.0.0.1:1111")
suburl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000")
suburl.Params.Set(constant.CLUSTER_KEY, "mock")
suburl.SetParam(constant.CLUSTER_KEY, "mock")
regurl.SubURL = &suburl
mockRegistry, _ := registry.NewMockRegistry(&common.URL{})
registryDirectory, _ := NewRegistryDirectory(&regurl, mockRegistry)
......
......@@ -257,10 +257,11 @@ func (r *etcdV3Registry) registerProvider(svc common.URL) error {
}
params := url.Values{}
for k, v := range svc.Params {
params[k] = v
}
svc.RangeParams(func(key, value string) bool {
params[key] = []string{value}
return true
})
params.Add("pid", processID)
params.Add("ip", localIP)
params.Add("anyhost", "true")
......
......@@ -63,7 +63,7 @@ func (suite *RegistryTestSuite) TestSubscribe() {
}
//consumer register
regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2 := initRegistry(t)
reg2.Register(url)
......
......@@ -119,10 +119,13 @@ func appendParam(target *bytes.Buffer, url common.URL, key string) {
func createRegisterParam(url common.URL, serviceName string) vo.RegisterInstanceParam {
category := getCategory(url)
params := make(map[string]string, len(url.Params)+3)
for k := range url.Params {
params[k] = url.Params.Get(k)
}
params := make(map[string]string)
url.RangeParams(func(key, value string) bool {
params[key] = value
return true
})
params[constant.NACOS_CATEGORY_KEY] = category
params[constant.NACOS_PROTOCOL_KEY] = url.Protocol
params[constant.NACOS_PATH_KEY] = url.Path
......
......@@ -66,7 +66,7 @@ func TestNacosRegistry_Subscribe(t *testing.T) {
return
}
regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2, _ := newNacosRegistry(&regurl)
listener, err := reg2.Subscribe(url)
assert.Nil(t, err)
......@@ -111,7 +111,7 @@ func TestNacosRegistry_Subscribe_del(t *testing.T) {
return
}
regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
reg2, _ := newNacosRegistry(&regurl)
listener, err := reg2.Subscribe(url1)
assert.Nil(t, err)
......
......@@ -80,7 +80,6 @@ type zkRegistry struct {
configListener *RegistryConfigurationListener
//for provider
zkPath map[string]int // key = protocol://ip:port/interface
}
func newZkRegistry(url *common.URL) (registry.Registry, error) {
......@@ -271,9 +270,11 @@ func (r *zkRegistry) register(c common.URL) error {
return perrors.WithStack(err)
}
params = url.Values{}
for k, v := range c.Params {
params[k] = v
}
c.RangeParams(func(key, value string) bool {
params[key] = []string{value}
return true
})
params.Add("pid", processID)
params.Add("ip", localIP)
......
......@@ -60,7 +60,7 @@ func Test_Subscribe(t *testing.T) {
}
//consumer register
regurl.Params.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
regurl.SetParam(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
_, reg2, _ := newMockZkRegistry(&regurl, zookeeper.WithTestCluster(ts))
reg2.Register(url)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment