diff --git a/.chloggen/codeboten_add-option-for-decompressor.yaml b/.chloggen/codeboten_add-option-for-decompressor.yaml new file mode 100755 index 00000000000..f4fca100888 --- /dev/null +++ b/.chloggen/codeboten_add-option-for-decompressor.yaml @@ -0,0 +1,16 @@ +# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix' +change_type: enhancement + +# The name of the component, or a single word describing the area of concern, (e.g. otlpreceiver) +component: confighttp + +# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`). +note: "Add support for additional content decoders via `WithDecoder` server option" + +# One or more tracking issues or pull requests related to the change +issues: [7977] + +# (Optional) One or more lines of additional information to render under the primary note. +# These lines will be padded with 2 spaces and then inserted directly into the document. +# Use pipe (|) for multiline entries. +subtext: diff --git a/config/confighttp/compression.go b/config/confighttp/compression.go index 962b6b461f5..faa716c4414 100644 --- a/config/confighttp/compression.go +++ b/config/confighttp/compression.go @@ -68,25 +68,68 @@ func (r *compressRoundTripper) RoundTrip(req *http.Request) (*http.Response, err type decompressor struct { errHandler func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int) base http.Handler + decoders map[string]func(body io.ReadCloser) (io.ReadCloser, error) } // httpContentDecompressor offloads the task of handling compressed HTTP requests // by identifying the compression format in the "Content-Encoding" header and re-writing // request body so that the handlers further in the chain can work on decompressed data. // It supports gzip and deflate/zlib compression. -func httpContentDecompressor(h http.Handler, eh func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int)) http.Handler { +func httpContentDecompressor(h http.Handler, eh func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int), decoders map[string]func(body io.ReadCloser) (io.ReadCloser, error)) http.Handler { errHandler := defaultErrorHandler if eh != nil { errHandler = eh } - return &decompressor{ + + d := &decompressor{ errHandler: errHandler, base: h, + decoders: map[string]func(body io.ReadCloser) (io.ReadCloser, error){ + "": func(body io.ReadCloser) (io.ReadCloser, error) { + // Not a compressed payload. Nothing to do. + return nil, nil + }, + "gzip": func(body io.ReadCloser) (io.ReadCloser, error) { + gr, err := gzip.NewReader(body) + if err != nil { + return nil, err + } + return gr, nil + }, + "zstd": func(body io.ReadCloser) (io.ReadCloser, error) { + zr, err := zstd.NewReader( + body, + // Concurrency 1 disables async decoding. We don't need async decoding, it is pointless + // for our use-case (a server accepting decoding http requests). + // Disabling async improves performance (I benchmarked it previously when working + // on https://github.com/open-telemetry/opentelemetry-collector-contrib/pull/23257). + zstd.WithDecoderConcurrency(1), + ) + if err != nil { + return nil, err + } + return zr.IOReadCloser(), nil + }, + "zlib": func(body io.ReadCloser) (io.ReadCloser, error) { + zr, err := zlib.NewReader(body) + if err != nil { + return nil, err + } + return zr, nil + }, + }, + } + d.decoders["deflate"] = d.decoders["zlib"] + + for key, dec := range decoders { + d.decoders[key] = dec } + + return d } func (d *decompressor) ServeHTTP(w http.ResponseWriter, r *http.Request) { - newBody, err := newBodyReader(r) + newBody, err := d.newBodyReader(r) if err != nil { d.errHandler(w, r, err.Error(), http.StatusBadRequest) return @@ -104,39 +147,13 @@ func (d *decompressor) ServeHTTP(w http.ResponseWriter, r *http.Request) { d.base.ServeHTTP(w, r) } -func newBodyReader(r *http.Request) (io.ReadCloser, error) { +func (d *decompressor) newBodyReader(r *http.Request) (io.ReadCloser, error) { encoding := r.Header.Get(headerContentEncoding) - switch encoding { - case string(configcompression.Gzip): - gr, err := gzip.NewReader(r.Body) - if err != nil { - return nil, err - } - return gr, nil - case string(configcompression.Deflate), string(configcompression.Zlib): - zr, err := zlib.NewReader(r.Body) - if err != nil { - return nil, err - } - return zr, nil - case "zstd": - zr, err := zstd.NewReader( - r.Body, - // Concurrency 1 disables async decoding. We don't need async decoding, it is pointless - // for our use-case (a server accepting decoding http requests). - // Disabling async improves performance (I benchmarked it previously when working - // on https://github.com/open-telemetry/opentelemetry-collector-contrib/pull/23257). - zstd.WithDecoderConcurrency(1), - ) - if err != nil { - return nil, err - } - return zr.IOReadCloser(), nil - case "": - // Not a compressed payload. Nothing to do. - return nil, nil + decoder, ok := d.decoders[encoding] + if !ok { + return nil, fmt.Errorf("unsupported %s: %s", headerContentEncoding, encoding) } - return nil, fmt.Errorf("unsupported %s: %s", headerContentEncoding, encoding) + return decoder(r.Body) } // defaultErrorHandler writes the error message in plain text. diff --git a/config/confighttp/compression_test.go b/config/confighttp/compression_test.go index 0b5e8141f7a..091d9faff12 100644 --- a/config/confighttp/compression_test.go +++ b/config/confighttp/compression_test.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/golang/snappy" @@ -114,8 +115,44 @@ func TestHTTPClientCompression(t *testing.T) { } } +func TestHTTPCustomDecompression(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + + require.NoError(t, err, "failed to read request body: %v", err) + assert.EqualValues(t, "decompressed body", string(body)) + w.WriteHeader(http.StatusOK) + }) + decoders := map[string]func(io.ReadCloser) (io.ReadCloser, error){ + "custom-encoding": func(io.ReadCloser) (io.ReadCloser, error) { // nolint: unparam + return io.NopCloser(strings.NewReader("decompressed body")), nil + }, + } + srv := httptest.NewServer(httpContentDecompressor(handler, defaultErrorHandler, decoders)) + + t.Cleanup(srv.Close) + + req, err := http.NewRequest(http.MethodGet, srv.URL, bytes.NewBuffer([]byte("123decompressed body"))) + require.NoError(t, err, "failed to create request to test handler") + req.Header.Set("Content-Encoding", "custom-encoding") + + client := http.Client{} + res, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, res.StatusCode, "test handler returned unexpected status code ") + _, err = io.ReadAll(res.Body) + require.NoError(t, res.Body.Close(), "failed to close request body: %v", err) +} + func TestHTTPContentDecompressionHandler(t *testing.T) { testBody := []byte("uncompressed_text") + noDecoders := map[string]func(io.ReadCloser) (io.ReadCloser, error){} tests := []struct { name string encoding string @@ -202,7 +239,7 @@ func TestHTTPContentDecompressionHandler(t *testing.T) { require.NoError(t, err, "failed to read request body: %v", err) assert.EqualValues(t, testBody, string(body)) w.WriteHeader(http.StatusOK) - }), defaultErrorHandler)) + }), defaultErrorHandler, noDecoders)) t.Cleanup(srv.Close) req, err := http.NewRequest(http.MethodGet, srv.URL, tt.reqBody) diff --git a/config/confighttp/confighttp.go b/config/confighttp/confighttp.go index 991d1822424..984ada15c47 100644 --- a/config/confighttp/confighttp.go +++ b/config/confighttp/confighttp.go @@ -6,6 +6,7 @@ package confighttp // import "go.opentelemetry.io/collector/config/confighttp" import ( "crypto/tls" "errors" + "io" "net" "net/http" "time" @@ -248,6 +249,7 @@ func (hss *HTTPServerSettings) ToListener() (net.Listener, error) { // returned by HTTPServerSettings.ToServer(). type toServerOptions struct { errHandler func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int) + decoders map[string]func(body io.ReadCloser) (io.ReadCloser, error) } // ToServerOption is an option to change the behavior of the HTTP server @@ -262,6 +264,17 @@ func WithErrorHandler(e func(w http.ResponseWriter, r *http.Request, errorMsg st } } +// WithDecoder provides support for additional decoders to be configured +// by the caller. +func WithDecoder(key string, dec func(body io.ReadCloser) (io.ReadCloser, error)) ToServerOption { + return func(opts *toServerOptions) { + if opts.decoders == nil { + opts.decoders = map[string]func(body io.ReadCloser) (io.ReadCloser, error){} + } + opts.decoders[key] = dec + } +} + // ToServer creates an http.Server from settings object. func (hss *HTTPServerSettings) ToServer(host component.Host, settings component.TelemetrySettings, handler http.Handler, opts ...ToServerOption) (*http.Server, error) { internal.WarnOnUnspecifiedHost(settings.Logger, hss.Endpoint) @@ -271,7 +284,7 @@ func (hss *HTTPServerSettings) ToServer(host component.Host, settings component. o(serverOpts) } - handler = httpContentDecompressor(handler, serverOpts.errHandler) + handler = httpContentDecompressor(handler, serverOpts.errHandler, serverOpts.decoders) if hss.MaxRequestBodySize > 0 { handler = maxRequestBodySizeInterceptor(handler, hss.MaxRequestBodySize) diff --git a/config/confighttp/confighttp_test.go b/config/confighttp/confighttp_test.go index a1cfa4f2478..a5460fe6b56 100644 --- a/config/confighttp/confighttp_test.go +++ b/config/confighttp/confighttp_test.go @@ -1116,6 +1116,66 @@ func TestFailedServerAuth(t *testing.T) { assert.Equal(t, response.Result().Status, fmt.Sprintf("%v %s", http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))) } +func TestServerWithErrorHandler(t *testing.T) { + // prepare + hss := HTTPServerSettings{ + Endpoint: "localhost:0", + } + eh := func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int) { + assert.Equal(t, statusCode, http.StatusBadRequest) + // custom error handler changes returned status code + http.Error(w, "invalid request", http.StatusInternalServerError) + + } + + srv, err := hss.ToServer( + componenttest.NewNopHost(), + componenttest.NewNopTelemetrySettings(), + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + WithErrorHandler(eh), + ) + require.NoError(t, err) + // test + response := &httptest.ResponseRecorder{} + + req, err := http.NewRequest(http.MethodGet, srv.Addr, nil) + require.NoError(t, err, "Error creating request: %v", err) + req.Header.Set("Content-Encoding", "something-invalid") + + srv.Handler.ServeHTTP(response, req) + // verify + assert.Equal(t, response.Result().StatusCode, http.StatusInternalServerError) +} + +func TestServerWithDecoder(t *testing.T) { + // prepare + hss := HTTPServerSettings{ + Endpoint: "localhost:0", + } + decoder := func(body io.ReadCloser) (io.ReadCloser, error) { + return body, nil + } + + srv, err := hss.ToServer( + componenttest.NewNopHost(), + componenttest.NewNopTelemetrySettings(), + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + WithDecoder("something-else", decoder), + ) + require.NoError(t, err) + // test + response := &httptest.ResponseRecorder{} + + req, err := http.NewRequest(http.MethodGet, srv.Addr, nil) + require.NoError(t, err, "Error creating request: %v", err) + req.Header.Set("Content-Encoding", "something-else") + + srv.Handler.ServeHTTP(response, req) + // verify + assert.Equal(t, response.Result().StatusCode, http.StatusOK) + +} + type mockHost struct { component.Host ext map[component.ID]component.Component