Skip to content

Commit

Permalink
Add client context propagation (#248)
Browse files Browse the repository at this point in the history
This is a follow up to #237 and #247, adding context propagation for client methods. 

**This involves a breaking change for the client interfaces**
  • Loading branch information
jaronoff97 authored Jan 26, 2024
1 parent 4d07a6a commit fab35cf
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 57 deletions.
12 changes: 6 additions & 6 deletions client/clientimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func TestOnConnectFail(t *testing.T) {
var connectErr atomic.Value
settings := createNoServerSettings()
settings.Callbacks = types.CallbacksStruct{
OnConnectFailedFunc: func(err error) {
OnConnectFailedFunc: func(ctx context.Context, err error) {
connectErr.Store(err)
},
}
Expand Down Expand Up @@ -238,7 +238,7 @@ func TestConnectWithServer(t *testing.T) {
var connected int64
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
},
Expand Down Expand Up @@ -276,11 +276,11 @@ func TestConnectWithServer503(t *testing.T) {
var connectErr atomic.Value
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&clientConnected, 1)
assert.Fail(t, "Client should not be able to connect")
},
OnConnectFailedFunc: func(err error) {
OnConnectFailedFunc: func(ctx context.Context, err error) {
connectErr.Store(err)
},
},
Expand Down Expand Up @@ -405,7 +405,7 @@ func TestFirstStatusReport(t *testing.T) {
var connected, remoteConfigReceived int64
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
atomic.AddInt64(&connected, 1)
},
OnMessageFunc: func(ctx context.Context, msg *types.MessageData) {
Expand Down Expand Up @@ -458,7 +458,7 @@ func TestIncludesDetailsOnReconnect(t *testing.T) {
var connected int64
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
atomic.AddInt64(&connected, 1)
},
},
Expand Down
4 changes: 2 additions & 2 deletions client/internal/httpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (h *HTTPSender) sendRequestWithRetries(ctx context.Context) (*http.Response
switch resp.StatusCode {
case http.StatusOK:
// We consider it connected if we receive 200 status from the Server.
h.callbacks.OnConnect()
h.callbacks.OnConnect(ctx)
return resp, nil

case http.StatusTooManyRequests, http.StatusServiceUnavailable:
Expand All @@ -195,7 +195,7 @@ func (h *HTTPSender) sendRequestWithRetries(ctx context.Context) (*http.Response
}

h.logger.Errorf(ctx, "Failed to do HTTP request (%v), will retry", err)
h.callbacks.OnConnectFailed(err)
h.callbacks.OnConnectFailed(ctx, err)
}

case <-ctx.Done():
Expand Down
8 changes: 4 additions & 4 deletions client/internal/httpsender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func TestHTTPSenderRetryForStatusTooManyRequests(t *testing.T) {
}
})
sender.callbacks = types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
},
OnConnectFailedFunc: func(_ error) {
OnConnectFailedFunc: func(ctx context.Context, _ error) {
},
}
sender.url = url
Expand Down Expand Up @@ -144,9 +144,9 @@ func TestHTTPSenderRetryForFailedRequests(t *testing.T) {
}
})
sender.callbacks = types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
},
OnConnectFailedFunc: func(_ error) {
OnConnectFailedFunc: func(ctx context.Context, _ error) {
},
}
sender.url = url
Expand Down
13 changes: 7 additions & 6 deletions client/internal/receivedprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (r *receivedProcessor) ProcessReceivedMessage(ctx context.Context, msg *pro
// to process.
if msg.Command != nil {
if r.hasCapability(protobufs.AgentCapabilities_AgentCapabilities_AcceptsRestartCommand) {
r.rcvCommand(msg.Command)
r.rcvCommand(ctx, msg.Command)
// If a command message exists, other messages will be ignored
return
} else {
Expand Down Expand Up @@ -198,16 +198,17 @@ func (r *receivedProcessor) rcvOpampConnectionSettings(ctx context.Context, sett
err := r.callbacks.OnOpampConnectionSettings(ctx, settings.Opamp)
if err == nil {
// TODO: verify connection using new settings.
r.callbacks.OnOpampConnectionSettingsAccepted(settings.Opamp)
r.callbacks.OnOpampConnectionSettingsAccepted(ctx, settings.Opamp)
}
} else {
r.logger.Debugf(ctx, "Ignoring Opamp, agent does not have AcceptsOpAMPConnectionSettings capability")
}
}

func (r *receivedProcessor) processErrorResponse(ctx context.Context, body *protobufs.ServerErrorResponse) {
// TODO: implement this.
r.logger.Errorf(ctx, "received an error from server: %s", body.ErrorMessage)
if body != nil {
r.callbacks.OnError(ctx, body)
}
}

func (r *receivedProcessor) rcvAgentIdentification(ctx context.Context, agentId *protobufs.AgentIdentification) error {
Expand All @@ -226,8 +227,8 @@ func (r *receivedProcessor) rcvAgentIdentification(ctx context.Context, agentId
return nil
}

func (r *receivedProcessor) rcvCommand(command *protobufs.ServerToAgentCommand) {
func (r *receivedProcessor) rcvCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) {
if command != nil {
r.callbacks.OnCommand(command)
r.callbacks.OnCommand(ctx, command)
}
}
4 changes: 2 additions & 2 deletions client/internal/wsreceiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestServerToAgentCommand(t *testing.T) {
action := none

callbacks := types.CallbacksStruct{
OnCommandFunc: func(command *protobufs.ServerToAgentCommand) error {
OnCommandFunc: func(ctx context.Context, command *protobufs.ServerToAgentCommand) error {
switch command.Type {
case protobufs.CommandType_CommandType_Restart:
action = restart
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestServerToAgentCommandExclusive(t *testing.T) {
calledOnMessageConfig := false

callbacks := types.CallbacksStruct{
OnCommandFunc: func(command *protobufs.ServerToAgentCommand) error {
OnCommandFunc: func(ctx context.Context, command *protobufs.ServerToAgentCommand) error {
calledCommand = true
return nil
},
Expand Down
42 changes: 21 additions & 21 deletions client/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,24 @@ type MessageData struct {
}

// Callbacks is an interface for the Client to handle messages from the Server.
// Callbacks are expected to honour the context passed to them, meaning they should be aware of cancellations.
type Callbacks interface {
// OnConnect is called when the connection is successfully established to the Server.
// May be called after Start() is called and every time a connection is established to the Server.
// For WebSocket clients this is called after the handshake is completed without any error.
// For HTTP clients this is called for any request if the response status is OK.
OnConnect()
OnConnect(ctx context.Context)

// OnConnectFailed is called when the connection to the Server cannot be established.
// May be called after Start() is called and tries to connect to the Server.
// May also be called if the connection is lost and reconnection attempt fails.
OnConnectFailed(err error)
OnConnectFailed(ctx context.Context, err error)

// OnError is called when the Server reports an error in response to some previously
// sent request. Useful for logging purposes. The Agent should not attempt to process
// the error by reconnecting or retrying previous operations. The client handles the
// ErrorResponse_UNAVAILABLE case internally by performing retries as necessary.
OnError(err *protobufs.ServerErrorResponse)
OnError(ctx context.Context, err *protobufs.ServerErrorResponse)

// OnMessage is called when the Agent receives a message that needs processing.
// See MessageData definition for the data that may be available for processing.
Expand Down Expand Up @@ -94,9 +95,7 @@ type Callbacks interface {
// verified and accepted (OnOpampConnectionSettingsOffer and connection using
// new settings succeeds). The Agent should store the settings and use them
// in the future. Old connection settings should be forgotten.
OnOpampConnectionSettingsAccepted(
settings *protobufs.OpAMPConnectionSettings,
)
OnOpampConnectionSettingsAccepted(ctx context.Context, settings *protobufs.OpAMPConnectionSettings)

// For all methods that accept a context parameter the caller may cancel the
// context if processing takes too long. In that case the method should return
Expand All @@ -115,15 +114,15 @@ type Callbacks interface {
GetEffectiveConfig(ctx context.Context) (*protobufs.EffectiveConfig, error)

// OnCommand is called when the Server requests that the connected Agent perform a command.
OnCommand(command *protobufs.ServerToAgentCommand) error
OnCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) error
}

// CallbacksStruct is a struct that implements Callbacks interface and allows
// to override only the methods that are needed. If a method is not overridden then it is a no-op.
type CallbacksStruct struct {
OnConnectFunc func()
OnConnectFailedFunc func(err error)
OnErrorFunc func(err *protobufs.ServerErrorResponse)
OnConnectFunc func(ctx context.Context)
OnConnectFailedFunc func(ctx context.Context, err error)
OnErrorFunc func(ctx context.Context, err *protobufs.ServerErrorResponse)

OnMessageFunc func(ctx context.Context, msg *MessageData)

Expand All @@ -132,10 +131,11 @@ type CallbacksStruct struct {
settings *protobufs.OpAMPConnectionSettings,
) error
OnOpampConnectionSettingsAcceptedFunc func(
ctx context.Context,
settings *protobufs.OpAMPConnectionSettings,
)

OnCommandFunc func(command *protobufs.ServerToAgentCommand) error
OnCommandFunc func(ctx context.Context, command *protobufs.ServerToAgentCommand) error

SaveRemoteConfigStatusFunc func(ctx context.Context, status *protobufs.RemoteConfigStatus)
GetEffectiveConfigFunc func(ctx context.Context) (*protobufs.EffectiveConfig, error)
Expand All @@ -144,23 +144,23 @@ type CallbacksStruct struct {
var _ Callbacks = (*CallbacksStruct)(nil)

// OnConnect implements Callbacks.OnConnect.
func (c CallbacksStruct) OnConnect() {
func (c CallbacksStruct) OnConnect(ctx context.Context) {
if c.OnConnectFunc != nil {
c.OnConnectFunc()
c.OnConnectFunc(ctx)
}
}

// OnConnectFailed implements Callbacks.OnConnectFailed.
func (c CallbacksStruct) OnConnectFailed(err error) {
func (c CallbacksStruct) OnConnectFailed(ctx context.Context, err error) {
if c.OnConnectFailedFunc != nil {
c.OnConnectFailedFunc(err)
c.OnConnectFailedFunc(ctx, err)
}
}

// OnError implements Callbacks.OnError.
func (c CallbacksStruct) OnError(err *protobufs.ServerErrorResponse) {
func (c CallbacksStruct) OnError(ctx context.Context, err *protobufs.ServerErrorResponse) {
if c.OnErrorFunc != nil {
c.OnErrorFunc(err)
c.OnErrorFunc(ctx, err)
}
}

Expand Down Expand Up @@ -197,16 +197,16 @@ func (c CallbacksStruct) OnOpampConnectionSettings(
}

// OnOpampConnectionSettingsAccepted implements Callbacks.OnOpampConnectionSettingsAccepted.
func (c CallbacksStruct) OnOpampConnectionSettingsAccepted(settings *protobufs.OpAMPConnectionSettings) {
func (c CallbacksStruct) OnOpampConnectionSettingsAccepted(ctx context.Context, settings *protobufs.OpAMPConnectionSettings) {
if c.OnOpampConnectionSettingsAcceptedFunc != nil {
c.OnOpampConnectionSettingsAcceptedFunc(settings)
c.OnOpampConnectionSettingsAcceptedFunc(ctx, settings)
}
}

// OnCommand implements Callbacks.OnCommand.
func (c CallbacksStruct) OnCommand(command *protobufs.ServerToAgentCommand) error {
func (c CallbacksStruct) OnCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) error {
if c.OnCommandFunc != nil {
return c.OnCommandFunc(command)
return c.OnCommandFunc(ctx, command)
}
return nil
}
4 changes: 2 additions & 2 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (err error, retryAfter sh
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader)
if err != nil {
if c.common.Callbacks != nil && !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(err)
c.common.Callbacks.OnConnectFailed(ctx, err)
}
if resp != nil {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
Expand All @@ -143,7 +143,7 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (err error, retryAfter sh
c.conn = conn
c.connMutex.Unlock()
if c.common.Callbacks != nil {
c.common.Callbacks.OnConnect()
c.common.Callbacks.OnConnect(ctx)
}

return nil, sharedinternal.OptionalDuration{Defined: false}
Expand Down
4 changes: 2 additions & 2 deletions client/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ func TestDisconnectWSByServer(t *testing.T) {
var connectErr atomic.Value
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
OnConnectFailedFunc: func(err error) {
OnConnectFailedFunc: func(ctx context.Context, err error) {
connectErr.Store(err)
},
},
Expand Down
12 changes: 6 additions & 6 deletions internal/examples/agent/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ func (agent *Agent) connect() error {
TLSConfig: tlsConfig,
InstanceUid: agent.instanceId.String(),
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
agent.logger.Debugf(context.Background(), "Connected to the server.")
OnConnectFunc: func(ctx context.Context) {
agent.logger.Debugf(ctx, "Connected to the server.")
},
OnConnectFailedFunc: func(err error) {
agent.logger.Errorf(context.Background(), "Failed to connect to the server: %v", err)
OnConnectFailedFunc: func(ctx context.Context, err error) {
agent.logger.Errorf(ctx, "Failed to connect to the server: %v", err)
},
OnErrorFunc: func(err *protobufs.ServerErrorResponse) {
agent.logger.Errorf(context.Background(), "Server returned an error response: %v", err.ErrorMessage)
OnErrorFunc: func(ctx context.Context, err *protobufs.ServerErrorResponse) {
agent.logger.Errorf(ctx, "Server returned an error response: %v", err.ErrorMessage)
},
SaveRemoteConfigStatusFunc: func(_ context.Context, status *protobufs.RemoteConfigStatus) {
agent.remoteConfigStatus = status
Expand Down
12 changes: 6 additions & 6 deletions internal/examples/supervisor/supervisor/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ func (s *Supervisor) startOpAMP() error {
},
InstanceUid: s.instanceId.String(),
Callbacks: types.CallbacksStruct{
OnConnectFunc: func() {
s.logger.Debugf(context.Background(), "Connected to the server.")
OnConnectFunc: func(ctx context.Context) {
s.logger.Debugf(ctx, "Connected to the server.")
},
OnConnectFailedFunc: func(err error) {
s.logger.Errorf(context.Background(), "Failed to connect to the server: %v", err)
OnConnectFailedFunc: func(ctx context.Context, err error) {
s.logger.Errorf(ctx, "Failed to connect to the server: %v", err)
},
OnErrorFunc: func(err *protobufs.ServerErrorResponse) {
s.logger.Errorf(context.Background(), "Server returned an error response: %v", err.ErrorMessage)
OnErrorFunc: func(ctx context.Context, err *protobufs.ServerErrorResponse) {
s.logger.Errorf(ctx, "Server returned an error response: %v", err.ErrorMessage)
},
GetEffectiveConfigFunc: func(ctx context.Context) (*protobufs.EffectiveConfig, error) {
return s.createEffectiveConfigMsg(), nil
Expand Down

0 comments on commit fab35cf

Please sign in to comment.