Skip to content

Commit

Permalink
Restore scheme/host when recording HTTP(S) (#22102)
Browse files Browse the repository at this point in the history
Things like LROs can use the original HTTP request when polling, so we
must ensure that the *http.Request associated with a *http.Response has
the correct scheme and host.
  • Loading branch information
jhendrixMSFT authored Dec 6, 2023
1 parent f8be4a3 commit 4bc279b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 2 additions & 0 deletions sdk/internal/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

* Recording will restore the original scheme/host after making a successful HTTP(s) call.

### Other Changes

## 1.5.0 (2023-11-02)
Expand Down
21 changes: 15 additions & 6 deletions sdk/internal/recording/recording.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,10 @@ func defaultOptions() *RecordingOptions {
}
}

func (r RecordingOptions) ReplaceAuthority(t *testing.T, rawReq *http.Request) *http.Request {
func (r RecordingOptions) ReplaceAuthority(t *testing.T, rawReq *http.Request) (*http.Request, string, string) {
originalURLScheme := rawReq.URL.Scheme
originalURLHost := rawReq.URL.Host
if GetRecordMode() != LiveMode && !IsLiveOnly(t) {
originalURLHost := rawReq.URL.Host

// don't modify the original request
cp := *rawReq
cpURL := *cp.URL
Expand All @@ -556,7 +556,7 @@ func (r RecordingOptions) ReplaceAuthority(t *testing.T, rawReq *http.Request) *
cp.Header.Set(IDHeader, GetRecordingId(t))
rawReq = &cp
}
return rawReq
return rawReq, originalURLScheme, originalURLHost
}

func (r RecordingOptions) host() string {
Expand Down Expand Up @@ -941,8 +941,17 @@ type RecordingHTTPClient struct {
}

func (c RecordingHTTPClient) Do(req *http.Request) (*http.Response, error) {
req = c.options.ReplaceAuthority(c.t, req)
return c.defaultClient.Do(req)
req, origScheme, origHost := c.options.ReplaceAuthority(c.t, req)
resp, err := c.defaultClient.Do(req)
if err != nil {
return nil, err
}
// if the request succeeds, restore the scheme/host with their original values.
// this is imporant for things like LROs that might use the originating URL to
// poll for status and/or fetch the final result.
resp.Request.URL.Scheme = origScheme
resp.Request.URL.Host = origHost
return resp, nil
}

// NewRecordingHTTPClient returns a type that implements `azcore.Transporter`. This will automatically route tests on the `Do` call.
Expand Down
4 changes: 2 additions & 2 deletions sdk/internal/recording/recording_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ func (s *recordingTests) TestStartStopRecordingClient() {
require.NoError(err)
require.Equal("https://azsdkengsys.azurecr.io/acr/v1/some_registry/_tags",
data.Entries[0].RequestURI)
require.Equal(resp.Request.URL.String(),
fmt.Sprintf("%s/acr/v1/some_registry/_tags", defaultOptions().baseURL()))
require.Equal("https://azsdkengsys.azurecr.io/acr/v1/some_registry/_tags",
resp.Request.URL.String())
}

func (s *recordingTests) TestStopRecordingNoStart() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"Entries": [],
"Variables": {
"randSeed": "1689722394"
"randSeed": "1701821574"
}
}

0 comments on commit 4bc279b

Please sign in to comment.