From d0235e4d6be307de7d197d94248f678a414cb864 Mon Sep 17 00:00:00 2001 From: tukeJonny Date: Thu, 13 Feb 2020 04:11:50 +0900 Subject: [PATCH] interceptor: new APIs for chaining server interceptors. (#3336) --- server.go | 92 ++++++++++++++ test/server_test.go | 284 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 test/server_test.go diff --git a/server.go b/server.go index ef46a53f60f9..0be1824abbd9 100644 --- a/server.go +++ b/server.go @@ -116,6 +116,8 @@ type serverOptions struct { dc Decompressor unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor + chainUnaryInts []UnaryServerInterceptor + chainStreamInts []StreamServerInterceptor inTapHandle tap.ServerInHandle statsHandler stats.Handler maxConcurrentStreams uint32 @@ -311,6 +313,16 @@ func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { }) } +// ChainUnaryInterceptor returns a ServerOption that specifies the chained interceptor +// for unary RPCs. The first interceptor will be the outer most, +// while the last interceptor will be the inner most wrapper around the real call. +// All unary interceptors added by this method will be chained. +func ChainUnaryInterceptor(interceptors ...UnaryServerInterceptor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.chainUnaryInts = append(o.chainUnaryInts, interceptors...) + }) +} + // StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the // server. Only one stream interceptor can be installed. func StreamInterceptor(i StreamServerInterceptor) ServerOption { @@ -322,6 +334,16 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption { }) } +// ChainStreamInterceptor returns a ServerOption that specifies the chained interceptor +// for stream RPCs. The first interceptor will be the outer most, +// while the last interceptor will be the inner most wrapper around the real call. +// All stream interceptors added by this method will be chained. +func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.chainStreamInts = append(o.chainStreamInts, interceptors...) + }) +} + // InTapHandle returns a ServerOption that sets the tap handle for all the server // transport to be created. Only one can be installed. func InTapHandle(h tap.ServerInHandle) ServerOption { @@ -404,6 +426,8 @@ func NewServer(opt ...ServerOption) *Server { done: grpcsync.NewEvent(), czData: new(channelzData), } + chainUnaryServerInterceptors(s) + chainStreamServerInterceptors(s) s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) @@ -886,6 +910,40 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str return err } +// chainUnaryServerInterceptors chains all unary server interceptors into one. +func chainUnaryServerInterceptors(s *Server) { + // Prepend opts.unaryInt to the chaining interceptors if it exists, since unaryInt will + // be executed before any other chained interceptors. + interceptors := s.opts.chainUnaryInts + if s.opts.unaryInt != nil { + interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...) + } + + var chainedInt UnaryServerInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { + return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) + } + } + + s.opts.unaryInt = chainedInt +} + +// getChainUnaryHandler recursively generate the chained UnaryHandler +func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + + return func(ctx context.Context, req interface{}) (interface{}, error) { + return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) + } +} + func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { stat := stream.Stat() overallTimer := stat.NewTimer("/processUnaryRPC") @@ -1165,6 +1223,40 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } +// chainStreamServerInterceptors chains all stream server interceptors into one. +func chainStreamServerInterceptors(s *Server) { + // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will + // be executed before any other chained interceptors. + interceptors := s.opts.chainStreamInts + if s.opts.streamInt != nil { + interceptors = append([]StreamServerInterceptor{s.opts.streamInt}, s.opts.chainStreamInts...) + } + + var chainedInt StreamServerInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { + return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) + } + } + + s.opts.streamInt = chainedInt +} + +// getChainStreamHandler recursively generate the chained StreamHandler +func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + + return func(srv interface{}, ss ServerStream) error { + return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) + } +} + func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { stat := stream.Stat() overallTimer := stat.NewTimer("/processStreamingRPC") diff --git a/test/server_test.go b/test/server_test.go new file mode 100644 index 000000000000..c6a5fe74bd55 --- /dev/null +++ b/test/server_test.go @@ -0,0 +1,284 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "io" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +type ctxKey string + +func (s) TestChainUnaryServerInterceptor(t *testing.T) { + var ( + firstIntKey = ctxKey("firstIntKey") + secondIntKey = ctxKey("secondIntKey") + ) + + firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if ctx.Value(firstIntKey) != nil { + return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) != nil { + return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey) + } + + firstCtx := context.WithValue(ctx, firstIntKey, 0) + resp, err := handler(firstCtx, req) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt") + } + + simpleResp, ok := resp.(*testpb.SimpleResponse) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt") + } + return &testpb.SimpleResponse{ + Payload: &testpb.Payload{ + Type: simpleResp.GetPayload().GetType(), + Body: append(simpleResp.GetPayload().GetBody(), '1'), + }, + }, nil + } + + secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if ctx.Value(firstIntKey) == nil { + return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) != nil { + return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey) + } + + secondCtx := context.WithValue(ctx, secondIntKey, 1) + resp, err := handler(secondCtx, req) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt") + } + + simpleResp, ok := resp.(*testpb.SimpleResponse) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt") + } + return &testpb.SimpleResponse{ + Payload: &testpb.Payload{ + Type: simpleResp.GetPayload().GetType(), + Body: append(simpleResp.GetPayload().GetBody(), '2'), + }, + }, nil + } + + lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if ctx.Value(firstIntKey) == nil { + return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey) + } + if ctx.Value(secondIntKey) == nil { + return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey) + } + + resp, err := handler(ctx, req) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt") + } + + simpleResp, ok := resp.(*testpb.SimpleResponse) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt") + } + return &testpb.SimpleResponse{ + Payload: &testpb.Payload{ + Type: simpleResp.GetPayload().GetType(), + Body: append(simpleResp.GetPayload().GetBody(), '3'), + }, + }, nil + } + + sopts := []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt), + } + + ss := &stubServer{ + unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0) + if err != nil { + return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err) + } + + return &testpb.SimpleResponse{ + Payload: payload, + }, nil + }, + } + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + resp, err := ss.client.UnaryCall(context.Background(), &testpb.SimpleRequest{}) + if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { + t.Fatalf("ss.client.UnaryCall(context.Background(), _) = %v, %v; want nil, ", resp, err) + } + + respBytes := resp.Payload.GetBody() + if string(respBytes) != "321" { + t.Fatalf("invalid response: want=%s, but got=%s", "321", resp) + } +} + +func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) { + baseIntKey := ctxKey("baseIntKey") + + baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if ctx.Value(baseIntKey) != nil { + return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey) + } + + baseCtx := context.WithValue(ctx, baseIntKey, 1) + return handler(baseCtx, req) + } + + chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if ctx.Value(baseIntKey) == nil { + return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey) + } + + return handler(ctx, req) + } + + sopts := []grpc.ServerOption{ + grpc.UnaryInterceptor(baseInt), + grpc.ChainUnaryInterceptor(chainInt), + } + + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + resp, err := ss.client.EmptyCall(context.Background(), &testpb.Empty{}) + if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { + t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, ", resp, err) + } +} + +func (s) TestChainStreamServerInterceptor(t *testing.T) { + callCounts := make([]int, 4) + + firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if callCounts[0] != 0 { + return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0]) + } + if callCounts[1] != 0 { + return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) + } + if callCounts[2] != 0 { + return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) + } + if callCounts[3] != 0 { + return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) + } + callCounts[0]++ + return handler(srv, stream) + } + + secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if callCounts[0] != 1 { + return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) + } + if callCounts[1] != 0 { + return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) + } + if callCounts[2] != 0 { + return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) + } + if callCounts[3] != 0 { + return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) + } + callCounts[1]++ + return handler(srv, stream) + } + + lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if callCounts[0] != 1 { + return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) + } + if callCounts[1] != 1 { + return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) + } + if callCounts[2] != 0 { + return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) + } + if callCounts[3] != 0 { + return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) + } + callCounts[2]++ + return handler(srv, stream) + } + + sopts := []grpc.ServerOption{ + grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt), + } + + ss := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + if callCounts[0] != 1 { + return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) + } + if callCounts[1] != 1 { + return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) + } + if callCounts[2] != 1 { + return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) + } + if callCounts[3] != 0 { + return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) + } + callCounts[3]++ + return nil + }, + } + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + stream, err := ss.client.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("failed to FullDuplexCall: %v", err) + } + + _, err = stream.Recv() + if err != io.EOF { + t.Fatalf("failed to recv from stream: %v", err) + } + + if callCounts[3] != 1 { + t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3]) + } +}