diff --git a/client/clientimpl_test.go b/client/clientimpl_test.go index e41f4557..76ca41aa 100644 --- a/client/clientimpl_test.go +++ b/client/clientimpl_test.go @@ -340,6 +340,79 @@ func TestConnectWithHeader(t *testing.T) { }) } +func TestConnectWithHeaderFunc(t *testing.T) { + testClients(t, func(t *testing.T, client OpAMPClient) { + // Start a server. + srv := internal.StartMockServer(t) + var conn atomic.Value + srv.OnConnect = func(r *http.Request) { + authHdr := r.Header.Get("Authorization") + assert.EqualValues(t, "Bearer 12345678", authHdr) + userAgentHdr := r.Header.Get("User-Agent") + assert.EqualValues(t, "custom-agent/1.0", userAgentHdr) + conn.Store(true) + } + + hf := func(header http.Header) http.Header { + header.Set("Authorization", "Bearer 12345678") + header.Set("User-Agent", "custom-agent/1.0") + return header + } + + // Start a client. + settings := types.StartSettings{ + OpAMPServerURL: "ws://" + srv.Endpoint, + HeaderFunc: hf, + } + startClient(t, settings, client) + + // Wait for connection to be established. + eventually(t, func() bool { return conn.Load() != nil }) + + // Shutdown the Server and the client. + srv.Close() + _ = client.Stop(context.Background()) + }) +} + +func TestConnectWithHeaderAndHeaderFunc(t *testing.T) { + testClients(t, func(t *testing.T, client OpAMPClient) { + // Start a server. + srv := internal.StartMockServer(t) + var conn atomic.Value + srv.OnConnect = func(r *http.Request) { + authHdr := r.Header.Get("Authorization") + assert.EqualValues(t, "Bearer 12345678", authHdr) + userAgentHdr := r.Header.Get("User-Agent") + assert.EqualValues(t, "custom-agent/1.0", userAgentHdr) + conn.Store(true) + } + + baseHeader := http.Header{} + baseHeader.Set("User-Agent", "custom-agent/1.0") + + hf := func(header http.Header) http.Header { + header.Set("Authorization", "Bearer 12345678") + return header + } + + // Start a client. + settings := types.StartSettings{ + OpAMPServerURL: "ws://" + srv.Endpoint, + Header: baseHeader, + HeaderFunc: hf, + } + startClient(t, settings, client) + + // Wait for connection to be established. + eventually(t, func() bool { return conn.Load() != nil }) + + // Shutdown the Server and the client. + srv.Close() + _ = client.Stop(context.Background()) + }) +} + func TestConnectWithTLS(t *testing.T) { testClients(t, func(t *testing.T, client OpAMPClient) { // Start a server. diff --git a/client/httpclient.go b/client/httpclient.go index 666747c0..92259d59 100644 --- a/client/httpclient.go +++ b/client/httpclient.go @@ -44,7 +44,7 @@ func (c *httpClient) Start(ctx context.Context, settings types.StartSettings) er c.opAMPServerURL = settings.OpAMPServerURL // Prepare Server connection settings. - c.sender.SetRequestHeader(settings.Header) + c.sender.SetRequestHeader(settings.Header, settings.HeaderFunc) // Add TLS configuration into httpClient c.sender.AddTLSConfig(settings.TLSConfig) diff --git a/client/internal/httpsender.go b/client/internal/httpsender.go index e48e3ce0..502bf7e4 100644 --- a/client/internal/httpsender.go +++ b/client/internal/httpsender.go @@ -59,7 +59,7 @@ type HTTPSender struct { compressionEnabled bool // Headers to send with all requests. - requestHeader http.Header + getHeader func() http.Header // Processor to handle received messages. receiveProcessor receivedProcessor @@ -75,7 +75,7 @@ func NewHTTPSender(logger types.Logger) *HTTPSender { pollingIntervalMs: defaultPollingIntervalMs, } // initialize the headers with no additional headers - h.SetRequestHeader(nil) + h.SetRequestHeader(nil, nil) return h } @@ -121,12 +121,26 @@ func (h *HTTPSender) Run( // SetRequestHeader sets additional HTTP headers to send with all future requests. // Should not be called concurrently with any other method. -func (h *HTTPSender) SetRequestHeader(header http.Header) { - if header == nil { - header = http.Header{} +func (h *HTTPSender) SetRequestHeader(baseHeaders http.Header, headerFunc func(http.Header) http.Header) { + if baseHeaders == nil { + baseHeaders = http.Header{} + } + + if headerFunc == nil { + headerFunc = func(h http.Header) http.Header { + return h + } + } + + h.getHeader = func() http.Header { + requestHeader := headerFunc(baseHeaders.Clone()) + requestHeader.Set(headerContentType, contentTypeProtobuf) + if h.compressionEnabled { + requestHeader.Set(headerContentEncoding, encodingTypeGZip) + } + + return requestHeader } - h.requestHeader = header - h.requestHeader.Set(headerContentType, contentTypeProtobuf) } // makeOneRequestRoundtrip sends a request and receives a response. @@ -255,7 +269,7 @@ func (h *HTTPSender) prepareRequest(ctx context.Context) (*requestWrapper, error return nil, err } - req.Header = h.requestHeader + req.Header = h.getHeader() return &req, nil } @@ -295,9 +309,10 @@ func (h *HTTPSender) SetPollingInterval(duration time.Duration) { atomic.StoreInt64(&h.pollingIntervalMs, duration.Milliseconds()) } +// EnableCompression enables compression for the sender. +// Should not be called concurrently with Run. func (h *HTTPSender) EnableCompression() { h.compressionEnabled = true - h.requestHeader.Set(headerContentEncoding, encodingTypeGZip) } func (h *HTTPSender) AddTLSConfig(config *tls.Config) { diff --git a/client/types/startsettings.go b/client/types/startsettings.go index 857fd2a4..b968c7ff 100644 --- a/client/types/startsettings.go +++ b/client/types/startsettings.go @@ -17,6 +17,11 @@ type StartSettings struct { // Optional additional HTTP headers to send with all HTTP requests. Header http.Header + // Optional function that can be used to modify the HTTP headers + // before each HTTP request. + // Can modify and return the argument or return the argument without modifying. + HeaderFunc func(http.Header) http.Header + // Optional TLS config for HTTP connection. TLSConfig *tls.Config diff --git a/client/wsclient.go b/client/wsclient.go index 22d434c6..dc703dcb 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -31,7 +31,7 @@ type wsClient struct { url *url.URL // HTTP request headers to use when connecting to OpAMP Server. - requestHeader http.Header + getHeader func() http.Header // Websocket dialer and connection. dialer websocket.Dialer @@ -86,7 +86,21 @@ func (c *wsClient) Start(ctx context.Context, settings types.StartSettings) erro } c.dialer.TLSClientConfig = settings.TLSConfig - c.requestHeader = settings.Header + headerFunc := settings.HeaderFunc + if headerFunc == nil { + headerFunc = func(h http.Header) http.Header { + return h + } + } + + baseHeader := settings.Header + if baseHeader == nil { + baseHeader = http.Header{} + } + + c.getHeader = func() http.Header { + return headerFunc(baseHeader.Clone()) + } c.common.StartConnectAndRun(c.runUntilStopped) @@ -142,7 +156,7 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS // by the Server. func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) { var resp *http.Response - conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader) + conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader()) if err != nil { if c.common.Callbacks != nil && !c.common.IsStopping() { c.common.Callbacks.OnConnectFailed(ctx, err)