Skip to content

Commit

Permalink
Validate and encode query parameters (#21296)
Browse files Browse the repository at this point in the history
* Validate and encode query parameters

If the endpoint passed to runtime.NewRequest contains query parameters
they must be validated and encoded.  The typical case is when a paged
operation's nextLink value contains query parameters.

* add test with no query params

* consolidate creating PagingHandler[T].Fetchers

* make EncodeQueryParams its own func

* fix bag merge

* rename
  • Loading branch information
jhendrixMSFT authored Aug 21, 2023
1 parent 12d5832 commit 081c5e3
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

* Added function `FetcherForNextLink` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL.

### Breaking Changes

### Bugs Fixed
Expand Down
24 changes: 24 additions & 0 deletions sdk/azcore/runtime/pager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing"
)

Expand Down Expand Up @@ -88,3 +90,25 @@ func (p *Pager[T]) NextPage(ctx context.Context) (T, error) {
func (p *Pager[T]) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &p.current)
}

// FetcherForNextLink is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher from a next link URL.
func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) {
var req *policy.Request
var err error
if nextLink == "" {
req, err = createReq(ctx)
} else if nextLink, err = EncodeQueryParams(nextLink); err == nil {
req, err = NewRequest(ctx, http.MethodGet, nextLink)
}
if err != nil {
return nil, err
}
resp, err := pl.Do(req)
if err != nil {
return nil, err
}
if !HasStatusCode(resp, http.StatusOK) {
return nil, NewResponseError(resp)
}
return resp, nil
}
57 changes: 57 additions & 0 deletions sdk/azcore/runtime/pager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -256,3 +257,59 @@ func TestPagerResponderError(t *testing.T) {
require.Error(t, err)
require.Empty(t, page)
}

func TestFetcherForNextLink(t *testing.T) {
srv, close := mock.NewServer()
defer close()
pl := exported.NewPipeline(srv)

srv.AppendResponse()
createReqCalled := false
resp, err := FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.NoError(t, err)
require.True(t, createReqCalled)
require.NotNil(t, resp)
require.EqualValues(t, http.StatusOK, resp.StatusCode)

srv.AppendResponse()
createReqCalled = false
resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.NoError(t, err)
require.False(t, createReqCalled)
require.NotNil(t, resp)
require.EqualValues(t, http.StatusOK, resp.StatusCode)

resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) {
return nil, errors.New("failed")
})
require.Error(t, err)
require.Nil(t, resp)

srv.AppendError(errors.New("failed"))
resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.Error(t, err)
require.True(t, createReqCalled)
require.Nil(t, resp)

srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{ "error": { "code": "InvalidResource", "message": "doesn't exist" } }`)))
createReqCalled = false
resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.Error(t, err)
var respErr *exported.ResponseError
require.ErrorAs(t, err, &respErr)
require.EqualValues(t, "InvalidResource", respErr.ErrorCode)
require.False(t, createReqCalled)
require.Nil(t, resp)
}
14 changes: 14 additions & 0 deletions sdk/azcore/runtime/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"fmt"
"io"
"mime/multipart"
"net/url"
"os"
"path"
"reflect"
Expand Down Expand Up @@ -43,6 +44,19 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*polic
return exported.NewRequest(ctx, httpMethod, endpoint)
}

// EncodeQueryParams will parse and encode any query parameters in the specified URL.
func EncodeQueryParams(u string) (string, error) {
before, after, found := strings.Cut(u, "?")
if !found {
return u, nil
}
qp, err := url.ParseQuery(after)
if err != nil {
return "", err
}
return before + "?" + qp.Encode(), nil
}

// JoinPaths concatenates multiple URL path segments into one path,
// inserting path separation characters as required. JoinPaths will preserve
// query parameters in the root path
Expand Down
19 changes: 19 additions & 0 deletions sdk/azcore/runtime/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,22 @@ func TestSetMultipartFormData(t *testing.T) {
require.Equal(t, "second part", string(second))
require.Equal(t, "third part", string(third))
}

func TestEncodeQueryParams(t *testing.T) {
const testURL = "https://contoso.com/"
nextLink, err := EncodeQueryParams(testURL + "query?$skip=5&$filter='foo eq bar'")
require.NoError(t, err)
require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink)
nextLink, err = EncodeQueryParams(testURL + "query?%24filter=%27foo+eq+bar%27&%24skip=5")
require.NoError(t, err)
require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink)
nextLink, err = EncodeQueryParams(testURL + "query?foo=bar&one=two")
require.NoError(t, err)
require.EqualValues(t, testURL+"query?foo=bar&one=two", nextLink)
nextLink, err = EncodeQueryParams(testURL)
require.NoError(t, err)
require.EqualValues(t, testURL, nextLink)
nextLink, err = EncodeQueryParams(testURL + "query?invalid=;semicolon")
require.Error(t, err)
require.Empty(t, nextLink)
}

0 comments on commit 081c5e3

Please sign in to comment.