From 1e36a61edfd8596a69b5fc5ba8b37eff2ea3c9e4 Mon Sep 17 00:00:00 2001 From: Tyler Yahn Date: Mon, 18 May 2020 18:54:38 -0700 Subject: [PATCH] Fix panic in gRPC UnaryServerInfo (#740) Fixes unresolved issue identified in #691 and attempted in #697. Adds unit test to ensure the UnaryServerInfo function does not panic during an error returned from the handler and appropriately annotates the span with the correct event. Restructures the interceptor to remove this class of errors. Co-authored-by: Joshua MacDonald --- plugin/grpctrace/interceptor.go | 66 ++++++++++++++-------------- plugin/grpctrace/interceptor_test.go | 38 ++++++++++++++++ 2 files changed, 70 insertions(+), 34 deletions(-) diff --git a/plugin/grpctrace/interceptor.go b/plugin/grpctrace/interceptor.go index a84aec9e12d..0981954c4af 100644 --- a/plugin/grpctrace/interceptor.go +++ b/plugin/grpctrace/interceptor.go @@ -43,9 +43,29 @@ var ( messageUncompressedSizeKey = kv.Key("message.uncompressed_size") ) +type messageType string + +// Event adds an event of the messageType to the span associated with the +// passed context with id and size (if message is a proto message). +func (m messageType) Event(ctx context.Context, id int, message interface{}) { + span := trace.SpanFromContext(ctx) + if p, ok := message.(proto.Message); ok { + span.AddEvent(ctx, "message", + messageTypeKey.String(string(m)), + messageIDKey.Int(id), + messageUncompressedSizeKey.Int(proto.Size(p)), + ) + } else { + span.AddEvent(ctx, "message", + messageTypeKey.String(string(m)), + messageIDKey.Int(id), + ) + } +} + const ( - messageTypeSent = "SENT" - messageTypeReceived = "RECEIVED" + messageSent messageType = "SENT" + messageReceived messageType = "RECEIVED" ) // UnaryClientInterceptor returns a grpc.UnaryClientInterceptor suitable @@ -80,11 +100,11 @@ func UnaryClientInterceptor(tracer trace.Tracer) grpc.UnaryClientInterceptor { Inject(ctx, &metadataCopy) ctx = metadata.NewOutgoingContext(ctx, metadataCopy) - addEventForMessageSent(ctx, 1, req) + messageSent.Event(ctx, 1, req) err := invoker(ctx, method, req, reply, cc, opts...) - addEventForMessageReceived(ctx, 1, reply) + messageReceived.Event(ctx, 1, reply) if err != nil { s, _ := status.FromError(err) @@ -134,7 +154,7 @@ func (w *clientStream) RecvMsg(m interface{}) error { w.events <- streamEvent{errorEvent, err} } else { w.receivedMessageID++ - addEventForMessageReceived(w.Context(), w.receivedMessageID, m) + messageReceived.Event(w.Context(), w.receivedMessageID, m) } return err @@ -144,7 +164,7 @@ func (w *clientStream) SendMsg(m interface{}) error { err := w.ClientStream.SendMsg(m) w.sentMessageID++ - addEventForMessageSent(w.Context(), w.sentMessageID, m) + messageSent.Event(w.Context(), w.sentMessageID, m) if err != nil { w.events <- streamEvent{errorEvent, err} @@ -297,15 +317,15 @@ func UnaryServerInterceptor(tracer trace.Tracer) grpc.UnaryServerInterceptor { ) defer span.End() - addEventForMessageReceived(ctx, 1, req) + messageReceived.Event(ctx, 1, req) resp, err := handler(ctx, req) - - addEventForMessageSent(ctx, 1, resp) - if err != nil { s, _ := status.FromError(err) span.SetStatus(s.Code(), s.Message()) + messageSent.Event(ctx, 1, s.Proto()) + } else { + messageSent.Event(ctx, 1, resp) } return resp, err @@ -331,7 +351,7 @@ func (w *serverStream) RecvMsg(m interface{}) error { if err == nil { w.receivedMessageID++ - addEventForMessageReceived(w.Context(), w.receivedMessageID, m) + messageReceived.Event(w.Context(), w.receivedMessageID, m) } return err @@ -341,7 +361,7 @@ func (w *serverStream) SendMsg(m interface{}) error { err := w.ServerStream.SendMsg(m) w.sentMessageID++ - addEventForMessageSent(w.Context(), w.sentMessageID, m) + messageSent.Event(w.Context(), w.sentMessageID, m) return err } @@ -435,25 +455,3 @@ func serviceFromFullMethod(method string) string { return match[1] } - -func addEventForMessageReceived(ctx context.Context, id int, m interface{}) { - size := proto.Size(m.(proto.Message)) - - span := trace.SpanFromContext(ctx) - span.AddEvent(ctx, "message", - messageTypeKey.String(messageTypeReceived), - messageIDKey.Int(id), - messageUncompressedSizeKey.Int(size), - ) -} - -func addEventForMessageSent(ctx context.Context, id int, m interface{}) { - size := proto.Size(m.(proto.Message)) - - span := trace.SpanFromContext(ctx) - span.AddEvent(ctx, "message", - messageTypeKey.String(messageTypeSent), - messageIDKey.Int(id), - messageUncompressedSizeKey.Int(size), - ) -} diff --git a/plugin/grpctrace/interceptor_test.go b/plugin/grpctrace/interceptor_test.go index a92c7177f67..211a9c36efc 100644 --- a/plugin/grpctrace/interceptor_test.go +++ b/plugin/grpctrace/interceptor_test.go @@ -20,8 +20,12 @@ import ( "time" "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "go.opentelemetry.io/otel/api/kv" "go.opentelemetry.io/otel/api/kv/value" @@ -373,3 +377,37 @@ func TestStreamClientInterceptor(t *testing.T) { validate("RECEIVED", events[i+1].Attributes) } } + +func TestServerInterceptorError(t *testing.T) { + exp := &testExporter{spanMap: make(map[string]*export.SpanData)} + tp, err := sdktrace.NewProvider( + sdktrace.WithSyncer(exp), + sdktrace.WithConfig(sdktrace.Config{ + DefaultSampler: sdktrace.AlwaysSample(), + }), + ) + require.NoError(t, err) + + tracer := tp.Tracer("grpctrace/Server") + usi := UnaryServerInterceptor(tracer) + deniedErr := status.Error(codes.PermissionDenied, "PERMISSION_DENIED_TEXT") + handler := func(_ context.Context, _ interface{}) (interface{}, error) { + return nil, deniedErr + } + _, err = usi(context.Background(), &mockProtoMessage{}, &grpc.UnaryServerInfo{}, handler) + require.Error(t, err) + assert.Equal(t, err, deniedErr) + + span, ok := exp.spanMap[""] + if !ok { + t.Fatalf("failed to export error span") + } + assert.Equal(t, span.StatusCode, codes.PermissionDenied) + assert.Contains(t, deniedErr.Error(), span.StatusMessage) + assert.Len(t, span.MessageEvents, 2) + assert.Equal(t, []kv.KeyValue{ + kv.String("message.type", "SENT"), + kv.Int("message.id", 1), + kv.Int("message.uncompressed_size", 26), + }, span.MessageEvents[1].Attributes) +}