diff --git a/common/constant/key.go b/common/constant/key.go index d201570b9ad5415694af5598fba7983289b2b295..eff704371c7c5b66ca11a846ad7603a01f8b5708 100644 --- a/common/constant/key.go +++ b/common/constant/key.go @@ -141,3 +141,21 @@ const ( const ( TRACING_REMOTE_SPAN_CTX = "tracing.remote.span.ctx" ) + +const ( + CONSUMER_SIGN_FILTER = "sign" + PROVIDER_AUTH_FILTER = "auth" + SERVICE_AUTH_KEY = "auth" + AUTHENTICATOR_KEY = "authenticator" + DEFAULT_AUTHENTICATOR = "accesskeys" + DEFAULT_ACCESS_KEY_STORAGE = "urlstorage" + ACCESS_KEY_STORAGE_KEY = "accessKey.storage" + REQUEST_TIMESTAMP_KEY = "timestamp" + REQUEST_SIGNATURE_KEY = "signature" + AK_KEY = "ak" + SIGNATURE_STRING_FORMAT = "%s#%s#%s#%s" + PARAMTER_SIGNATURE_ENABLE_KEY = "param.sign" + CONSUMER = "consumer" + ACCESS_KEY_ID_KEY = "accessKeyId" + SECRET_ACCESS_KEY_KEY = "secretAccessKey" +) diff --git a/common/extension/auth.go b/common/extension/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..e57e22f660b6d4dec63f8b4a06c25b05bd5c8d72 --- /dev/null +++ b/common/extension/auth.go @@ -0,0 +1,32 @@ +package extension + +import ( + "github.com/apache/dubbo-go/filter" +) + +var ( + authenticators = make(map[string]func() filter.Authenticator) + accesskeyStorages = make(map[string]func() filter.AccessKeyStorage) +) + +func SetAuthenticator(name string, fcn func() filter.Authenticator) { + authenticators[name] = fcn +} + +func GetAuthenticator(name string) filter.Authenticator { + if authenticators[name] == nil { + panic("authenticator for " + name + " is not existing, make sure you have import the package.") + } + return authenticators[name]() +} + +func SetAccesskeyStorages(name string, fcn func() filter.AccessKeyStorage) { + accesskeyStorages[name] = fcn +} + +func GetAccesskeyStorages(name string) filter.AccessKeyStorage { + if accesskeyStorages[name] == nil { + panic("accesskeyStorages for " + name + " is not existing, make sure you have import the package.") + } + return accesskeyStorages[name]() +} diff --git a/common/url.go b/common/url.go index a073e013f47a2acff4782ffa4444203fa0cec9b5..b2a514b4e6b1ba105e9f9aa4f9501bce1e613d4d 100644 --- a/common/url.go +++ b/common/url.go @@ -339,6 +339,28 @@ func (c URL) ServiceKey() string { return buf.String() } +// ColonSeparatedKey +// The format is "{interface}:[version]:[group]" +func (c *URL) ColonSeparatedKey() string { + intf := c.GetParam(constant.INTERFACE_KEY, strings.TrimPrefix(c.Path, "/")) + if intf == "" { + return "" + } + buf := &bytes.Buffer{} + buf.WriteString(intf) + buf.WriteString(":") + version := c.GetParam(constant.VERSION_KEY, "") + if version != "" && version != "0.0.0" { + buf.WriteString(version) + } + group := c.GetParam(constant.GROUP_KEY, "") + buf.WriteString(":") + if group != "" { + buf.WriteString(group) + } + return buf.String() +} + // EncodedServiceKey ... func (c *URL) EncodedServiceKey() string { serviceKey := c.ServiceKey() diff --git a/common/url_test.go b/common/url_test.go index c70c58bc215b6449311d43f9f9cffeb89623f80c..9d80bc52b8ee5c74abb29b3b65487156ebf6ae2c 100644 --- a/common/url_test.go +++ b/common/url_test.go @@ -271,3 +271,17 @@ func TestClone(t *testing.T) { assert.Equal(t, u1.Protocol, "dubbo") assert.Equal(t, u2.Protocol, "provider") } + +func TestColonSeparatedKey(t *testing.T) { + u1, _ := NewURL(context.TODO(), "dubbo://127.0.0.1:20000") + u1.AddParam(constant.INTERFACE_KEY, "com.ikurento.user.UserProvider") + + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::") + u1.AddParam(constant.VERSION_KEY, "version1") + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:") + u1.AddParam(constant.GROUP_KEY, "group1") + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:group1") + u1.SetParam(constant.VERSION_KEY, "") + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::group1") + +} diff --git a/config/service_config.go b/config/service_config.go index 37ec3a3ae611d60d71f5679c1d501bb699351849..2111838395d507ebac4f72883c99dd2bb1615850 100644 --- a/config/service_config.go +++ b/config/service_config.go @@ -67,6 +67,8 @@ type ServiceConfig struct { TpsLimitRejectedHandler string `yaml:"tps.limit.rejected.handler" json:"tps.limit.rejected.handler,omitempty" property:"tps.limit.rejected.handler"` ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"` ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"` + Auth string `yaml:"auth" json:"auth,omitempty" property:"auth"` + ParamSign string `yaml:"param.sign" json:"param.sign,omitempty" property:"param.sign"` unexported *atomic.Bool exported *atomic.Bool @@ -220,6 +222,10 @@ func (c *ServiceConfig) getUrlMap() url.Values { urlMap.Set(constant.EXECUTE_LIMIT_KEY, c.ExecuteLimit) urlMap.Set(constant.EXECUTE_REJECTED_EXECUTION_HANDLER_KEY, c.ExecuteLimitRejectedHandler) + // auth filter + urlMap.Set(constant.SERVICE_AUTH_KEY, c.Auth) + urlMap.Set(constant.PARAMTER_SIGNATURE_ENABLE_KEY, c.ParamSign) + for _, v := range c.Methods { prefix := "methods." + v.Name + "." urlMap.Set(prefix+constant.LOADBALANCE_KEY, v.Loadbalance) diff --git a/filter/access_key.go b/filter/access_key.go new file mode 100644 index 0000000000000000000000000000000000000000..c9bdd4ff8993d51e4d5002a1216225e2da074df5 --- /dev/null +++ b/filter/access_key.go @@ -0,0 +1,22 @@ +package filter + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" +) + +type AccessKeyPair struct { + AccessKey string `yaml:"accessKey" json:"accessKey,omitempty" property:"accessKey"` + SecretKey string `yaml:"secretKey" json:"secretKey,omitempty" property:"secretKey"` + ConsumerSide string `yaml:"consumerSide" json:"ConsumerSide,consumerSide" property:"consumerSide"` + ProviderSide string `yaml:"providerSide" json:"providerSide,omitempty" property:"providerSide"` + Creator string `yaml:"creator" json:"creator,omitempty" property:"creator"` + Options string `yaml:"options" json:"options,omitempty" property:"options"` +} + +// AccessKeyStorage +// This SPI Extension support us to store our AccessKeyPair or load AccessKeyPair from other +// storage, such as filesystem. +type AccessKeyStorage interface { + GetAccessKeyPair(protocol.Invocation, *common.URL) *AccessKeyPair +} diff --git a/filter/authenticator.go b/filter/authenticator.go new file mode 100644 index 0000000000000000000000000000000000000000..ce0547b36b03b7078784a6c05c08cd3f89611ca4 --- /dev/null +++ b/filter/authenticator.go @@ -0,0 +1,18 @@ +package filter + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" +) + +// Authenticator +type Authenticator interface { + + // Sign + // give a sign to request + Sign(protocol.Invocation, *common.URL) error + + // Authenticate + // verify the signature of the request is valid or not + Authenticate(protocol.Invocation, *common.URL) error +} diff --git a/filter/filter_impl/auth/accesskey_storage.go b/filter/filter_impl/auth/accesskey_storage.go new file mode 100644 index 0000000000000000000000000000000000000000..0a2bf47cbd377899ba8a0edf4a67026dd827d41f --- /dev/null +++ b/filter/filter_impl/auth/accesskey_storage.go @@ -0,0 +1,31 @@ +package auth + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" +) + +// DefaultAccesskeyStorage +// The default implementation of AccesskeyStorage +type DefaultAccesskeyStorage struct { +} + +// GetAccessKeyPair +// get AccessKeyPair from url by the key "accessKeyId" and "secretAccessKey" +func (storage *DefaultAccesskeyStorage) GetAccessKeyPair(invocation protocol.Invocation, url *common.URL) *filter.AccessKeyPair { + return &filter.AccessKeyPair{ + AccessKey: url.GetParam(constant.ACCESS_KEY_ID_KEY, ""), + SecretKey: url.GetParam(constant.SECRET_ACCESS_KEY_KEY, ""), + } +} + +func init() { + extension.SetAccesskeyStorages(constant.DEFAULT_ACCESS_KEY_STORAGE, GetDefaultAccesskeyStorage) +} + +func GetDefaultAccesskeyStorage() filter.AccessKeyStorage { + return &DefaultAccesskeyStorage{} +} diff --git a/filter/filter_impl/auth/accesskey_storage_test.go b/filter/filter_impl/auth/accesskey_storage_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6ab861a8673b191be0a8063980e1dc53e4e70f60 --- /dev/null +++ b/filter/filter_impl/auth/accesskey_storage_test.go @@ -0,0 +1,28 @@ +package auth + +import ( + "net/url" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + invocation2 "github.com/apache/dubbo-go/protocol/invocation" +) + +func TestDefaultAccesskeyStorage_GetAccesskeyPair(t *testing.T) { + url := common.NewURLWithOptions( + common.WithParams(url.Values{}), + common.WithParamsValue(constant.SECRET_ACCESS_KEY_KEY, "skey"), + common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey")) + invocation := &invocation2.RPCInvocation{} + storage := GetDefaultAccesskeyStorage() + accesskeyPair := storage.GetAccessKeyPair(invocation, url) + assert.Equal(t, "skey", accesskeyPair.SecretKey) + assert.Equal(t, "akey", accesskeyPair.AccessKey) +} diff --git a/filter/filter_impl/auth/consumer_sign.go b/filter/filter_impl/auth/consumer_sign.go new file mode 100644 index 0000000000000000000000000000000000000000..be86b5c74bb9fd02b96483edb18571d47d205ee7 --- /dev/null +++ b/filter/filter_impl/auth/consumer_sign.go @@ -0,0 +1,43 @@ +package auth + +import ( + "context" + "fmt" +) +import ( + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" +) + +// ConsumerSignFilter +// This filter is working for signing the request on consumer side +type ConsumerSignFilter struct { +} + +func init() { + extension.SetFilter(constant.CONSUMER_SIGN_FILTER, getConsumerSignFilter) +} + +func (csf *ConsumerSignFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + logger.Infof("invoking ConsumerSign filter.") + url := invoker.GetUrl() + + err := doAuthWork(&url, func(authenticator filter.Authenticator) error { + return authenticator.Sign(invocation, &url) + }) + if err != nil { + panic(fmt.Sprintf("Sign for invocation %s # %s failed", url.ServiceKey(), invocation.MethodName())) + + } + return invoker.Invoke(ctx, invocation) +} + +func (csf *ConsumerSignFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + return result +} +func getConsumerSignFilter() filter.Filter { + return &ConsumerSignFilter{} +} diff --git a/filter/filter_impl/auth/consumer_sign_test.go b/filter/filter_impl/auth/consumer_sign_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c90a769bcbc32e1f685404e3bfac54e56de83b90 --- /dev/null +++ b/filter/filter_impl/auth/consumer_sign_test.go @@ -0,0 +1,37 @@ +package auth + +import ( + "context" + "testing" +) + +import ( + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/protocol" + "github.com/apache/dubbo-go/protocol/invocation" + "github.com/apache/dubbo-go/protocol/mock" +) + +func TestConsumerSignFilter_Invoke(t *testing.T) { + url, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0") + url.SetParam(constant.SECRET_ACCESS_KEY_KEY, "sk") + url.SetParam(constant.ACCESS_KEY_ID_KEY, "ak") + inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, nil) + filter := &ConsumerSignFilter{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + invoker := mock.NewMockInvoker(ctrl) + result := &protocol.RPCResult{} + invoker.EXPECT().Invoke(inv).Return(result).Times(2) + invoker.EXPECT().GetUrl().Return(url).Times(2) + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) + + url.SetParam(constant.SERVICE_AUTH_KEY, "true") + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) +} diff --git a/filter/filter_impl/auth/default_authenticator.go b/filter/filter_impl/auth/default_authenticator.go new file mode 100644 index 0000000000000000000000000000000000000000..73eb9cddc0e1b7b4747da4b0f3e883075e349226 --- /dev/null +++ b/filter/filter_impl/auth/default_authenticator.go @@ -0,0 +1,120 @@ +package auth + +import ( + "errors" + "fmt" + "github.com/apache/dubbo-go/filter" + "strconv" + "time" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/protocol" + invocation_impl "github.com/apache/dubbo-go/protocol/invocation" +) + +func init() { + extension.SetAuthenticator(constant.DEFAULT_AUTHENTICATOR, GetDefaultAuthenticator) +} + +// DefaultAuthenticator +// The default implemetation of Authenticator +type DefaultAuthenticator struct { +} + +// Sign +// add the signature for the invocation +func (authenticator *DefaultAuthenticator) Sign(invocation protocol.Invocation, url *common.URL) error { + currentTimeMillis := strconv.Itoa(int(time.Now().Unix() * 1000)) + + consumer := url.GetParam(constant.APPLICATION_KEY, "") + accessKeyPair, err := getAccessKeyPair(invocation, url) + if err != nil { + return errors.New("get accesskey pair failed, cause: " + err.Error()) + } + inv := invocation.(*invocation_impl.RPCInvocation) + signature, err := getSignature(url, invocation, accessKeyPair.SecretKey, currentTimeMillis) + if err != nil { + return err + } + inv.SetAttachments(constant.REQUEST_SIGNATURE_KEY, signature) + inv.SetAttachments(constant.REQUEST_TIMESTAMP_KEY, currentTimeMillis) + inv.SetAttachments(constant.AK_KEY, accessKeyPair.AccessKey) + inv.SetAttachments(constant.CONSUMER, consumer) + return nil +} + +// getSignature +// get signature by the metadata and params of the invocation +func getSignature(url *common.URL, invocation protocol.Invocation, secrectKey string, currentTime string) (string, error) { + + requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT, + url.ColonSeparatedKey(), invocation.MethodName(), secrectKey, currentTime) + var signature string + if parameterEncrypt := url.GetParamBool(constant.PARAMTER_SIGNATURE_ENABLE_KEY, false); parameterEncrypt { + var err error + if signature, err = SignWithParams(invocation.Arguments(), requestString, secrectKey); err != nil { + // TODO + return "", errors.New("sign the request with params failed, cause:" + err.Error()) + } + } else { + signature = Sign(requestString, secrectKey) + } + + return signature, nil +} + +// Authenticate +// This method verifies whether the signature sent by the requester is correct +func (authenticator *DefaultAuthenticator) Authenticate(invocation protocol.Invocation, url *common.URL) error { + accessKeyId := invocation.AttachmentsByKey(constant.AK_KEY, "") + + requestTimestamp := invocation.AttachmentsByKey(constant.REQUEST_TIMESTAMP_KEY, "") + originSignature := invocation.AttachmentsByKey(constant.REQUEST_SIGNATURE_KEY, "") + consumer := invocation.AttachmentsByKey(constant.CONSUMER, "") + if IsEmpty(accessKeyId, false) || IsEmpty(consumer, false) || + IsEmpty(requestTimestamp, false) || IsEmpty(originSignature, false) { + return errors.New("failed to authenticate your ak/sk, maybe the consumer has not enabled the auth") + } + + accessKeyPair, err := getAccessKeyPair(invocation, url) + if err != nil { + return errors.New("failed to authenticate , can't load the accessKeyPair") + } + + computeSignature, err := getSignature(url, invocation, accessKeyPair.SecretKey, requestTimestamp) + if err != nil { + return err + } + if success := computeSignature == originSignature; !success { + return errors.New("failed to authenticate, signature is not correct") + } + return nil +} + +func getAccessKeyPair(invocation protocol.Invocation, url *common.URL) (*filter.AccessKeyPair, error) { + accesskeyStorage := extension.GetAccesskeyStorages(url.GetParam(constant.ACCESS_KEY_STORAGE_KEY, constant.DEFAULT_ACCESS_KEY_STORAGE)) + accessKeyPair := accesskeyStorage.GetAccessKeyPair(invocation, url) + if accessKeyPair == nil || IsEmpty(accessKeyPair.AccessKey, false) || IsEmpty(accessKeyPair.SecretKey, true) { + return nil, errors.New("accessKeyId or secretAccessKey not found") + } else { + return accessKeyPair, nil + } +} + +func GetDefaultAuthenticator() filter.Authenticator { + return &DefaultAuthenticator{} +} + +func doAuthWork(url *common.URL, do func(filter.Authenticator) error) error { + + shouldAuth := url.GetParamBool(constant.SERVICE_AUTH_KEY, false) + if shouldAuth { + authenticator := extension.GetAuthenticator(url.GetParam(constant.AUTHENTICATOR_KEY, constant.DEFAULT_AUTHENTICATOR)) + return do(authenticator) + } + return nil +} diff --git a/filter/filter_impl/auth/default_authenticator_test.go b/filter/filter_impl/auth/default_authenticator_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f1e014b5c6ae64efc662d763614e0dd4940edd86 --- /dev/null +++ b/filter/filter_impl/auth/default_authenticator_test.go @@ -0,0 +1,131 @@ +package auth + +import ( + "context" + "fmt" + "net/url" + "strconv" + "testing" + "time" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/protocol/invocation" +) + +func TestDefaultAuthenticator_Authenticate(t *testing.T) { + secret := "dubbo-sk" + access := "dubbo-ak" + testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0") + testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "true") + testurl.SetParam(constant.ACCESS_KEY_ID_KEY, access) + testurl.SetParam(constant.SECRET_ACCESS_KEY_KEY, secret) + parmas := []interface{}{"OK", struct { + Name string + Id int64 + }{"YUYU", 1}} + inv := invocation.NewRPCInvocation("test", parmas, nil) + requestTime := strconv.Itoa(int(time.Now().Unix() * 1000)) + signature, _ := getSignature(&testurl, inv, secret, requestTime) + + var authenticator = &DefaultAuthenticator{} + + invcation := invocation.NewRPCInvocation("test", parmas, map[string]string{ + constant.REQUEST_SIGNATURE_KEY: signature, + constant.CONSUMER: "test", + constant.REQUEST_TIMESTAMP_KEY: requestTime, + constant.AK_KEY: access, + }) + err := authenticator.Authenticate(invcation, &testurl) + assert.Nil(t, err) + // modify the params + invcation = invocation.NewRPCInvocation("test", parmas[:1], map[string]string{ + constant.REQUEST_SIGNATURE_KEY: signature, + constant.CONSUMER: "test", + constant.REQUEST_TIMESTAMP_KEY: requestTime, + constant.AK_KEY: access, + }) + err = authenticator.Authenticate(invcation, &testurl) + assert.NotNil(t, err) + +} + +func TestDefaultAuthenticator_Sign(t *testing.T) { + authenticator := &DefaultAuthenticator{} + testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?application=test&interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0") + testurl.SetParam(constant.ACCESS_KEY_ID_KEY, "akey") + testurl.SetParam(constant.SECRET_ACCESS_KEY_KEY, "skey") + testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "false") + inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, nil) + _ = authenticator.Sign(inv, &testurl) + assert.NotEqual(t, inv.AttachmentsByKey(constant.REQUEST_SIGNATURE_KEY, ""), "") + assert.NotEqual(t, inv.AttachmentsByKey(constant.CONSUMER, ""), "") + assert.NotEqual(t, inv.AttachmentsByKey(constant.REQUEST_TIMESTAMP_KEY, ""), "") + assert.Equal(t, inv.AttachmentsByKey(constant.AK_KEY, ""), "akey") + +} + +func Test_getAccessKeyPairSuccess(t *testing.T) { + testurl := common.NewURLWithOptions( + common.WithParams(url.Values{}), + common.WithParamsValue(constant.SECRET_ACCESS_KEY_KEY, "skey"), + common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey")) + invcation := invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil) + _, e := getAccessKeyPair(invcation, testurl) + assert.Nil(t, e) +} + +func Test_getAccessKeyPairFailed(t *testing.T) { + defer func() { + e := recover() + assert.NotNil(t, e) + }() + testurl := common.NewURLWithOptions( + common.WithParams(url.Values{}), + common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey")) + invcation := invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil) + _, e := getAccessKeyPair(invcation, testurl) + assert.NotNil(t, e) + testurl = common.NewURLWithOptions( + common.WithParams(url.Values{}), + common.WithParamsValue(constant.SECRET_ACCESS_KEY_KEY, "skey"), + common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey"), common.WithParamsValue(constant.ACCESS_KEY_STORAGE_KEY, "dubbo")) + _, e = getAccessKeyPair(invcation, testurl) + +} + +func Test_getSignatureWithinParams(t *testing.T) { + testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0") + testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "true") + inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, map[string]string{ + "": "", + }) + secret := "dubbo" + current := strconv.Itoa(int(time.Now().Unix() * 1000)) + signature, _ := getSignature(&testurl, inv, secret, current) + requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT, + testurl.ColonSeparatedKey(), inv.MethodName(), secret, current) + s, _ := SignWithParams(inv.Arguments(), requestString, secret) + assert.False(t, IsEmpty(signature, false)) + assert.Equal(t, s, signature) +} + +func Test_getSignature(t *testing.T) { + testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0") + testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "false") + inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, nil) + secret := "dubbo" + current := strconv.Itoa(int(time.Now().Unix() * 1000)) + signature, _ := getSignature(&testurl, inv, secret, current) + requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT, + testurl.ColonSeparatedKey(), inv.MethodName(), secret, current) + s := Sign(requestString, secret) + assert.False(t, IsEmpty(signature, false)) + assert.Equal(t, s, signature) +} diff --git a/filter/filter_impl/auth/provider_auth.go b/filter/filter_impl/auth/provider_auth.go new file mode 100644 index 0000000000000000000000000000000000000000..90804934f6b01a61f021f61f4ee549d744ccee72 --- /dev/null +++ b/filter/filter_impl/auth/provider_auth.go @@ -0,0 +1,43 @@ +package auth + +import ( + "context" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" +) + +// ProviderAuthFilter +// This filter is used to verify the correctness of the signature on provider side +type ProviderAuthFilter struct { +} + +func init() { + extension.SetFilter(constant.PROVIDER_AUTH_FILTER, getProviderAuthFilter) +} + +func (paf *ProviderAuthFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + logger.Infof("invoking providerAuth filter.") + url := invoker.GetUrl() + + err := doAuthWork(&url, func(authenticator filter.Authenticator) error { + return authenticator.Authenticate(invocation, &url) + }) + if err != nil { + logger.Infof("auth the request: %v occur exception, cause: %s", invocation, err.Error()) + return &protocol.RPCResult{ + Err: err, + } + } + + return invoker.Invoke(ctx, invocation) +} + +func (paf *ProviderAuthFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + return result +} +func getProviderAuthFilter() filter.Filter { + return &ProviderAuthFilter{} +} diff --git a/filter/filter_impl/auth/provider_auth_test.go b/filter/filter_impl/auth/provider_auth_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7552a4aa0447e18b40d8160b895d3ab65ee5edb1 --- /dev/null +++ b/filter/filter_impl/auth/provider_auth_test.go @@ -0,0 +1,57 @@ +package auth + +import ( + "context" + "strconv" + "testing" + "time" +) + +import ( + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/protocol" + "github.com/apache/dubbo-go/protocol/invocation" + "github.com/apache/dubbo-go/protocol/mock" +) + +func TestProviderAuthFilter_Invoke(t *testing.T) { + secret := "dubbo-sk" + access := "dubbo-ak" + url, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0") + url.SetParam(constant.ACCESS_KEY_ID_KEY, access) + url.SetParam(constant.SECRET_ACCESS_KEY_KEY, secret) + parmas := []interface{}{ + "OK", + struct { + Name string + Id int64 + }{"YUYU", 1}, + } + inv := invocation.NewRPCInvocation("test", parmas, nil) + requestTime := strconv.Itoa(int(time.Now().Unix() * 1000)) + signature, _ := getSignature(&url, inv, secret, requestTime) + + inv = invocation.NewRPCInvocation("test", []interface{}{"OK"}, map[string]string{ + constant.REQUEST_SIGNATURE_KEY: signature, + constant.CONSUMER: "test", + constant.REQUEST_TIMESTAMP_KEY: requestTime, + constant.AK_KEY: access, + }) + ctrl := gomock.NewController(t) + filter := &ProviderAuthFilter{} + defer ctrl.Finish() + invoker := mock.NewMockInvoker(ctrl) + result := &protocol.RPCResult{} + invoker.EXPECT().Invoke(inv).Return(result).Times(2) + invoker.EXPECT().GetUrl().Return(url).Times(2) + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) + url.SetParam(constant.SERVICE_AUTH_KEY, "true") + assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv)) + +} diff --git a/filter/filter_impl/auth/sign_util.go b/filter/filter_impl/auth/sign_util.go new file mode 100644 index 0000000000000000000000000000000000000000..60698439c5abc1ff0cc555b2ceec77bf2e0e53d5 --- /dev/null +++ b/filter/filter_impl/auth/sign_util.go @@ -0,0 +1,55 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "strings" +) + +// Sign +// get a signature string with given information, such as metadata or parameters +func Sign(metadata, key string) string { + return doSign([]byte(metadata), key) +} + +func SignWithParams(params []interface{}, metadata, key string) (string, error) { + if params == nil || len(params) == 0 { + return Sign(metadata, key), nil + } + + data := append(params, metadata) + if bytes, err := toBytes(data); err != nil { + // TODO + return "", errors.New("data convert to bytes failed") + } else { + return doSign(bytes, key), nil + } +} + +func toBytes(data []interface{}) ([]byte, error) { + if bytes, err := json.Marshal(data); err != nil { + return nil, errors.New("") + } else { + return bytes, nil + } +} + +func doSign(bytes []byte, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write(bytes) + signature := mac.Sum(nil) + return base64.URLEncoding.EncodeToString(signature) +} + +func IsEmpty(s string, allowSpace bool) bool { + if len(s) == 0 { + return true + } + if !allowSpace { + return strings.TrimSpace(s) == "" + } + return false +} diff --git a/filter/filter_impl/auth/sign_util_test.go b/filter/filter_impl/auth/sign_util_test.go new file mode 100644 index 0000000000000000000000000000000000000000..de6154e8854af99f8e862d94ee45aefcbf26b12b --- /dev/null +++ b/filter/filter_impl/auth/sign_util_test.go @@ -0,0 +1,84 @@ +package auth + +import ( + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +func TestIsEmpty(t *testing.T) { + type args struct { + s string + allowSpace bool + } + tests := []struct { + name string + args args + want bool + }{ + // TODO: Add test cases. + {"test1", args{s: " ", allowSpace: false}, true}, + {"test2", args{s: " ", allowSpace: true}, false}, + {"test3", args{s: "hello,dubbo", allowSpace: false}, false}, + {"test4", args{s: "hello,dubbo", allowSpace: true}, false}, + {"test5", args{s: "", allowSpace: true}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsEmpty(tt.args.s, tt.args.allowSpace); got != tt.want { + t.Errorf("IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSign(t *testing.T) { + metadata := "com.ikurento.user.UserProvider::sayHi" + key := "key" + signature := Sign(metadata, key) + assert.NotNil(t, signature) + +} + +func TestSignWithParams(t *testing.T) { + metadata := "com.ikurento.user.UserProvider::sayHi" + key := "key" + params := []interface{}{ + "a", 1, struct { + Name string + Id int64 + }{"YuYu", 1}, + } + signature, _ := SignWithParams(params, metadata, key) + assert.False(t, IsEmpty(signature, false)) +} + +func Test_doSign(t *testing.T) { + sign := doSign([]byte("DubboGo"), "key") + sign1 := doSign([]byte("DubboGo"), "key") + sign2 := doSign([]byte("DubboGo"), "key2") + assert.NotNil(t, sign) + assert.Equal(t, sign1, sign) + assert.NotEqual(t, sign1, sign2) +} + +func Test_toBytes(t *testing.T) { + params := []interface{}{ + "a", 1, struct { + Name string + Id int64 + }{"YuYu", 1}, + } + params2 := []interface{}{ + "a", 1, struct { + Name string + Id int64 + }{"YuYu", 1}, + } + jsonBytes, _ := toBytes(params) + jsonBytes2, _ := toBytes(params2) + assert.NotNil(t, jsonBytes) + assert.Equal(t, jsonBytes, jsonBytes2) +}