Skip to content

Commit

Permalink
Add context timeout middleware (#2380)
Browse files Browse the repository at this point in the history
Add context timeout middleware


Co-authored-by: Erhan Akpınar <erhan.akpinar@yemeksepeti.com>
Co-authored-by: @erhanakp
  • Loading branch information
hakankutluay and erhanakp authored Feb 1, 2023
1 parent 08093a4 commit 82a964c
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 0 deletions.
72 changes: 72 additions & 0 deletions middleware/context_timeout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package middleware

import (
"context"
"errors"
"time"

"github.com/labstack/echo/v4"
)

// ContextTimeoutConfig defines the config for ContextTimeout middleware.
type ContextTimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper

// ErrorHandler is a function when error aries in middeware execution.
ErrorHandler func(err error, c echo.Context) error

// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
Timeout time.Duration
}

// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
// when underlying method returns context.DeadlineExceeded error.
func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
}

// ContextTimeoutWithConfig returns a Timeout middleware with config.
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

// ToMiddleware converts Config to middleware.
func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Timeout == 0 {
return nil, errors.New("timeout must be set")
}
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(err error, c echo.Context) error {
if err != nil && errors.Is(err, context.DeadlineExceeded) {
return echo.ErrServiceUnavailable.WithInternal(err)
}
return err
}
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}

timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
defer cancel()

c.SetRequest(c.Request().WithContext(timeoutContext))

if err := next(c); err != nil {
return config.ErrorHandler(err, c)
}
return nil
}
}, nil
}
226 changes: 226 additions & 0 deletions middleware/context_timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package middleware

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestContextTimeoutSkipper(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
Timeout: 10 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}

return errors.New("response from handler")
})(c)

// if not skipped we would have not returned error due context timeout logic
assert.EqualError(t, err, "response from handler")
}

func TestContextTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
assert.Panics(t, func() {
ContextTimeout(time.Duration(0))
})
}

func TestContextTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 10 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

rec.Code = 1 // we want to be sure that even 200 will not be sent
err := m(func(c echo.Context) error {
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
// to handle returned error and this can be done only then handler has not yet committed (written status code)
// the response.
return echo.NewHTTPError(http.StatusTeapot, "err")
})(c)

assert.Error(t, err)
assert.EqualError(t, err, "code=418, message=err")
assert.Equal(t, 1, rec.Code)
assert.Equal(t, "", rec.Body.String())
}

func TestContextTimeoutSuccessfulRequest(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 10 * time.Millisecond,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
})(c)

assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
}

func TestContextTimeoutTestRequestClone(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()

m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 1 * time.Second,
})

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
assert.EqualValues(t, "cookie", cookie.Name)
assert.EqualValues(t, "value", cookie.Value)
}

// Form values
if assert.NoError(t, c.Request().ParseForm()) {
assert.EqualValues(t, "value", c.Request().FormValue("form"))
}

// Query string
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
return nil
})(c)

assert.NoError(t, err)
}

func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
t.Parallel()

timeout := 10 * time.Millisecond
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Timeout: timeout,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}
return c.String(http.StatusOK, "Hello, World!")
})(c)

assert.IsType(t, &echo.HTTPError{}, err)
assert.Error(t, err)
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
}

func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
t.Parallel()

timeoutErrorHandler := func(err error, c echo.Context) error {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return &echo.HTTPError{
Code: http.StatusServiceUnavailable,
Message: "Timeout! change me",
}
}
return err
}
return nil
}

timeout := 10 * time.Millisecond
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Timeout: timeout,
ErrorHandler: timeoutErrorHandler,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable

if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}

// The Request Context should have a Deadline set by http.ContextTimeoutHandler
if _, ok := c.Request().Context().Deadline(); !ok {
assert.Fail(t, "No timeout set on Request Context")
}
return c.String(http.StatusOK, "Hello, World!")
})(c)

assert.IsType(t, &echo.HTTPError{}, err)
assert.Error(t, err)
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
}

func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)

defer func() {
_ = timer.Stop()
}()

select {
case <-ctx.Done():
return context.DeadlineExceeded
case <-timer.C:
return nil
}
}

0 comments on commit 82a964c

Please sign in to comment.