Skip to content
Snippets Groups Projects
Commit a8862607 authored by vito.he's avatar vito.he Committed by GitHub
Browse files

Merge pull request #323 from CodingSinger/auth

support sign and auth for request
parents 78e40ef1 e652b2a3
No related branches found
No related tags found
No related merge requests found
Showing
with 761 additions and 0 deletions
......@@ -141,3 +141,21 @@ const (
const (
TRACING_REMOTE_SPAN_CTX = "tracing.remote.span.ctx"
)
const (
CONSUMER_SIGN_FILTER = "sign"
PROVIDER_AUTH_FILTER = "auth"
SERVICE_AUTH_KEY = "auth"
AUTHENTICATOR_KEY = "authenticator"
DEFAULT_AUTHENTICATOR = "accesskeys"
DEFAULT_ACCESS_KEY_STORAGE = "urlstorage"
ACCESS_KEY_STORAGE_KEY = "accessKey.storage"
REQUEST_TIMESTAMP_KEY = "timestamp"
REQUEST_SIGNATURE_KEY = "signature"
AK_KEY = "ak"
SIGNATURE_STRING_FORMAT = "%s#%s#%s#%s"
PARAMTER_SIGNATURE_ENABLE_KEY = "param.sign"
CONSUMER = "consumer"
ACCESS_KEY_ID_KEY = "accessKeyId"
SECRET_ACCESS_KEY_KEY = "secretAccessKey"
)
package extension
import (
"github.com/apache/dubbo-go/filter"
)
var (
authenticators = make(map[string]func() filter.Authenticator)
accesskeyStorages = make(map[string]func() filter.AccessKeyStorage)
)
func SetAuthenticator(name string, fcn func() filter.Authenticator) {
authenticators[name] = fcn
}
func GetAuthenticator(name string) filter.Authenticator {
if authenticators[name] == nil {
panic("authenticator for " + name + " is not existing, make sure you have import the package.")
}
return authenticators[name]()
}
func SetAccesskeyStorages(name string, fcn func() filter.AccessKeyStorage) {
accesskeyStorages[name] = fcn
}
func GetAccesskeyStorages(name string) filter.AccessKeyStorage {
if accesskeyStorages[name] == nil {
panic("accesskeyStorages for " + name + " is not existing, make sure you have import the package.")
}
return accesskeyStorages[name]()
}
......@@ -339,6 +339,28 @@ func (c URL) ServiceKey() string {
return buf.String()
}
// ColonSeparatedKey
// The format is "{interface}:[version]:[group]"
func (c *URL) ColonSeparatedKey() string {
intf := c.GetParam(constant.INTERFACE_KEY, strings.TrimPrefix(c.Path, "/"))
if intf == "" {
return ""
}
buf := &bytes.Buffer{}
buf.WriteString(intf)
buf.WriteString(":")
version := c.GetParam(constant.VERSION_KEY, "")
if version != "" && version != "0.0.0" {
buf.WriteString(version)
}
group := c.GetParam(constant.GROUP_KEY, "")
buf.WriteString(":")
if group != "" {
buf.WriteString(group)
}
return buf.String()
}
// EncodedServiceKey ...
func (c *URL) EncodedServiceKey() string {
serviceKey := c.ServiceKey()
......
......@@ -271,3 +271,17 @@ func TestClone(t *testing.T) {
assert.Equal(t, u1.Protocol, "dubbo")
assert.Equal(t, u2.Protocol, "provider")
}
func TestColonSeparatedKey(t *testing.T) {
u1, _ := NewURL(context.TODO(), "dubbo://127.0.0.1:20000")
u1.AddParam(constant.INTERFACE_KEY, "com.ikurento.user.UserProvider")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::")
u1.AddParam(constant.VERSION_KEY, "version1")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:")
u1.AddParam(constant.GROUP_KEY, "group1")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:group1")
u1.SetParam(constant.VERSION_KEY, "")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::group1")
}
......@@ -67,6 +67,8 @@ type ServiceConfig struct {
TpsLimitRejectedHandler string `yaml:"tps.limit.rejected.handler" json:"tps.limit.rejected.handler,omitempty" property:"tps.limit.rejected.handler"`
ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"`
ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"`
Auth string `yaml:"auth" json:"auth,omitempty" property:"auth"`
ParamSign string `yaml:"param.sign" json:"param.sign,omitempty" property:"param.sign"`
unexported *atomic.Bool
exported *atomic.Bool
......@@ -220,6 +222,10 @@ func (c *ServiceConfig) getUrlMap() url.Values {
urlMap.Set(constant.EXECUTE_LIMIT_KEY, c.ExecuteLimit)
urlMap.Set(constant.EXECUTE_REJECTED_EXECUTION_HANDLER_KEY, c.ExecuteLimitRejectedHandler)
// auth filter
urlMap.Set(constant.SERVICE_AUTH_KEY, c.Auth)
urlMap.Set(constant.PARAMTER_SIGNATURE_ENABLE_KEY, c.ParamSign)
for _, v := range c.Methods {
prefix := "methods." + v.Name + "."
urlMap.Set(prefix+constant.LOADBALANCE_KEY, v.Loadbalance)
......
package filter
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
)
type AccessKeyPair struct {
AccessKey string `yaml:"accessKey" json:"accessKey,omitempty" property:"accessKey"`
SecretKey string `yaml:"secretKey" json:"secretKey,omitempty" property:"secretKey"`
ConsumerSide string `yaml:"consumerSide" json:"ConsumerSide,consumerSide" property:"consumerSide"`
ProviderSide string `yaml:"providerSide" json:"providerSide,omitempty" property:"providerSide"`
Creator string `yaml:"creator" json:"creator,omitempty" property:"creator"`
Options string `yaml:"options" json:"options,omitempty" property:"options"`
}
// AccessKeyStorage
// This SPI Extension support us to store our AccessKeyPair or load AccessKeyPair from other
// storage, such as filesystem.
type AccessKeyStorage interface {
GetAccessKeyPair(protocol.Invocation, *common.URL) *AccessKeyPair
}
package filter
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
)
// Authenticator
type Authenticator interface {
// Sign
// give a sign to request
Sign(protocol.Invocation, *common.URL) error
// Authenticate
// verify the signature of the request is valid or not
Authenticate(protocol.Invocation, *common.URL) error
}
package auth
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
)
// DefaultAccesskeyStorage
// The default implementation of AccesskeyStorage
type DefaultAccesskeyStorage struct {
}
// GetAccessKeyPair
// get AccessKeyPair from url by the key "accessKeyId" and "secretAccessKey"
func (storage *DefaultAccesskeyStorage) GetAccessKeyPair(invocation protocol.Invocation, url *common.URL) *filter.AccessKeyPair {
return &filter.AccessKeyPair{
AccessKey: url.GetParam(constant.ACCESS_KEY_ID_KEY, ""),
SecretKey: url.GetParam(constant.SECRET_ACCESS_KEY_KEY, ""),
}
}
func init() {
extension.SetAccesskeyStorages(constant.DEFAULT_ACCESS_KEY_STORAGE, GetDefaultAccesskeyStorage)
}
func GetDefaultAccesskeyStorage() filter.AccessKeyStorage {
return &DefaultAccesskeyStorage{}
}
package auth
import (
"net/url"
"testing"
)
import (
"github.com/stretchr/testify/assert"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
invocation2 "github.com/apache/dubbo-go/protocol/invocation"
)
func TestDefaultAccesskeyStorage_GetAccesskeyPair(t *testing.T) {
url := common.NewURLWithOptions(
common.WithParams(url.Values{}),
common.WithParamsValue(constant.SECRET_ACCESS_KEY_KEY, "skey"),
common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey"))
invocation := &invocation2.RPCInvocation{}
storage := GetDefaultAccesskeyStorage()
accesskeyPair := storage.GetAccessKeyPair(invocation, url)
assert.Equal(t, "skey", accesskeyPair.SecretKey)
assert.Equal(t, "akey", accesskeyPair.AccessKey)
}
package auth
import (
"context"
"fmt"
)
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
)
// ConsumerSignFilter
// This filter is working for signing the request on consumer side
type ConsumerSignFilter struct {
}
func init() {
extension.SetFilter(constant.CONSUMER_SIGN_FILTER, getConsumerSignFilter)
}
func (csf *ConsumerSignFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking ConsumerSign filter.")
url := invoker.GetUrl()
err := doAuthWork(&url, func(authenticator filter.Authenticator) error {
return authenticator.Sign(invocation, &url)
})
if err != nil {
panic(fmt.Sprintf("Sign for invocation %s # %s failed", url.ServiceKey(), invocation.MethodName()))
}
return invoker.Invoke(ctx, invocation)
}
func (csf *ConsumerSignFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
func getConsumerSignFilter() filter.Filter {
return &ConsumerSignFilter{}
}
package auth
import (
"context"
"testing"
)
import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
"github.com/apache/dubbo-go/protocol/mock"
)
func TestConsumerSignFilter_Invoke(t *testing.T) {
url, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0")
url.SetParam(constant.SECRET_ACCESS_KEY_KEY, "sk")
url.SetParam(constant.ACCESS_KEY_ID_KEY, "ak")
inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, nil)
filter := &ConsumerSignFilter{}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
invoker := mock.NewMockInvoker(ctrl)
result := &protocol.RPCResult{}
invoker.EXPECT().Invoke(inv).Return(result).Times(2)
invoker.EXPECT().GetUrl().Return(url).Times(2)
assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv))
url.SetParam(constant.SERVICE_AUTH_KEY, "true")
assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv))
}
package auth
import (
"errors"
"fmt"
"github.com/apache/dubbo-go/filter"
"strconv"
"time"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/protocol"
invocation_impl "github.com/apache/dubbo-go/protocol/invocation"
)
func init() {
extension.SetAuthenticator(constant.DEFAULT_AUTHENTICATOR, GetDefaultAuthenticator)
}
// DefaultAuthenticator
// The default implemetation of Authenticator
type DefaultAuthenticator struct {
}
// Sign
// add the signature for the invocation
func (authenticator *DefaultAuthenticator) Sign(invocation protocol.Invocation, url *common.URL) error {
currentTimeMillis := strconv.Itoa(int(time.Now().Unix() * 1000))
consumer := url.GetParam(constant.APPLICATION_KEY, "")
accessKeyPair, err := getAccessKeyPair(invocation, url)
if err != nil {
return errors.New("get accesskey pair failed, cause: " + err.Error())
}
inv := invocation.(*invocation_impl.RPCInvocation)
signature, err := getSignature(url, invocation, accessKeyPair.SecretKey, currentTimeMillis)
if err != nil {
return err
}
inv.SetAttachments(constant.REQUEST_SIGNATURE_KEY, signature)
inv.SetAttachments(constant.REQUEST_TIMESTAMP_KEY, currentTimeMillis)
inv.SetAttachments(constant.AK_KEY, accessKeyPair.AccessKey)
inv.SetAttachments(constant.CONSUMER, consumer)
return nil
}
// getSignature
// get signature by the metadata and params of the invocation
func getSignature(url *common.URL, invocation protocol.Invocation, secrectKey string, currentTime string) (string, error) {
requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT,
url.ColonSeparatedKey(), invocation.MethodName(), secrectKey, currentTime)
var signature string
if parameterEncrypt := url.GetParamBool(constant.PARAMTER_SIGNATURE_ENABLE_KEY, false); parameterEncrypt {
var err error
if signature, err = SignWithParams(invocation.Arguments(), requestString, secrectKey); err != nil {
// TODO
return "", errors.New("sign the request with params failed, cause:" + err.Error())
}
} else {
signature = Sign(requestString, secrectKey)
}
return signature, nil
}
// Authenticate
// This method verifies whether the signature sent by the requester is correct
func (authenticator *DefaultAuthenticator) Authenticate(invocation protocol.Invocation, url *common.URL) error {
accessKeyId := invocation.AttachmentsByKey(constant.AK_KEY, "")
requestTimestamp := invocation.AttachmentsByKey(constant.REQUEST_TIMESTAMP_KEY, "")
originSignature := invocation.AttachmentsByKey(constant.REQUEST_SIGNATURE_KEY, "")
consumer := invocation.AttachmentsByKey(constant.CONSUMER, "")
if IsEmpty(accessKeyId, false) || IsEmpty(consumer, false) ||
IsEmpty(requestTimestamp, false) || IsEmpty(originSignature, false) {
return errors.New("failed to authenticate your ak/sk, maybe the consumer has not enabled the auth")
}
accessKeyPair, err := getAccessKeyPair(invocation, url)
if err != nil {
return errors.New("failed to authenticate , can't load the accessKeyPair")
}
computeSignature, err := getSignature(url, invocation, accessKeyPair.SecretKey, requestTimestamp)
if err != nil {
return err
}
if success := computeSignature == originSignature; !success {
return errors.New("failed to authenticate, signature is not correct")
}
return nil
}
func getAccessKeyPair(invocation protocol.Invocation, url *common.URL) (*filter.AccessKeyPair, error) {
accesskeyStorage := extension.GetAccesskeyStorages(url.GetParam(constant.ACCESS_KEY_STORAGE_KEY, constant.DEFAULT_ACCESS_KEY_STORAGE))
accessKeyPair := accesskeyStorage.GetAccessKeyPair(invocation, url)
if accessKeyPair == nil || IsEmpty(accessKeyPair.AccessKey, false) || IsEmpty(accessKeyPair.SecretKey, true) {
return nil, errors.New("accessKeyId or secretAccessKey not found")
} else {
return accessKeyPair, nil
}
}
func GetDefaultAuthenticator() filter.Authenticator {
return &DefaultAuthenticator{}
}
func doAuthWork(url *common.URL, do func(filter.Authenticator) error) error {
shouldAuth := url.GetParamBool(constant.SERVICE_AUTH_KEY, false)
if shouldAuth {
authenticator := extension.GetAuthenticator(url.GetParam(constant.AUTHENTICATOR_KEY, constant.DEFAULT_AUTHENTICATOR))
return do(authenticator)
}
return nil
}
package auth
import (
"context"
"fmt"
"net/url"
"strconv"
"testing"
"time"
)
import (
"github.com/stretchr/testify/assert"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/protocol/invocation"
)
func TestDefaultAuthenticator_Authenticate(t *testing.T) {
secret := "dubbo-sk"
access := "dubbo-ak"
testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0")
testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "true")
testurl.SetParam(constant.ACCESS_KEY_ID_KEY, access)
testurl.SetParam(constant.SECRET_ACCESS_KEY_KEY, secret)
parmas := []interface{}{"OK", struct {
Name string
Id int64
}{"YUYU", 1}}
inv := invocation.NewRPCInvocation("test", parmas, nil)
requestTime := strconv.Itoa(int(time.Now().Unix() * 1000))
signature, _ := getSignature(&testurl, inv, secret, requestTime)
var authenticator = &DefaultAuthenticator{}
invcation := invocation.NewRPCInvocation("test", parmas, map[string]string{
constant.REQUEST_SIGNATURE_KEY: signature,
constant.CONSUMER: "test",
constant.REQUEST_TIMESTAMP_KEY: requestTime,
constant.AK_KEY: access,
})
err := authenticator.Authenticate(invcation, &testurl)
assert.Nil(t, err)
// modify the params
invcation = invocation.NewRPCInvocation("test", parmas[:1], map[string]string{
constant.REQUEST_SIGNATURE_KEY: signature,
constant.CONSUMER: "test",
constant.REQUEST_TIMESTAMP_KEY: requestTime,
constant.AK_KEY: access,
})
err = authenticator.Authenticate(invcation, &testurl)
assert.NotNil(t, err)
}
func TestDefaultAuthenticator_Sign(t *testing.T) {
authenticator := &DefaultAuthenticator{}
testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?application=test&interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0")
testurl.SetParam(constant.ACCESS_KEY_ID_KEY, "akey")
testurl.SetParam(constant.SECRET_ACCESS_KEY_KEY, "skey")
testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "false")
inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, nil)
_ = authenticator.Sign(inv, &testurl)
assert.NotEqual(t, inv.AttachmentsByKey(constant.REQUEST_SIGNATURE_KEY, ""), "")
assert.NotEqual(t, inv.AttachmentsByKey(constant.CONSUMER, ""), "")
assert.NotEqual(t, inv.AttachmentsByKey(constant.REQUEST_TIMESTAMP_KEY, ""), "")
assert.Equal(t, inv.AttachmentsByKey(constant.AK_KEY, ""), "akey")
}
func Test_getAccessKeyPairSuccess(t *testing.T) {
testurl := common.NewURLWithOptions(
common.WithParams(url.Values{}),
common.WithParamsValue(constant.SECRET_ACCESS_KEY_KEY, "skey"),
common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey"))
invcation := invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil)
_, e := getAccessKeyPair(invcation, testurl)
assert.Nil(t, e)
}
func Test_getAccessKeyPairFailed(t *testing.T) {
defer func() {
e := recover()
assert.NotNil(t, e)
}()
testurl := common.NewURLWithOptions(
common.WithParams(url.Values{}),
common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey"))
invcation := invocation.NewRPCInvocation("MethodName", []interface{}{"OK"}, nil)
_, e := getAccessKeyPair(invcation, testurl)
assert.NotNil(t, e)
testurl = common.NewURLWithOptions(
common.WithParams(url.Values{}),
common.WithParamsValue(constant.SECRET_ACCESS_KEY_KEY, "skey"),
common.WithParamsValue(constant.ACCESS_KEY_ID_KEY, "akey"), common.WithParamsValue(constant.ACCESS_KEY_STORAGE_KEY, "dubbo"))
_, e = getAccessKeyPair(invcation, testurl)
}
func Test_getSignatureWithinParams(t *testing.T) {
testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0")
testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "true")
inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, map[string]string{
"": "",
})
secret := "dubbo"
current := strconv.Itoa(int(time.Now().Unix() * 1000))
signature, _ := getSignature(&testurl, inv, secret, current)
requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT,
testurl.ColonSeparatedKey(), inv.MethodName(), secret, current)
s, _ := SignWithParams(inv.Arguments(), requestString, secret)
assert.False(t, IsEmpty(signature, false))
assert.Equal(t, s, signature)
}
func Test_getSignature(t *testing.T) {
testurl, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0")
testurl.SetParam(constant.PARAMTER_SIGNATURE_ENABLE_KEY, "false")
inv := invocation.NewRPCInvocation("test", []interface{}{"OK"}, nil)
secret := "dubbo"
current := strconv.Itoa(int(time.Now().Unix() * 1000))
signature, _ := getSignature(&testurl, inv, secret, current)
requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT,
testurl.ColonSeparatedKey(), inv.MethodName(), secret, current)
s := Sign(requestString, secret)
assert.False(t, IsEmpty(signature, false))
assert.Equal(t, s, signature)
}
package auth
import (
"context"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
)
// ProviderAuthFilter
// This filter is used to verify the correctness of the signature on provider side
type ProviderAuthFilter struct {
}
func init() {
extension.SetFilter(constant.PROVIDER_AUTH_FILTER, getProviderAuthFilter)
}
func (paf *ProviderAuthFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking providerAuth filter.")
url := invoker.GetUrl()
err := doAuthWork(&url, func(authenticator filter.Authenticator) error {
return authenticator.Authenticate(invocation, &url)
})
if err != nil {
logger.Infof("auth the request: %v occur exception, cause: %s", invocation, err.Error())
return &protocol.RPCResult{
Err: err,
}
}
return invoker.Invoke(ctx, invocation)
}
func (paf *ProviderAuthFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
func getProviderAuthFilter() filter.Filter {
return &ProviderAuthFilter{}
}
package auth
import (
"context"
"strconv"
"testing"
"time"
)
import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
"github.com/apache/dubbo-go/protocol/mock"
)
func TestProviderAuthFilter_Invoke(t *testing.T) {
secret := "dubbo-sk"
access := "dubbo-ak"
url, _ := common.NewURL(context.TODO(), "dubbo://127.0.0.1:20000/com.ikurento.user.UserProvider?interface=com.ikurento.user.UserProvider&group=gg&version=2.6.0")
url.SetParam(constant.ACCESS_KEY_ID_KEY, access)
url.SetParam(constant.SECRET_ACCESS_KEY_KEY, secret)
parmas := []interface{}{
"OK",
struct {
Name string
Id int64
}{"YUYU", 1},
}
inv := invocation.NewRPCInvocation("test", parmas, nil)
requestTime := strconv.Itoa(int(time.Now().Unix() * 1000))
signature, _ := getSignature(&url, inv, secret, requestTime)
inv = invocation.NewRPCInvocation("test", []interface{}{"OK"}, map[string]string{
constant.REQUEST_SIGNATURE_KEY: signature,
constant.CONSUMER: "test",
constant.REQUEST_TIMESTAMP_KEY: requestTime,
constant.AK_KEY: access,
})
ctrl := gomock.NewController(t)
filter := &ProviderAuthFilter{}
defer ctrl.Finish()
invoker := mock.NewMockInvoker(ctrl)
result := &protocol.RPCResult{}
invoker.EXPECT().Invoke(inv).Return(result).Times(2)
invoker.EXPECT().GetUrl().Return(url).Times(2)
assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv))
url.SetParam(constant.SERVICE_AUTH_KEY, "true")
assert.Equal(t, result, filter.Invoke(context.Background(), invoker, inv))
}
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"strings"
)
// Sign
// get a signature string with given information, such as metadata or parameters
func Sign(metadata, key string) string {
return doSign([]byte(metadata), key)
}
func SignWithParams(params []interface{}, metadata, key string) (string, error) {
if params == nil || len(params) == 0 {
return Sign(metadata, key), nil
}
data := append(params, metadata)
if bytes, err := toBytes(data); err != nil {
// TODO
return "", errors.New("data convert to bytes failed")
} else {
return doSign(bytes, key), nil
}
}
func toBytes(data []interface{}) ([]byte, error) {
if bytes, err := json.Marshal(data); err != nil {
return nil, errors.New("")
} else {
return bytes, nil
}
}
func doSign(bytes []byte, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write(bytes)
signature := mac.Sum(nil)
return base64.URLEncoding.EncodeToString(signature)
}
func IsEmpty(s string, allowSpace bool) bool {
if len(s) == 0 {
return true
}
if !allowSpace {
return strings.TrimSpace(s) == ""
}
return false
}
package auth
import (
"testing"
)
import (
"github.com/stretchr/testify/assert"
)
func TestIsEmpty(t *testing.T) {
type args struct {
s string
allowSpace bool
}
tests := []struct {
name string
args args
want bool
}{
// TODO: Add test cases.
{"test1", args{s: " ", allowSpace: false}, true},
{"test2", args{s: " ", allowSpace: true}, false},
{"test3", args{s: "hello,dubbo", allowSpace: false}, false},
{"test4", args{s: "hello,dubbo", allowSpace: true}, false},
{"test5", args{s: "", allowSpace: true}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsEmpty(tt.args.s, tt.args.allowSpace); got != tt.want {
t.Errorf("IsEmpty() = %v, want %v", got, tt.want)
}
})
}
}
func TestSign(t *testing.T) {
metadata := "com.ikurento.user.UserProvider::sayHi"
key := "key"
signature := Sign(metadata, key)
assert.NotNil(t, signature)
}
func TestSignWithParams(t *testing.T) {
metadata := "com.ikurento.user.UserProvider::sayHi"
key := "key"
params := []interface{}{
"a", 1, struct {
Name string
Id int64
}{"YuYu", 1},
}
signature, _ := SignWithParams(params, metadata, key)
assert.False(t, IsEmpty(signature, false))
}
func Test_doSign(t *testing.T) {
sign := doSign([]byte("DubboGo"), "key")
sign1 := doSign([]byte("DubboGo"), "key")
sign2 := doSign([]byte("DubboGo"), "key2")
assert.NotNil(t, sign)
assert.Equal(t, sign1, sign)
assert.NotEqual(t, sign1, sign2)
}
func Test_toBytes(t *testing.T) {
params := []interface{}{
"a", 1, struct {
Name string
Id int64
}{"YuYu", 1},
}
params2 := []interface{}{
"a", 1, struct {
Name string
Id int64
}{"YuYu", 1},
}
jsonBytes, _ := toBytes(params)
jsonBytes2, _ := toBytes(params2)
assert.NotNil(t, jsonBytes)
assert.Equal(t, jsonBytes, jsonBytes2)
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment