Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CheckRedirect callback #269

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions client/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types

import (
"context"
"net/http"

"github.com/open-telemetry/opamp-go/protobufs"
)
Expand Down Expand Up @@ -110,6 +111,19 @@ type Callbacks interface {

// OnCommand is called when the Server requests that the connected Agent perform a command.
OnCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) error

// CheckRedirect is called before following a redirect, allowing the client
// the opportunity to observe the redirect chain, and optionally terminate
// following redirects early.
//
// CheckRedirect is intended to be similar, although not exactly equivalent,
// to net/http.Client's CheckRedirect feature. Unlike in net/http, the via
// parameter is a slice of HTTP responses, instead of requests. This gives
// an opportunity to users to know what the exact response headers and
// status were. The request itself can be obtained from the response.
//
// The responses in the via parameter are passed with their bodies closed.
CheckRedirect(req *http.Request, via []*http.Response) error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it valid to return ErrUseLastResponse? I assume it is not valid at least for WS transport since it typically means there is no WS connection established.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity we should probably also link to https://pkg.go.dev/net/http#Client.CheckRedirect somewhere in the comment above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This add CheckRedirect to Callbacks but implements it only for wsClient. I think we also need an implementation for httpClient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tigrannajaryan Thank you for the review!

Regarding ErrUseLastResponse, I would agree it is not applicable here. I think it should result in the same behaviour as returning any other error, however. I'll mention it specifically in the docs.

I agree that we should have an implementation for httpClient. Perhaps the name of this callback should be WSCheckRedirect or similar, denoting that it is only for websocket clients. I don't think the function signature, as implemented, is valid for HTTP clients.

I looked at httpsender and it uses the default HTTP client (with an override method for TLS config). Perhaps httpClient could accept a net/http.Client supplied by the library consumer, instead of having a specific API for things like CheckRedirect. Then the API surface would not grow, and library consumers would be able to use any custom configuration of the HTTP client they desire. Any thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should have an implementation for httpClient. Perhaps the name of this callback should be WSCheckRedirect or similar, denoting that it is only for websocket clients. I don't think the function signature, as implemented, is valid for HTTP clients.

Why is it not valid for HTTP clients?

Perhaps httpClient could accept a net/http.Client supplied by the library consumer, instead of having a specific API for things like CheckRedirect. Then the API surface would not grow, and library consumers would be able to use any custom configuration of the HTTP client they desire. Any thoughts on this?

Yes, this is a good option and we do something similar in the Server implementation, where you can use your own http.Server:

Attach(settings Settings) (HTTPHandlerFunc, ConnContext, error)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HTTP clients should rightly expect to have the signature func(*http.Request, []*http.Request) error, as that would be 1:1 with net/http.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HTTP clients should rightly expect to have the signature func(*http.Request, []*http.Request) error, as that would be 1:1 with net/http.

Do we have the need to use a different signature for WS? I see it is mentioned in the comment but it is not clear to me how exactly you would use the response for WS and why it is useful for WS and not for http.

Ideally we should use the same signature as the standard lib and make this setting applicable to both WS and plain HTTP implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a need for situations where the opamp-go users wants to know what status the server wrote. 301 vs 302 for instance, where the client may want to cache permanent redirects but not cache temporary redirects. Since the opamp-go user isn't privy to the HTTP response, there is no other opportunity to know that information.

I'm definitely open to another design, as I agree it would be better to maintain parity with net/http in general.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echlebek Sorry for long pause, I finally got back to this PR.

What do you think if we change CheckRedirect to look like this:

CheckRedirect(req *http.Request, viaReq []*http.Request, viaResp []*http.Response) error

The wsClient will pass both viaReq and viaResp parameters, while httpClient will only pass viaReq and will leave viaResp nil.

}

// CallbacksStruct is a struct that implements Callbacks interface and allows
Expand All @@ -130,6 +144,10 @@ type CallbacksStruct struct {

SaveRemoteConfigStatusFunc func(ctx context.Context, status *protobufs.RemoteConfigStatus)
GetEffectiveConfigFunc func(ctx context.Context) (*protobufs.EffectiveConfig, error)

// CheckRedirectFunc is called before following a redirect. It is similar in
// nature to the CheckRedirect in net/http's Client.
CheckRedirectFunc func(req *http.Request, via []*http.Response) error
}

var _ Callbacks = (*CallbacksStruct)(nil)
Expand Down Expand Up @@ -194,3 +212,11 @@ func (c CallbacksStruct) OnCommand(ctx context.Context, command *protobufs.Serve
}
return nil
}

// CheckRedirect implements Callbacks.CheckRedirect.
func (c CallbacksStruct) CheckRedirect(req *http.Request, via []*http.Response) error {
if fn := c.CheckRedirectFunc; fn != nil {
return fn(req, via)
}
return nil
}
82 changes: 64 additions & 18 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
// Network connection timeout used for the WebSocket closing handshake.
// This field is currently only modified during testing.
connShutdownTimeout time.Duration

// responseChain is used for the "via" argument in CheckRedirect.
// It is appended to with every redirect followed, and zeroed on a succesful
// connection. responseChain should only be referred to by the goroutine that
// runs tryConnectOnce and its synchronous callees.
responseChain []*http.Response
}

// NewWebSocket creates a new OpAMP Client that uses WebSocket transport.
Expand Down Expand Up @@ -133,35 +139,75 @@
return c.common.SendCustomMessage(message)
}

// handleRedirect checks a failed websocket upgrade response for a 3xx response
// and a Location header. If found, it sets the URL to the location found in the
// header so that it is tried on the next retry, instead of the current URL.
func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error {
// append to the responseChain so that subsequent redirects will have access
c.responseChain = append(c.responseChain, resp)

// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
return err
}

// It's slightly tricky to make CheckRedirect work. The WS HTTP request is
// formed within the websocket library. To work around that, copy the
// previous request, available in the response, and set the URL to the new
// location. It should then result in the same URL that the websocket
// library will form.
nextRequest := resp.Request.Clone(ctx)
nextRequest.URL = redirect

// if CheckRedirect results in an error, it gets returned, terminating
// redirection. As with stdlib, the error is wrapped in url.Error.
if err := c.common.Callbacks.CheckRedirect(nextRequest, c.responseChain); err != nil {
return &url.Error{
Op: "Get",
URL: nextRequest.URL.String(),
Err: err,
}
}

// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"

Check warning on line 178 in client/wsclient.go

View check run for this annotation

Codecov / codecov/patch

client/wsclient.go#L178

Added line #L178 was not covered by tests
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)

// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect

return nil
}

// Try to connect once. Returns an error if connection fails and optional retryAfter
// duration to indicate to the caller to retry after the specified time as instructed
// by the Server.
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
var resp *http.Response
var redirecting bool
defer func() {
if err != nil && !redirecting {
c.responseChain = nil
if c.common.Callbacks != nil && !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
}
}
}()
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader)
if err != nil {
if c.common.Callbacks != nil && !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
}
if resp != nil {
duration := sharedinternal.ExtractRetryAfterHeader(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
redirecting = true
if err := c.handleRedirect(ctx, resp); err != nil {
return duration, err
}
// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect
} else {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
}
Expand Down
116 changes: 112 additions & 4 deletions client/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
Expand All @@ -13,6 +14,7 @@ import (

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"

Expand Down Expand Up @@ -195,12 +197,45 @@ func errServer() *httptest.Server {
}))
}

type checkRedirectMock struct {
mock.Mock
t testing.TB
viaLen int
}

func (c *checkRedirectMock) CheckRedirect(req *http.Request, via []*http.Response) error {
if req == nil {
c.t.Error("nil request in CheckRedirect")
return errors.New("nil request in CheckRedirect")
}
if len(via) > c.viaLen {
c.t.Error("via should be shorter than viaLen")
}
location, err := via[len(via)-1].Location()
if err != nil {
c.t.Error(err)
}
// the URL of the request should match the location header of the last response
assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response")
return c.Called(req, via).Error(0)
}

func mockRedirect(t testing.TB, viaLen int, err error) *checkRedirectMock {
m := &checkRedirectMock{
t: t,
viaLen: viaLen,
}
m.On("CheckRedirect", mock.Anything, mock.Anything).Return(err)
return m
}

func TestRedirectWS(t *testing.T) {
redirectee := internal.StartMockServer(t)
tests := []struct {
Name string
Redirector *httptest.Server
ExpError bool
Name string
Redirector *httptest.Server
ExpError bool
MockRedirect *checkRedirectMock
}{
{
Name: "redirect ws scheme",
Expand All @@ -215,6 +250,17 @@ func TestRedirectWS(t *testing.T) {
Redirector: errServer(),
ExpError: true,
},
{
Name: "check redirect",
Redirector: redirectServer("ws://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirect(t, 1, nil),
},
{
Name: "check redirect returns error",
Redirector: redirectServer("ws://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirect(t, 1, errors.New("hello")),
ExpError: true,
},
}

for _, test := range tests {
Expand All @@ -228,7 +274,7 @@ func TestRedirectWS(t *testing.T) {
var connected int64
var connectErr atomic.Value
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
Callbacks: &types.CallbacksStruct{
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
Expand All @@ -239,6 +285,9 @@ func TestRedirectWS(t *testing.T) {
},
},
}
if test.MockRedirect != nil {
settings.Callbacks.(*types.CallbacksStruct).CheckRedirectFunc = test.MockRedirect.CheckRedirect
}
reURL, err := url.Parse(test.Redirector.URL)
assert.NoError(t, err)
reURL.Scheme = "ws"
Expand All @@ -261,10 +310,69 @@ func TestRedirectWS(t *testing.T) {
// Stop the client.
err = client.Stop(context.Background())
assert.NoError(t, err)

if test.MockRedirect != nil {
test.MockRedirect.AssertCalled(t, "CheckRedirect", mock.Anything, mock.Anything)
}
})
}
}

func TestRedirectWSFollowChain(t *testing.T) {
// test that redirect following is recursive
redirectee := internal.StartMockServer(t)
middle := redirectServer("http://"+redirectee.Endpoint, 302)
middleURL, err := url.Parse(middle.URL)
if err != nil {
// unlikely
t.Fatal(err)
}
redirector := redirectServer("http://"+middleURL.Host, 302)

var conn atomic.Value
redirectee.OnWSConnect = func(c *websocket.Conn) {
conn.Store(c)
}

// Start an OpAMP/WebSocket client.
var connected int64
var connectErr atomic.Value
mr := mockRedirect(t, 2, nil)
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
OnConnectFailedFunc: func(ctx context.Context, err error) {
if err != websocket.ErrBadHandshake {
connectErr.Store(err)
}
},
CheckRedirectFunc: mr.CheckRedirect,
},
}
reURL, err := url.Parse(redirector.URL)
if err != nil {
// unlikely
t.Fatal(err)
}
reURL.Scheme = "ws"
settings.OpAMPServerURL = reURL.String()
client := NewWebSocket(nil)
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool {
return conn.Load() != nil || connectErr.Load() != nil || client.lastInternalErr.Load() != nil
})

assert.True(t, connectErr.Load() == nil)

// Stop the client.
err = client.Stop(context.Background())
assert.NoError(t, err)
}

func TestHandlesStopBeforeStart(t *testing.T) {
client := NewWebSocket(nil)
require.Error(t, client.Stop(context.Background()))
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.5.6 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNs
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
Expand Down
Loading