Skip to content

Commit

Permalink
interceptor: new APIs for chaining server interceptors. (#3336)
Browse files Browse the repository at this point in the history
  • Loading branch information
tukeJonny committed Feb 12, 2020
1 parent 95c834a commit d0235e4
Show file tree
Hide file tree
Showing 2 changed files with 376 additions and 0 deletions.
92 changes: 92 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ type serverOptions struct {
dc Decompressor
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
chainUnaryInts []UnaryServerInterceptor
chainStreamInts []StreamServerInterceptor
inTapHandle tap.ServerInHandle
statsHandler stats.Handler
maxConcurrentStreams uint32
Expand Down Expand Up @@ -311,6 +313,16 @@ func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
})
}

// ChainUnaryInterceptor returns a ServerOption that specifies the chained interceptor
// for unary RPCs. The first interceptor will be the outer most,
// while the last interceptor will be the inner most wrapper around the real call.
// All unary interceptors added by this method will be chained.
func ChainUnaryInterceptor(interceptors ...UnaryServerInterceptor) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
})
}

// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
// server. Only one stream interceptor can be installed.
func StreamInterceptor(i StreamServerInterceptor) ServerOption {
Expand All @@ -322,6 +334,16 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
})
}

// ChainStreamInterceptor returns a ServerOption that specifies the chained interceptor
// for stream RPCs. The first interceptor will be the outer most,
// while the last interceptor will be the inner most wrapper around the real call.
// All stream interceptors added by this method will be chained.
func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.chainStreamInts = append(o.chainStreamInts, interceptors...)
})
}

// InTapHandle returns a ServerOption that sets the tap handle for all the server
// transport to be created. Only one can be installed.
func InTapHandle(h tap.ServerInHandle) ServerOption {
Expand Down Expand Up @@ -404,6 +426,8 @@ func NewServer(opt ...ServerOption) *Server {
done: grpcsync.NewEvent(),
czData: new(channelzData),
}
chainUnaryServerInterceptors(s)
chainStreamServerInterceptors(s)
s.cv = sync.NewCond(&s.mu)
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
Expand Down Expand Up @@ -886,6 +910,40 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
return err
}

// chainUnaryServerInterceptors chains all unary server interceptors into one.
func chainUnaryServerInterceptors(s *Server) {
// Prepend opts.unaryInt to the chaining interceptors if it exists, since unaryInt will
// be executed before any other chained interceptors.
interceptors := s.opts.chainUnaryInts
if s.opts.unaryInt != nil {
interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...)
}

var chainedInt UnaryServerInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
}

s.opts.unaryInt = chainedInt
}

// getChainUnaryHandler recursively generate the chained UnaryHandler
func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
}

return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
}
}

func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
stat := stream.Stat()
overallTimer := stat.NewTimer("/processUnaryRPC")
Expand Down Expand Up @@ -1165,6 +1223,40 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err
}

// chainStreamServerInterceptors chains all stream server interceptors into one.
func chainStreamServerInterceptors(s *Server) {
// Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will
// be executed before any other chained interceptors.
interceptors := s.opts.chainStreamInts
if s.opts.streamInt != nil {
interceptors = append([]StreamServerInterceptor{s.opts.streamInt}, s.opts.chainStreamInts...)
}

var chainedInt StreamServerInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
}

s.opts.streamInt = chainedInt
}

// getChainStreamHandler recursively generate the chained StreamHandler
func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}

return func(srv interface{}, ss ServerStream) error {
return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
}
}

func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
stat := stream.Stat()
overallTimer := stat.NewTimer("/processStreamingRPC")
Expand Down
Loading

0 comments on commit d0235e4

Please sign in to comment.