diff --git a/instrumentation/net/http/otelhttp/transport.go b/instrumentation/net/http/otelhttp/transport.go index 38cf6f33630..4c855bfb477 100644 --- a/instrumentation/net/http/otelhttp/transport.go +++ b/instrumentation/net/http/otelhttp/transport.go @@ -40,7 +40,14 @@ var _ http.RoundTripper = &Transport{} // NewTransport wraps the provided http.RoundTripper with one that // starts a span and injects the span context into the outbound request headers. +// +// If the provided http.RoundTripper is nil, http.DefaultTransport will be used +// as the base http.RoundTripper func NewTransport(base http.RoundTripper, opts ...Option) *Transport { + if base == nil { + base = http.DefaultTransport + } + t := Transport{ rt: base, } diff --git a/instrumentation/net/http/otelhttp/transport_test.go b/instrumentation/net/http/otelhttp/transport_test.go index 3e5add593d7..5631f71f6aa 100644 --- a/instrumentation/net/http/otelhttp/transport_test.go +++ b/instrumentation/net/http/otelhttp/transport_test.go @@ -74,3 +74,51 @@ func TestTransportBasics(t *testing.T) { t.Fatalf("unexpected content: got %s, expected %s", body, content) } } + +func TestNilTransport(t *testing.T) { + prop := propagation.TraceContext{} + provider := oteltest.NewTracerProvider() + content := []byte("Hello, world!") + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + span := trace.RemoteSpanContextFromContext(ctx) + tgtID, err := trace.SpanIDFromHex(fmt.Sprintf("%016x", uint(2))) + if err != nil { + t.Fatalf("Error converting id to SpanID: %s", err.Error()) + } + if span.SpanID() != tgtID { + t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), tgtID) + } + if _, err := w.Write(content); err != nil { + t.Fatal(err) + } + })) + defer ts.Close() + + r, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatal(err) + } + + tr := NewTransport( + nil, + WithTracerProvider(provider), + WithPropagators(prop), + ) + + c := http.Client{Transport: tr} + res, err := c.Do(r) + if err != nil { + t.Fatal(err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(body, content) { + t.Fatalf("unexpected content: got %s, expected %s", body, content) + } +}