Skip to content

Commit

Permalink
[chore]: Avoid private types in public API, remove options on interna…
Browse files Browse the repository at this point in the history
…l funcs (#7870)

Signed-off-by: Bogdan Drutu <bogdandrutu@gmail.com>
  • Loading branch information
bogdandrutu authored Jun 27, 2023
1 parent 5852d09 commit 8278824
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 60 deletions.
67 changes: 28 additions & 39 deletions config/confighttp/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,66 +65,55 @@ func (r *compressRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
return r.rt.RoundTrip(cReq)
}

type errorHandler func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int)

type decompressor struct {
errorHandler
}

type decompressorOption func(d *decompressor)

func withErrorHandlerForDecompressor(e errorHandler) decompressorOption {
return func(d *decompressor) {
d.errorHandler = e
}
errHandler func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int)
base http.Handler
}

// httpContentDecompressor offloads the task of handling compressed HTTP requests
// by identifying the compression format in the "Content-Encoding" header and re-writing
// request body so that the handlers further in the chain can work on decompressed data.
// It supports gzip and deflate/zlib compression.
func httpContentDecompressor(h http.Handler, opts ...decompressorOption) http.Handler {
d := &decompressor{}
for _, o := range opts {
o(d)
func httpContentDecompressor(h http.Handler, eh func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int)) http.Handler {
errHandler := defaultErrorHandler
if eh != nil {
errHandler = eh
}
if d.errorHandler == nil {
d.errorHandler = defaultErrorHandler
return &decompressor{
errHandler: errHandler,
base: h,
}
return d.wrap(h)
}

func (d *decompressor) wrap(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newBody, err := newBodyReader(r)
if err != nil {
d.errorHandler(w, r, err.Error(), http.StatusBadRequest)
return
}
if newBody != nil {
defer newBody.Close()
// "Content-Encoding" header is removed to avoid decompressing twice
// in case the next handler(s) have implemented a similar mechanism.
r.Header.Del(headerContentEncoding)
// "Content-Length" is set to -1 as the size of the decompressed body is unknown.
r.Header.Del("Content-Length")
r.ContentLength = -1
r.Body = newBody
}
h.ServeHTTP(w, r)
})
func (d *decompressor) ServeHTTP(w http.ResponseWriter, r *http.Request) {
newBody, err := newBodyReader(r)
if err != nil {
d.errHandler(w, r, err.Error(), http.StatusBadRequest)
return
}
if newBody != nil {
defer newBody.Close()
// "Content-Encoding" header is removed to avoid decompressing twice
// in case the next handler(s) have implemented a similar mechanism.
r.Header.Del("Content-Encoding")
// "Content-Length" is set to -1 as the size of the decompressed body is unknown.
r.Header.Del("Content-Length")
r.ContentLength = -1
r.Body = newBody
}
d.base.ServeHTTP(w, r)
}

func newBodyReader(r *http.Request) (io.ReadCloser, error) {
encoding := r.Header.Get(headerContentEncoding)
switch encoding {
case "gzip":
case string(configcompression.Gzip):
gr, err := gzip.NewReader(r.Body)
if err != nil {
return nil, err
}
return gr, nil
case "deflate", "zlib":
case string(configcompression.Deflate), string(configcompression.Zlib):
zr, err := zlib.NewReader(r.Body)
if err != nil {
return nil, err
Expand Down
26 changes: 13 additions & 13 deletions config/confighttp/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestHTTPClientCompression(t *testing.T) {
body, err := io.ReadAll(r.Body)
require.NoError(t, err, "failed to read request body: %v", err)
assert.EqualValues(t, tt.reqBody, body)
w.WriteHeader(200)
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(srv.Close)

Expand Down Expand Up @@ -127,52 +127,52 @@ func TestHTTPContentDecompressionHandler(t *testing.T) {
name: "NoCompression",
encoding: "",
reqBody: bytes.NewBuffer(testBody),
respCode: 200,
respCode: http.StatusOK,
},
{
name: "ValidGzip",
encoding: "gzip",
reqBody: compressGzip(t, testBody),
respCode: 200,
respCode: http.StatusOK,
},
{
name: "ValidZlib",
encoding: "zlib",
reqBody: compressZlib(t, testBody),
respCode: 200,
respCode: http.StatusOK,
},
{
name: "ValidZstd",
encoding: "zstd",
reqBody: compressZstd(t, testBody),
respCode: 200,
respCode: http.StatusOK,
},
{
name: "InvalidGzip",
encoding: "gzip",
reqBody: bytes.NewBuffer(testBody),
respCode: 400,
respCode: http.StatusBadRequest,
respBody: "gzip: invalid header\n",
},
{
name: "InvalidZlib",
encoding: "zlib",
reqBody: bytes.NewBuffer(testBody),
respCode: 400,
respCode: http.StatusBadRequest,
respBody: "zlib: invalid header\n",
},
{
name: "InvalidZstd",
encoding: "zstd",
reqBody: bytes.NewBuffer(testBody),
respCode: 400,
respCode: http.StatusBadRequest,
respBody: "invalid input: magic number mismatch",
},
{
name: "UnsupportedCompression",
encoding: "nosuchcompression",
reqBody: bytes.NewBuffer(testBody),
respCode: 400,
respCode: http.StatusBadRequest,
respBody: "unsupported Content-Encoding: nosuchcompression\n",
},
}
Expand All @@ -189,7 +189,7 @@ func TestHTTPContentDecompressionHandler(t *testing.T) {
require.NoError(t, err, "failed to read request body: %v", err)
assert.EqualValues(t, testBody, string(body))
w.WriteHeader(http.StatusOK)
})))
}), defaultErrorHandler))
t.Cleanup(srv.Close)

req, err := http.NewRequest(http.MethodGet, srv.URL, tt.reqBody)
Expand All @@ -213,7 +213,7 @@ func TestHTTPContentDecompressionHandler(t *testing.T) {
func TestHTTPContentCompressionRequestWithNilBody(t *testing.T) {
compressedGzipBody := compressGzip(t, []byte{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.WriteHeader(http.StatusOK)
body, err := io.ReadAll(r.Body)
require.NoError(t, err, "failed to read request body: %v", err)
assert.EqualValues(t, compressedGzipBody.Bytes(), body)
Expand Down Expand Up @@ -247,7 +247,7 @@ func (*copyFailBody) Close() error {

func TestHTTPContentCompressionCopyError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(server.Close)

Expand All @@ -271,7 +271,7 @@ func (*closeFailBody) Close() error {

func TestHTTPContentCompressionRequestBodyCloseError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(server.Close)

Expand Down
11 changes: 4 additions & 7 deletions config/confighttp/confighttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (hss *HTTPServerSettings) ToListener() (net.Listener, error) {
// toServerOptions has options that change the behavior of the HTTP server
// returned by HTTPServerSettings.ToServer().
type toServerOptions struct {
errorHandler
errHandler func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int)
}

// ToServerOption is an option to change the behavior of the HTTP server
Expand All @@ -256,9 +256,9 @@ type ToServerOption func(opts *toServerOptions)

// WithErrorHandler overrides the HTTP error handler that gets invoked
// when there is a failure inside httpContentDecompressor.
func WithErrorHandler(e errorHandler) ToServerOption {
func WithErrorHandler(e func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int)) ToServerOption {
return func(opts *toServerOptions) {
opts.errorHandler = e
opts.errHandler = e
}
}

Expand All @@ -271,10 +271,7 @@ func (hss *HTTPServerSettings) ToServer(host component.Host, settings component.
o(serverOpts)
}

handler = httpContentDecompressor(
handler,
withErrorHandlerForDecompressor(serverOpts.errorHandler),
)
handler = httpContentDecompressor(handler, serverOpts.errHandler)

if hss.MaxRequestBodySize > 0 {
handler = maxRequestBodySizeInterceptor(handler, hss.MaxRequestBodySize)
Expand Down
2 changes: 1 addition & 1 deletion config/confighttp/confighttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ func TestHttpClientHeaders(t *testing.T) {
for k, v := range tt.headers {
assert.Equal(t, r.Header.Get(k), string(v))
}
w.WriteHeader(200)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
serverURL, _ := url.Parse(server.URL)
Expand Down

0 comments on commit 8278824

Please sign in to comment.