Skip to content
Snippets Groups Projects
Commit 6477526d authored by Ooo0oO0o0oO's avatar Ooo0oO0o0oO
Browse files

[WIP] support sign and auth for request

parent 3977bc6c
No related branches found
No related tags found
No related merge requests found
......@@ -141,3 +141,20 @@ const (
const (
TRACING_REMOTE_SPAN_CTX = "tracing.remote.span.ctx"
)
const (
CONSUMER_SIGN_FILTER = "sign"
PROVIDER_AUTH_FILTER = "auth"
SERVICE_AUTH_KEY = "auth"
AUTHENTICATOR_KEY = "authenticator"
DEFAULT_AUTHENTICATOR = "accesskeys"
DEFAULT_ACCESS_KEY_STORAGE = "urlstorage"
ACCESS_KEY_STORAGE_KEY = "accessKey.storage"
SECRET_ACCESS_KEY_KEY = "secretAccessKey"
REQUEST_TIMESTAMP_KEY = "timestamp"
REQUEST_SIGNATURE_KEY = "signature"
AK_KEY = "ak"
SIGNATURE_STRING_FORMAT = "%s#%s#%s#%s"
PARAMTER_SIGNATURE_ENABLE_KEY = "param.sign"
CONSUMER = "consumer"
)
package extension
import (
"github.com/apache/dubbo-go/filter"
)
var(
authenticators = make(map[string]func() filter.Authenticator)
accesskeyStorages = make(map[string]func() filter.AccesskeyStorage)
)
func SetAuthenticator(name string, fcn func() filter.Authenticator) {
authenticators[name] = fcn
}
func GetAuthenticator(name string) filter.Authenticator {
if clusters[name] == nil {
panic("cluster for " + name + " is not existing, make sure you have import the package.")
}
return authenticators[name]()
}
func SetAccesskeyStorages(name string, fcn func() filter.AccesskeyStorage) {
accesskeyStorages[name] = fcn
}
func GetAccesskeyStorages(name string) filter.AccesskeyStorage {
if clusters[name] == nil {
panic("cluster for " + name + " is not existing, make sure you have import the package.")
}
return accesskeyStorages[name]()
}
\ No newline at end of file
......@@ -339,6 +339,26 @@ func (c URL) ServiceKey() string {
return buf.String()
}
func (c *URL) ColonSeparatedKey() string {
intf := c.GetParam(constant.INTERFACE_KEY, strings.TrimPrefix(c.Path, "/"))
if intf == "" {
return ""
}
buf := &bytes.Buffer{}
buf.WriteString(intf)
buf.WriteString(":")
version := c.GetParam(constant.VERSION_KEY, "")
if version != "" && version != "0.0.0" {
buf.WriteString(version)
}
group := c.GetParam(constant.GROUP_KEY, "")
buf.WriteString(":")
if group != "" {
buf.WriteString(group)
}
return buf.String()
}
// EncodedServiceKey ...
func (c *URL) EncodedServiceKey() string {
serviceKey := c.ServiceKey()
......
......@@ -271,3 +271,17 @@ func TestClone(t *testing.T) {
assert.Equal(t, u1.Protocol, "dubbo")
assert.Equal(t, u2.Protocol, "provider")
}
func TestColonSeparatedKey(t *testing.T) {
u1, _ := NewURL(context.TODO(), "dubbo://127.0.0.1:20000")
u1.AddParam(constant.INTERFACE_KEY, "com.ikurento.user.UserProvider")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::")
u1.AddParam(constant.VERSION_KEY, "version1")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:")
u1.AddParam(constant.GROUP_KEY, "group1")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:group1")
u1.SetParam(constant.VERSION_KEY, "")
assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::group1")
}
package filter
type AccessKeyPair struct {
AccessKey string
SecretKey string
ConsumerSide string
ProviderSide string
Creator string
Options string
}
package filter
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
)
type Authenticator interface {
Sign(protocol.Invocation, *common.URL) error
Authenticate(protocol.Invocation, *common.URL) error
}
type AccesskeyStorage interface {
GetAccesskeyPair(protocol.Invocation, *common.URL) *AccessKeyPair
}
package auth
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
)
const (
ACCESS_KEY_ID_KEY = "accessKeyId"
SECRET_ACCESS_KEY_KEY = "secretAccessKey"
)
type DefaultAccesskeyStorage struct {
}
func (storage *DefaultAccesskeyStorage) GetAccesskeyPair(invocation protocol.Invocation, url *common.URL) *filter.AccessKeyPair {
return &filter.AccessKeyPair{
AccessKey: url.GetParam(ACCESS_KEY_ID_KEY, ""),
SecretKey: url.GetParam(SECRET_ACCESS_KEY_KEY, ""),
}
}
func init() {
extension.SetAccesskeyStorages(constant.DEFAULT_ACCESS_KEY_STORAGE, GetDefaultAccesskeyStorage)
}
func GetDefaultAccesskeyStorage() filter.AccesskeyStorage {
return &DefaultAccesskeyStorage{}
}
package auth
import (
"errors"
"fmt"
"strconv"
"time"
)
import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
invocation_impl "github.com/apache/dubbo-go/protocol/invocation"
)
func init() {
extension.SetAuthenticator(constant.DEFAULT_AUTHENTICATOR, GetDefaultAuthenticator)
}
type DefaultAuthenticator struct {
}
func (authenticator *DefaultAuthenticator) Sign(invocation protocol.Invocation, url *common.URL) error {
currentTimeMillis := strconv.Itoa(time.Now().Second() * 1000)
consumer := url.GetParam(constant.APPLICATION_KEY, "")
accessKeyPair, err := getAccessKeyPair(invocation, url)
if err != nil {
return errors.New("get accesskey pair failed, cause: " + err.Error())
}
inv := invocation.(*invocation_impl.RPCInvocation)
signature, err := getSignature(url, invocation, accessKeyPair.SecretKey, currentTimeMillis)
if err != nil {
return err
}
inv.SetAttachments(constant.REQUEST_SIGNATURE_KEY, signature)
inv.SetAttachments(constant.REQUEST_TIMESTAMP_KEY, currentTimeMillis)
inv.SetAttachments(constant.AK_KEY, accessKeyPair.AccessKey)
inv.SetAttachments(constant.CONSUMER, consumer)
return nil
}
func getSignature(url *common.URL, invocation protocol.Invocation, secrectKey string, currentTime string) (string, error) {
requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT,
url.ColonSeparatedKey(), invocation.MethodName(), secrectKey, currentTime)
var signature string
if parameterEncrypt := url.GetParamBool(constant.PARAMTER_SIGNATURE_ENABLE_KEY, false); parameterEncrypt {
var err error
if signature, err = SignWithParams(invocation.Arguments(), requestString, secrectKey); err != nil {
// TODO
return "", errors.New("sign the request with params failed, cause:" + err.Error())
}
} else {
signature = Sign(requestString, secrectKey)
}
return signature, nil
}
func (authenticator *DefaultAuthenticator) Authenticate(invocation protocol.Invocation, url *common.URL) error {
accessKeyId := invocation.AttachmentsByKey(constant.AK_KEY, "")
requestTimestamp := invocation.AttachmentsByKey(constant.REQUEST_TIMESTAMP_KEY, "")
originSignature := invocation.AttachmentsByKey(constant.REQUEST_SIGNATURE_KEY, "")
consumer := invocation.AttachmentsByKey(constant.CONSUMER, "")
if IsEmpty(accessKeyId, false) || IsEmpty(consumer, false) || IsEmpty(requestTimestamp, false) || IsEmpty(originSignature, false) {
return errors.New("failed to authenticate, maybe consumer not enable the auth")
}
accessKeyPair, err := getAccessKeyPair(invocation, url)
if err != nil {
return errors.New("failed to authenticate , can't load the accessKeyPair")
}
computeSignature, err := getSignature(url, invocation, accessKeyPair.SecretKey, requestTimestamp)
if err != nil {
return err
}
if success := computeSignature == originSignature; !success {
return errors.New("failed to authenticate, signature is not correct")
}
return nil
}
func getAccessKeyPair(invocation protocol.Invocation, url *common.URL) (*filter.AccessKeyPair, error) {
accesskeyStorage := extension.GetAccesskeyStorages(url.GetParam(constant.ACCESS_KEY_STORAGE_KEY, constant.DEFAULT_ACCESS_KEY_STORAGE))
accessKeyPair := accesskeyStorage.GetAccesskeyPair(invocation, url)
if accessKeyPair == nil || IsEmpty(accessKeyPair.AccessKey, false) || IsEmpty(accessKeyPair.SecretKey, true) {
return nil, errors.New("accessKeyId or secretAccessKey not found")
} else {
return accessKeyPair, nil
}
}
func GetDefaultAuthenticator() filter.Authenticator {
return &DefaultAuthenticator{}
}
package auth
import (
"fmt"
)
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
)
type ConsumerSignFilter struct {
}
func init() {
extension.SetFilter(constant.CONSUMER_SIGN_FILTER, getConsumerSignFilter)
}
func (filter *ConsumerSignFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking ConsumerSign filter.")
url := invoker.GetUrl()
shouldAuth := url.GetParamBool(constant.SERVICE_AUTH_KEY, false)
if shouldAuth {
authenticator := extension.GetAuthenticator(url.GetParam(constant.AUTHENTICATOR_KEY, constant.DEFAULT_AUTHENTICATOR))
if err := authenticator.Sign(invocation, &url); err != nil {
panic(fmt.Sprintf("Sign for invocation %s # %s failed", url.ServiceKey(), invocation.MethodName()))
}
}
return invoker.Invoke(invocation)
}
func (filter *ConsumerSignFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
func getConsumerSignFilter() filter.Filter {
return &ConsumerSignFilter{}
}
package auth
import (
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/filter"
"github.com/apache/dubbo-go/protocol"
)
type ProviderAuthFilter struct {
}
func init() {
extension.SetFilter(constant.PROVIDER_AUTH_FILTER, getProviderAuthFilter)
}
func (filter *ProviderAuthFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
logger.Infof("invoking providerAuth filter.")
url := invoker.GetUrl()
shouldAuth := url.GetParamBool(constant.SERVICE_AUTH_KEY, false)
if shouldAuth {
authenticator := extension.GetAuthenticator(url.GetParam(constant.AUTHENTICATOR_KEY, constant.DEFAULT_AUTHENTICATOR))
if err := authenticator.Authenticate(invocation, &url); err != nil {
logger.Infof("auth the request: %v occur exception, cause: %s", invocation, err.Error())
return &protocol.RPCResult{
Err: err,
}
}
}
return invoker.Invoke(invocation)
}
func (filter *ProviderAuthFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
func getProviderAuthFilter() filter.Filter {
return &ProviderAuthFilter{}
}
package auth
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/gob"
"errors"
"strings"
)
func Sign(metadata, key string) string {
return doSign([]byte(metadata), key)
}
func SignWithParams(params []interface{}, metadata, key string) (string, error) {
if params == nil || len(params) == 0 {
return Sign(metadata, key), nil
}
data := append(params, metadata)
if bytes, err := toBytes(data); err != nil {
// TODO
return "", errors.New("data convert to bytes failed")
} else {
return doSign(bytes, key), nil
}
}
func toBytes(data []interface{}) ([]byte, error) {
var b bytes.Buffer
enc := gob.NewEncoder(&b)
if err := enc.Encode(data); err != nil {
return nil, errors.New("")
}
return b.Bytes(), nil
}
func doSign(bytes []byte, key string) string {
sum256 := sha256.Sum256(bytes)
return base64.URLEncoding.EncodeToString(sum256[:])
}
func IsEmpty(s string, allowSpace bool) bool {
if len(s) == 0 {
return true
}
if !allowSpace {
if strings.TrimSpace(s) == "" {
return true
}
}
return false
}
package auth
import (
"reflect"
"testing"
)
import (
"github.com/stretchr/testify/assert"
)
func TestIsEmpty(t *testing.T) {
type args struct {
s string
allowSpace bool
}
tests := []struct {
name string
args args
want bool
}{
// TODO: Add test cases.
{"test1", args{s: " ", allowSpace: false}, true},
{"test2", args{s: " ", allowSpace: true}, false},
{"test3", args{s: "hello,dubbo", allowSpace: false}, false},
{"test4", args{s: "hello,dubbo", allowSpace: true}, false},
{"test5", args{s: "", allowSpace: true}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsEmpty(tt.args.s, tt.args.allowSpace); got != tt.want {
t.Errorf("IsEmpty() = %v, want %v", got, tt.want)
}
})
}
}
func TestSign(t *testing.T) {
metadata := "com.ikurento.user.UserProvider::sayHi"
key := "key"
signature := Sign(metadata, key)
assert.NotNil(t, signature)
}
func TestSignWithParams(t *testing.T) {
}
func Test_doSign(t *testing.T) {
type args struct {
bytes []byte
key string
}
tests := []struct {
name string
args args
want string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := doSign(tt.args.bytes, tt.args.key); got != tt.want {
t.Errorf("doSign() = %v, want %v", got, tt.want)
}
})
}
}
func Test_toBytes(t *testing.T) {
type args struct {
data []interface{}
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := toBytes(tt.args.data)
if (err != nil) != tt.wantErr {
t.Errorf("toBytes() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("toBytes() got = %v, want %v", got, tt.want)
}
})
}
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment