diff --git a/common/constant/key.go b/common/constant/key.go index d201570b9ad5415694af5598fba7983289b2b295..45f79d39e1e88958b573554697529557b2a5906d 100644 --- a/common/constant/key.go +++ b/common/constant/key.go @@ -141,3 +141,20 @@ const ( const ( TRACING_REMOTE_SPAN_CTX = "tracing.remote.span.ctx" ) + +const ( + CONSUMER_SIGN_FILTER = "sign" + PROVIDER_AUTH_FILTER = "auth" + SERVICE_AUTH_KEY = "auth" + AUTHENTICATOR_KEY = "authenticator" + DEFAULT_AUTHENTICATOR = "accesskeys" + DEFAULT_ACCESS_KEY_STORAGE = "urlstorage" + ACCESS_KEY_STORAGE_KEY = "accessKey.storage" + SECRET_ACCESS_KEY_KEY = "secretAccessKey" + REQUEST_TIMESTAMP_KEY = "timestamp" + REQUEST_SIGNATURE_KEY = "signature" + AK_KEY = "ak" + SIGNATURE_STRING_FORMAT = "%s#%s#%s#%s" + PARAMTER_SIGNATURE_ENABLE_KEY = "param.sign" + CONSUMER = "consumer" +) diff --git a/common/extension/auth.go b/common/extension/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..39c215660dac125226cb9c05d005b18930b94f2a --- /dev/null +++ b/common/extension/auth.go @@ -0,0 +1,33 @@ +package extension + +import ( + "github.com/apache/dubbo-go/filter" +) + +var( + authenticators = make(map[string]func() filter.Authenticator) + accesskeyStorages = make(map[string]func() filter.AccesskeyStorage) +) + + +func SetAuthenticator(name string, fcn func() filter.Authenticator) { + authenticators[name] = fcn +} + +func GetAuthenticator(name string) filter.Authenticator { + if clusters[name] == nil { + panic("cluster for " + name + " is not existing, make sure you have import the package.") + } + return authenticators[name]() +} + +func SetAccesskeyStorages(name string, fcn func() filter.AccesskeyStorage) { + accesskeyStorages[name] = fcn +} + +func GetAccesskeyStorages(name string) filter.AccesskeyStorage { + if clusters[name] == nil { + panic("cluster for " + name + " is not existing, make sure you have import the package.") + } + return accesskeyStorages[name]() +} \ No newline at end of file diff --git a/common/url.go b/common/url.go index a073e013f47a2acff4782ffa4444203fa0cec9b5..64a5930c6e5f6788c8a828d519f23da3b8a5c264 100644 --- a/common/url.go +++ b/common/url.go @@ -339,6 +339,26 @@ func (c URL) ServiceKey() string { return buf.String() } +func (c *URL) ColonSeparatedKey() string { + intf := c.GetParam(constant.INTERFACE_KEY, strings.TrimPrefix(c.Path, "/")) + if intf == "" { + return "" + } + buf := &bytes.Buffer{} + buf.WriteString(intf) + buf.WriteString(":") + version := c.GetParam(constant.VERSION_KEY, "") + if version != "" && version != "0.0.0" { + buf.WriteString(version) + } + group := c.GetParam(constant.GROUP_KEY, "") + buf.WriteString(":") + if group != "" { + buf.WriteString(group) + } + return buf.String() +} + // EncodedServiceKey ... func (c *URL) EncodedServiceKey() string { serviceKey := c.ServiceKey() diff --git a/common/url_test.go b/common/url_test.go index c70c58bc215b6449311d43f9f9cffeb89623f80c..9d80bc52b8ee5c74abb29b3b65487156ebf6ae2c 100644 --- a/common/url_test.go +++ b/common/url_test.go @@ -271,3 +271,17 @@ func TestClone(t *testing.T) { assert.Equal(t, u1.Protocol, "dubbo") assert.Equal(t, u2.Protocol, "provider") } + +func TestColonSeparatedKey(t *testing.T) { + u1, _ := NewURL(context.TODO(), "dubbo://127.0.0.1:20000") + u1.AddParam(constant.INTERFACE_KEY, "com.ikurento.user.UserProvider") + + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::") + u1.AddParam(constant.VERSION_KEY, "version1") + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:") + u1.AddParam(constant.GROUP_KEY, "group1") + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+":version1:group1") + u1.SetParam(constant.VERSION_KEY, "") + assert.Equal(t, u1.ColonSeparatedKey(), u1.GetParam(constant.INTERFACE_KEY, "")+"::group1") + +} diff --git a/filter/accesskey.go b/filter/accesskey.go new file mode 100644 index 0000000000000000000000000000000000000000..682e2b190b0d2feccceabf183382ae0a6c75f08b --- /dev/null +++ b/filter/accesskey.go @@ -0,0 +1,10 @@ +package filter + +type AccessKeyPair struct { + AccessKey string + SecretKey string + ConsumerSide string + ProviderSide string + Creator string + Options string +} diff --git a/filter/auth_spi.go b/filter/auth_spi.go new file mode 100644 index 0000000000000000000000000000000000000000..edbea66e0cc51d57795c06dd1e3378d014301360 --- /dev/null +++ b/filter/auth_spi.go @@ -0,0 +1,16 @@ +package filter + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" +) + +type Authenticator interface { + Sign(protocol.Invocation, *common.URL) error + Authenticate(protocol.Invocation, *common.URL) error +} + +type AccesskeyStorage interface { + GetAccesskeyPair(protocol.Invocation, *common.URL) *AccessKeyPair +} + diff --git a/filter/filter_impl/auth/accesskey_storage.go b/filter/filter_impl/auth/accesskey_storage.go new file mode 100644 index 0000000000000000000000000000000000000000..82c393b7cc91b35dcbe2c49a915fc9f966a00a42 --- /dev/null +++ b/filter/filter_impl/auth/accesskey_storage.go @@ -0,0 +1,33 @@ +package auth + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" +) + +const ( + ACCESS_KEY_ID_KEY = "accessKeyId" + + SECRET_ACCESS_KEY_KEY = "secretAccessKey" +) + +type DefaultAccesskeyStorage struct { +} + +func (storage *DefaultAccesskeyStorage) GetAccesskeyPair(invocation protocol.Invocation, url *common.URL) *filter.AccessKeyPair { + return &filter.AccessKeyPair{ + AccessKey: url.GetParam(ACCESS_KEY_ID_KEY, ""), + SecretKey: url.GetParam(SECRET_ACCESS_KEY_KEY, ""), + } +} + +func init() { + extension.SetAccesskeyStorages(constant.DEFAULT_ACCESS_KEY_STORAGE, GetDefaultAccesskeyStorage) +} + +func GetDefaultAccesskeyStorage() filter.AccesskeyStorage { + return &DefaultAccesskeyStorage{} +} diff --git a/filter/filter_impl/auth/authenticator.go b/filter/filter_impl/auth/authenticator.go new file mode 100644 index 0000000000000000000000000000000000000000..31589a52cb08525b14b147c7ed08f69e304dc4f6 --- /dev/null +++ b/filter/filter_impl/auth/authenticator.go @@ -0,0 +1,101 @@ +package auth + +import ( + "errors" + "fmt" + "strconv" + "time" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" + invocation_impl "github.com/apache/dubbo-go/protocol/invocation" +) + +func init() { + extension.SetAuthenticator(constant.DEFAULT_AUTHENTICATOR, GetDefaultAuthenticator) +} + +type DefaultAuthenticator struct { +} + +func (authenticator *DefaultAuthenticator) Sign(invocation protocol.Invocation, url *common.URL) error { + currentTimeMillis := strconv.Itoa(time.Now().Second() * 1000) + + consumer := url.GetParam(constant.APPLICATION_KEY, "") + accessKeyPair, err := getAccessKeyPair(invocation, url) + if err != nil { + return errors.New("get accesskey pair failed, cause: " + err.Error()) + } + inv := invocation.(*invocation_impl.RPCInvocation) + signature, err := getSignature(url, invocation, accessKeyPair.SecretKey, currentTimeMillis) + if err != nil { + return err + } + inv.SetAttachments(constant.REQUEST_SIGNATURE_KEY, signature) + inv.SetAttachments(constant.REQUEST_TIMESTAMP_KEY, currentTimeMillis) + inv.SetAttachments(constant.AK_KEY, accessKeyPair.AccessKey) + inv.SetAttachments(constant.CONSUMER, consumer) + return nil +} + +func getSignature(url *common.URL, invocation protocol.Invocation, secrectKey string, currentTime string) (string, error) { + + requestString := fmt.Sprintf(constant.SIGNATURE_STRING_FORMAT, + url.ColonSeparatedKey(), invocation.MethodName(), secrectKey, currentTime) + var signature string + if parameterEncrypt := url.GetParamBool(constant.PARAMTER_SIGNATURE_ENABLE_KEY, false); parameterEncrypt { + var err error + if signature, err = SignWithParams(invocation.Arguments(), requestString, secrectKey); err != nil { + // TODO + return "", errors.New("sign the request with params failed, cause:" + err.Error()) + } + } else { + signature = Sign(requestString, secrectKey) + } + + return signature, nil +} + +func (authenticator *DefaultAuthenticator) Authenticate(invocation protocol.Invocation, url *common.URL) error { + accessKeyId := invocation.AttachmentsByKey(constant.AK_KEY, "") + + requestTimestamp := invocation.AttachmentsByKey(constant.REQUEST_TIMESTAMP_KEY, "") + originSignature := invocation.AttachmentsByKey(constant.REQUEST_SIGNATURE_KEY, "") + consumer := invocation.AttachmentsByKey(constant.CONSUMER, "") + if IsEmpty(accessKeyId, false) || IsEmpty(consumer, false) || IsEmpty(requestTimestamp, false) || IsEmpty(originSignature, false) { + return errors.New("failed to authenticate, maybe consumer not enable the auth") + } + + accessKeyPair, err := getAccessKeyPair(invocation, url) + if err != nil { + return errors.New("failed to authenticate , can't load the accessKeyPair") + } + + computeSignature, err := getSignature(url, invocation, accessKeyPair.SecretKey, requestTimestamp) + if err != nil { + return err + } + if success := computeSignature == originSignature; !success { + return errors.New("failed to authenticate, signature is not correct") + } + return nil +} + +func getAccessKeyPair(invocation protocol.Invocation, url *common.URL) (*filter.AccessKeyPair, error) { + accesskeyStorage := extension.GetAccesskeyStorages(url.GetParam(constant.ACCESS_KEY_STORAGE_KEY, constant.DEFAULT_ACCESS_KEY_STORAGE)) + accessKeyPair := accesskeyStorage.GetAccesskeyPair(invocation, url) + if accessKeyPair == nil || IsEmpty(accessKeyPair.AccessKey, false) || IsEmpty(accessKeyPair.SecretKey, true) { + return nil, errors.New("accessKeyId or secretAccessKey not found") + } else { + return accessKeyPair, nil + } +} + +func GetDefaultAuthenticator() filter.Authenticator { + return &DefaultAuthenticator{} +} diff --git a/filter/filter_impl/auth/consumer_sign.go b/filter/filter_impl/auth/consumer_sign.go new file mode 100644 index 0000000000000000000000000000000000000000..f17d7e5b647dae2d11d66e75d2060fccb55aad41 --- /dev/null +++ b/filter/filter_impl/auth/consumer_sign.go @@ -0,0 +1,39 @@ +package auth + +import ( + "fmt" +) +import ( + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" +) + +type ConsumerSignFilter struct { +} + +func init() { + extension.SetFilter(constant.CONSUMER_SIGN_FILTER, getConsumerSignFilter) +} + +func (filter *ConsumerSignFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + logger.Infof("invoking ConsumerSign filter.") + url := invoker.GetUrl() + shouldAuth := url.GetParamBool(constant.SERVICE_AUTH_KEY, false) + if shouldAuth { + authenticator := extension.GetAuthenticator(url.GetParam(constant.AUTHENTICATOR_KEY, constant.DEFAULT_AUTHENTICATOR)) + if err := authenticator.Sign(invocation, &url); err != nil { + panic(fmt.Sprintf("Sign for invocation %s # %s failed", url.ServiceKey(), invocation.MethodName())) + } + } + return invoker.Invoke(invocation) +} + +func (filter *ConsumerSignFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + return result +} +func getConsumerSignFilter() filter.Filter { + return &ConsumerSignFilter{} +} diff --git a/filter/filter_impl/auth/provider_auth.go b/filter/filter_impl/auth/provider_auth.go new file mode 100644 index 0000000000000000000000000000000000000000..80dffd7bd7d38b5eb2ac60b2bff35ff66aa8ca8f --- /dev/null +++ b/filter/filter_impl/auth/provider_auth.go @@ -0,0 +1,39 @@ +package auth + +import ( + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" +) + +type ProviderAuthFilter struct { +} + +func init() { + extension.SetFilter(constant.PROVIDER_AUTH_FILTER, getProviderAuthFilter) +} + +func (filter *ProviderAuthFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + logger.Infof("invoking providerAuth filter.") + url := invoker.GetUrl() + shouldAuth := url.GetParamBool(constant.SERVICE_AUTH_KEY, false) + if shouldAuth { + authenticator := extension.GetAuthenticator(url.GetParam(constant.AUTHENTICATOR_KEY, constant.DEFAULT_AUTHENTICATOR)) + if err := authenticator.Authenticate(invocation, &url); err != nil { + logger.Infof("auth the request: %v occur exception, cause: %s", invocation, err.Error()) + return &protocol.RPCResult{ + Err: err, + } + } + } + return invoker.Invoke(invocation) +} + +func (filter *ProviderAuthFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + return result +} +func getProviderAuthFilter() filter.Filter { + return &ProviderAuthFilter{} +} diff --git a/filter/filter_impl/auth/sign_util.go b/filter/filter_impl/auth/sign_util.go new file mode 100644 index 0000000000000000000000000000000000000000..086c5293ff7550bb2786ae7bc9f7a5e5e72145fd --- /dev/null +++ b/filter/filter_impl/auth/sign_util.go @@ -0,0 +1,54 @@ +package auth + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/gob" + "errors" + "strings" +) + +func Sign(metadata, key string) string { + return doSign([]byte(metadata), key) +} + +func SignWithParams(params []interface{}, metadata, key string) (string, error) { + if params == nil || len(params) == 0 { + return Sign(metadata, key), nil + } + + data := append(params, metadata) + if bytes, err := toBytes(data); err != nil { + // TODO + return "", errors.New("data convert to bytes failed") + } else { + return doSign(bytes, key), nil + } +} + +func toBytes(data []interface{}) ([]byte, error) { + var b bytes.Buffer + enc := gob.NewEncoder(&b) + if err := enc.Encode(data); err != nil { + return nil, errors.New("") + } + return b.Bytes(), nil +} + +func doSign(bytes []byte, key string) string { + sum256 := sha256.Sum256(bytes) + return base64.URLEncoding.EncodeToString(sum256[:]) +} + +func IsEmpty(s string, allowSpace bool) bool { + if len(s) == 0 { + return true + } + if !allowSpace { + if strings.TrimSpace(s) == "" { + return true + } + } + return false +} diff --git a/filter/filter_impl/auth/sign_util_test.go b/filter/filter_impl/auth/sign_util_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d3212134155549398592943cdba81a38cfcd5572 --- /dev/null +++ b/filter/filter_impl/auth/sign_util_test.go @@ -0,0 +1,95 @@ +package auth + +import ( + "reflect" + "testing" +) +import ( + "github.com/stretchr/testify/assert" +) + +func TestIsEmpty(t *testing.T) { + type args struct { + s string + allowSpace bool + } + tests := []struct { + name string + args args + want bool + }{ + // TODO: Add test cases. + {"test1", args{s: " ", allowSpace: false}, true}, + {"test2", args{s: " ", allowSpace: true}, false}, + {"test3", args{s: "hello,dubbo", allowSpace: false}, false}, + {"test4", args{s: "hello,dubbo", allowSpace: true}, false}, + {"test5", args{s: "", allowSpace: true}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsEmpty(tt.args.s, tt.args.allowSpace); got != tt.want { + t.Errorf("IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSign(t *testing.T) { + + metadata := "com.ikurento.user.UserProvider::sayHi" + key := "key" + signature := Sign(metadata, key) + assert.NotNil(t, signature) + +} + +func TestSignWithParams(t *testing.T) { + +} + +func Test_doSign(t *testing.T) { + type args struct { + bytes []byte + key string + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := doSign(tt.args.bytes, tt.args.key); got != tt.want { + t.Errorf("doSign() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_toBytes(t *testing.T) { + type args struct { + data []interface{} + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := toBytes(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("toBytes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("toBytes() got = %v, want %v", got, tt.want) + } + }) + } +}