From 0c34c47c00bb58a39dcc4b4e4ec67358e37f17db Mon Sep 17 00:00:00 2001
From: "382673304@qq.com" <382673304@qq.com>
Date: Fri, 30 Oct 2020 13:42:10 +0800
Subject: [PATCH] feat: add grpc max message size config

---
 common/constant/key.go         |  1 +
 config/service_config.go       |  2 ++
 protocol/grpc/client.go        |  7 ++++++-
 protocol/grpc/grpc_protocol.go |  4 ++++
 protocol/grpc/server.go        | 11 ++++++++++-
 5 files changed, 23 insertions(+), 2 deletions(-)

diff --git a/common/constant/key.go b/common/constant/key.go
index 943338f8e..c7db08a55 100644
--- a/common/constant/key.go
+++ b/common/constant/key.go
@@ -25,6 +25,7 @@ const (
 	GROUP_KEY                = "group"
 	VERSION_KEY              = "version"
 	INTERFACE_KEY            = "interface"
+	GRPC_MESSAGE_SIZE_KEY    = "message_size"
 	PATH_KEY                 = "path"
 	SERVICE_KEY              = "service"
 	METHODS_KEY              = "methods"
diff --git a/config/service_config.go b/config/service_config.go
index 48632a1b1..764d5dc86 100644
--- a/config/service_config.go
+++ b/config/service_config.go
@@ -73,6 +73,7 @@ type ServiceConfig struct {
 	Auth                        string            `yaml:"auth" json:"auth,omitempty" property:"auth"`
 	ParamSign                   string            `yaml:"param.sign" json:"param.sign,omitempty" property:"param.sign"`
 	Tag                         string            `yaml:"tag" json:"tag,omitempty" property:"tag"`
+	GrpcMaxMessageSize          int               `default:"4" yaml:"max_message_size" json:"max_message_size,omitempty"`
 
 	Protocols     map[string]*ProtocolConfig
 	unexported    *atomic.Bool
@@ -271,6 +272,7 @@ func (c *ServiceConfig) getUrlMap() url.Values {
 	urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.PROVIDER))
 	urlMap.Set(constant.RELEASE_KEY, "dubbo-golang-"+constant.Version)
 	urlMap.Set(constant.SIDE_KEY, (common.RoleType(common.PROVIDER)).Role())
+	urlMap.Set(constant.GRPC_MESSAGE_SIZE_KEY, strconv.Itoa(c.GrpcMaxMessageSize))
 	// todo: move
 	urlMap.Set(constant.SERIALIZATION_KEY, c.Serialization)
 	// application info
diff --git a/protocol/grpc/client.go b/protocol/grpc/client.go
index a0ab0be80..6f9fc22a0 100644
--- a/protocol/grpc/client.go
+++ b/protocol/grpc/client.go
@@ -19,6 +19,7 @@ package grpc
 
 import (
 	"reflect"
+	"strconv"
 )
 
 import (
@@ -93,9 +94,13 @@ func NewClient(url common.URL) *Client {
 	// if global trace instance was set , it means trace function enabled. If not , will return Nooptracer
 	tracer := opentracing.GlobalTracer()
 	dailOpts := make([]grpc.DialOption, 0, 4)
+	maxMessageSize, _ := strconv.Atoi(url.GetParam(constant.GRPC_MESSAGE_SIZE_KEY, "4"))
 	dailOpts = append(dailOpts, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithUnaryInterceptor(
 		otgrpc.OpenTracingClientInterceptor(tracer, otgrpc.LogPayloads())),
-		grpc.WithDefaultCallOptions(grpc.CallContentSubtype(clientConf.ContentSubType)))
+		grpc.WithDefaultCallOptions(
+			grpc.CallContentSubtype(clientConf.ContentSubType),
+			grpc.MaxCallRecvMsgSize(1024*1024*maxMessageSize),
+			grpc.MaxCallSendMsgSize(1024*1024*maxMessageSize)))
 	conn, err := grpc.Dial(url.Location, dailOpts...)
 	if err != nil {
 		panic(err)
diff --git a/protocol/grpc/grpc_protocol.go b/protocol/grpc/grpc_protocol.go
index 68594a4b3..296497aca 100644
--- a/protocol/grpc/grpc_protocol.go
+++ b/protocol/grpc/grpc_protocol.go
@@ -18,6 +18,8 @@
 package grpc
 
 import (
+	"github.com/apache/dubbo-go/common/constant"
+	"strconv"
 	"sync"
 )
 
@@ -76,7 +78,9 @@ func (gp *GrpcProtocol) openServer(url common.URL) {
 		gp.serverLock.Lock()
 		_, ok = gp.serverMap[url.Location]
 		if !ok {
+			grpcMessageSize, _ := strconv.Atoi(url.GetParam(constant.GRPC_MESSAGE_SIZE_KEY, "4"))
 			srv := NewServer()
+			srv.SetBufferSize(grpcMessageSize)
 			gp.serverMap[url.Location] = srv
 			srv.Start(url)
 		}
diff --git a/protocol/grpc/server.go b/protocol/grpc/server.go
index 2b7b1addd..d6ed29dc9 100644
--- a/protocol/grpc/server.go
+++ b/protocol/grpc/server.go
@@ -40,6 +40,7 @@ import (
 // Server is a gRPC server
 type Server struct {
 	grpcServer *grpc.Server
+	bufferSize int
 }
 
 // NewServer creates a new server
@@ -57,6 +58,10 @@ type DubboGrpcService interface {
 	ServiceDesc() *grpc.ServiceDesc
 }
 
+func (s *Server) SetBufferSize(n int) {
+	s.bufferSize = n
+}
+
 // Start gRPC server with @url
 func (s *Server) Start(url common.URL) {
 	var (
@@ -72,7 +77,11 @@ func (s *Server) Start(url common.URL) {
 	// if global trace instance was set, then server tracer instance can be get. If not , will return Nooptracer
 	tracer := opentracing.GlobalTracer()
 	server := grpc.NewServer(
-		grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(tracer)))
+		grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(tracer)),
+		grpc.MaxRecvMsgSize(1024*1024*s.bufferSize),
+		grpc.MaxSendMsgSize(1024*1024*s.bufferSize))
+	fmt.Println("-------------------")
+	fmt.Println("size = ", s.bufferSize)
 
 	key := url.GetParam(constant.BEAN_NAME_KEY, "")
 	service := config.GetProviderService(key)
-- 
GitLab