Skip to content

Commit

Permalink
otelhttp: handle nil base http transport (#713)
Browse files Browse the repository at this point in the history
* handle nil base http transport

* update godoc

Co-authored-by: Tyler Yahn <MrAlias@users.noreply.github.com>
  • Loading branch information
kjschnei001 and MrAlias authored Apr 6, 2021
1 parent e8c2192 commit ae2c628
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
7 changes: 7 additions & 0 deletions instrumentation/net/http/otelhttp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@ var _ http.RoundTripper = &Transport{}

// NewTransport wraps the provided http.RoundTripper with one that
// starts a span and injects the span context into the outbound request headers.
//
// If the provided http.RoundTripper is nil, http.DefaultTransport will be used
// as the base http.RoundTripper
func NewTransport(base http.RoundTripper, opts ...Option) *Transport {
if base == nil {
base = http.DefaultTransport
}

t := Transport{
rt: base,
}
Expand Down
48 changes: 48 additions & 0 deletions instrumentation/net/http/otelhttp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,51 @@ func TestTransportBasics(t *testing.T) {
t.Fatalf("unexpected content: got %s, expected %s", body, content)
}
}

func TestNilTransport(t *testing.T) {
prop := propagation.TraceContext{}
provider := oteltest.NewTracerProvider()
content := []byte("Hello, world!")

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
span := trace.RemoteSpanContextFromContext(ctx)
tgtID, err := trace.SpanIDFromHex(fmt.Sprintf("%016x", uint(2)))
if err != nil {
t.Fatalf("Error converting id to SpanID: %s", err.Error())
}
if span.SpanID() != tgtID {
t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), tgtID)
}
if _, err := w.Write(content); err != nil {
t.Fatal(err)
}
}))
defer ts.Close()

r, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatal(err)
}

tr := NewTransport(
nil,
WithTracerProvider(provider),
WithPropagators(prop),
)

c := http.Client{Transport: tr}
res, err := c.Do(r)
if err != nil {
t.Fatal(err)
}

body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(body, content) {
t.Fatalf("unexpected content: got %s, expected %s", body, content)
}
}

0 comments on commit ae2c628

Please sign in to comment.