diff --git a/common/constant/default.go b/common/constant/default.go index 6e0f8488783ebe66939436ca14670395e2719be7..3d9177b21dde779200371db7287d97e500d62c2b 100644 --- a/common/constant/default.go +++ b/common/constant/default.go @@ -46,7 +46,7 @@ const ( const ( DEFAULT_KEY = "default" PREFIX_DEFAULT_KEY = "default." - DEFAULT_SERVICE_FILTERS = "echo,token,accesslog,tps,execute,pshutdown" + DEFAULT_SERVICE_FILTERS = "echo,token,accesslog,tps,generic-service,execute,pshutdown" DEFAULT_REFERENCE_FILTERS = "cshutdown" GENERIC_REFERENCE_FILTERS = "generic" GENERIC = "$invoke" diff --git a/filter/impl/generic_service_filter.go b/filter/impl/generic_service_filter.go new file mode 100644 index 0000000000000000000000000000000000000000..9848fd8e58caf641640f2d76202b4cdfc7f56017 --- /dev/null +++ b/filter/impl/generic_service_filter.go @@ -0,0 +1,98 @@ +package impl + +import ( + hessian "github.com/apache/dubbo-go-hessian2" + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/constant" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/filter" + "github.com/apache/dubbo-go/protocol" + invocation2 "github.com/apache/dubbo-go/protocol/invocation" + "github.com/mitchellh/mapstructure" + "reflect" + "strings" +) + +const ( + GENERIC_SERVICE = "generic-service" + GENERIC_SERIALIZATION_DEFAULT = "true" +) + +func init() { + extension.SetFilter(GENERIC_SERVICE, GetGenericServiceFilter) +} + +type GenericServiceFilter struct{} + +func (ef *GenericServiceFilter) Invoke(invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + logger.Infof("invoking generic service filter.") + logger.Debugf("generic service filter methodName:%v,args:%v", invocation.MethodName(), len(invocation.Arguments())) + if invocation.MethodName() != constant.GENERIC || len(invocation.Arguments()) != 3 { + return invoker.Invoke(invocation) + } + var ( + err error + methodName string + newParams []interface{} + genericKey string + argsType []reflect.Type + oldParams []hessian.Object + ) + url := invoker.GetUrl() + methodName = invocation.Arguments()[0].(string) + // get service + svc := common.ServiceMap.GetService(url.Protocol, strings.TrimPrefix(url.Path, "/")) + // get method + method := svc.Method()[methodName] + if method == nil { + logger.Errorf("[Generic Service Filter] Don't have this method: %v", method) + return &protocol.RPCResult{} + } + argsType = method.ArgsType() + genericKey = invocation.AttachmentsByKey(constant.GENERIC_KEY, GENERIC_SERIALIZATION_DEFAULT) + if genericKey == GENERIC_SERIALIZATION_DEFAULT { + oldParams = invocation.Arguments()[2].([]hessian.Object) + } else { + logger.Errorf("[Generic Service Filter] Don't support this generic: %v", genericKey) + return &protocol.RPCResult{} + } + if len(oldParams) != len(argsType) { + logger.Errorf("[Generic Service Filter] method:%s invocation arguments number was wrong", methodName) + return &protocol.RPCResult{} + } + // oldParams convert to newParams + for i := range argsType { + var newParam interface{} + if argsType[i].Kind() == reflect.Ptr { + newParam = reflect.New(argsType[i].Elem()).Interface() + err = mapstructure.Decode(oldParams[i], newParam) + } else if argsType[i].Kind() == reflect.Struct || argsType[i].Kind() == reflect.Slice { + newParam = reflect.New(argsType[i]).Interface() + err = mapstructure.Decode(oldParams[i], newParam) + newParam = reflect.ValueOf(newParam).Elem().Interface() + } else { + newParam = oldParams[i] + } + if err != nil { + logger.Errorf("[Generic Service Filter] decode arguments map to struct wrong") + } + newParams = append(newParams, newParam) + } + newInvocation := invocation2.NewRPCInvocation(methodName, newParams, invocation.Attachments()) + newInvocation.SetReply(invocation.Reply()) + return invoker.Invoke(newInvocation) +} + +func (ef *GenericServiceFilter) OnResponse(result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result { + if invocation.MethodName() == constant.GENERIC && len(invocation.Arguments()) == 3 && result.Result() != nil { + s := reflect.ValueOf(result.Result()).Elem().Interface() + r := struct2MapAll(s) + result.SetResult(r) + } + return result +} + +func GetGenericServiceFilter() filter.Filter { + return &GenericServiceFilter{} +} diff --git a/filter/impl/generic_service_filter_test.go b/filter/impl/generic_service_filter_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e8b625bf0a8714dadac66a32c2173f4d63818e2b --- /dev/null +++ b/filter/impl/generic_service_filter_test.go @@ -0,0 +1,107 @@ +package impl + +import ( + "context" + "errors" + hessian "github.com/apache/dubbo-go-hessian2" + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/common/proxy/proxy_factory" + "github.com/apache/dubbo-go/protocol" + "github.com/apache/dubbo-go/protocol/invocation" + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +type TestStruct struct { + AaAa string + BaBa string `m:"baBa"` + XxYy struct { + xxXx string `m:"xxXx"` + Xx string `m:"xx"` + } `m:"xxYy"` +} + +func (c *TestStruct) JavaClassName() string { + return "com.test.testStruct" +} + +type TestService struct { +} + +func (ts *TestService) MethodOne(ctx context.Context, test1 *TestStruct, test2 []TestStruct, + test3 interface{}, test4 []interface{}, test5 *string) (*TestStruct, error) { + if test1 == nil { + return nil, errors.New("param test1 is nil") + } + if test2 == nil { + return nil, errors.New("param test2 is nil") + } + if test3 == nil { + return nil, errors.New("param test3 is nil") + } + if test4 == nil { + return nil, errors.New("param test4 is nil") + } + if test5 == nil { + return nil, errors.New("param test5 is nil") + } + return &TestStruct{}, nil +} + +func (s *TestService) Reference() string { + return "com.test.Path" +} + +func TestGenericServiceFilter_Invoke(t *testing.T) { + hessian.RegisterPOJO(&TestStruct{}) + methodName := "$invoke" + m := make(map[string]interface{}) + m["AaAa"] = "nihao" + x := make(map[string]interface{}) + x["xxXX"] = "nihaoxxx" + m["XxYy"] = x + aurguments := []interface{}{ + "MethodOne", + nil, + []hessian.Object{ + hessian.Object(m), + hessian.Object(append(make([]map[string]interface{}, 1), m)), + hessian.Object("111"), + hessian.Object(append(make([]map[string]interface{}, 1), m)), + hessian.Object("222")}, + } + s := &TestService{} + _, _ = common.ServiceMap.Register("testprotocol", s) + rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil) + filter := GetGenericServiceFilter() + url, _ := common.NewURL(context.Background(), "testprotocol://127.0.0.1:20000/com.test.Path") + result := filter.Invoke(&proxy_factory.ProxyInvoker{BaseInvoker: *protocol.NewBaseInvoker(url)}, rpcInvocation) + assert.NotNil(t, result) + assert.Nil(t, result.Error()) +} + +func TestGenericServiceFilter_Response(t *testing.T) { + ts := &TestStruct{ + AaAa: "aaa", + BaBa: "bbb", + XxYy: struct { + xxXx string `m:"xxXx"` + Xx string `m:"xx"` + }{}, + } + result := &protocol.RPCResult{ + Rest: ts, + } + aurguments := []interface{}{ + "MethodOne", + nil, + []hessian.Object{nil}, + } + filter := GetGenericServiceFilter() + methodName := "$invoke" + rpcInvocation := invocation.NewRPCInvocation(methodName, aurguments, nil) + r := filter.OnResponse(result, nil, rpcInvocation) + assert.NotNil(t, r.Result()) + assert.Equal(t, reflect.ValueOf(r.Result()).Kind(), reflect.Map) +}