Skip to content

Commit

Permalink
Timeout middleware implementation for go1.13+ (#1743)
Browse files Browse the repository at this point in the history
Co-authored-by: Ilija Matoski <imatoski@schubergphilis.com>
  • Loading branch information
ilijamt and Ilija Matoski authored Jan 5, 2021
1 parent 02ed3f3 commit 67263b5
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 0 deletions.
81 changes: 81 additions & 0 deletions middleware/timeout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// +build go1.13

package middleware

import (
"context"
"github.com/labstack/echo/v4"
"time"
)

type (
// TimeoutConfig defines the config for Timeout middleware.
TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorHandler defines a function which is executed for a timeout
// It can be used to define a custom timeout error
ErrorHandler TimeoutErrorHandlerWithContext
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
Timeout time.Duration
}

// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
// handle the error as we see fit
TimeoutErrorHandlerWithContext func(error, echo.Context) error
)

var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorHandler: nil,
}
)

// Timeout returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig)
}

// TimeoutWithConfig returns a Timeout middleware with config.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) || config.Timeout == 0 {
return next(c)
}

ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
defer cancel()

// this does a deep clone of the context, wondering if there is a better way to do this?
c.SetRequest(c.Request().Clone(ctx))

done := make(chan error, 1)
go func() {
// This goroutine will keep running even if this middleware times out and
// will be stopped when ctx.Done() is called down the next(c) call chain
done <- next(c)
}()

select {
case <-ctx.Done():
if config.ErrorHandler != nil {
return config.ErrorHandler(ctx.Err(), c)
}
return ctx.Err()
case err := <-done:
return err
}
}
}
}
177 changes: 177 additions & 0 deletions middleware/timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// +build go1.13

package middleware

import (
"context"
"errors"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
)

func TestTimeoutSkipper(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 0,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutIsCancelable(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Minute,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := Timeout()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
return errors.New("err")
})(c)

assert.Error(t, err)
}

func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
ErrorHandler: func(err error, e echo.Context) error {
assert.EqualError(t, err, context.DeadlineExceeded.Error())
return errors.New("err")
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)

assert.EqualError(t, err, errors.New("err").Error())
}

func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)

assert.EqualError(t, err, context.DeadlineExceeded.Error())
}

func TestTimeoutTestRequestClone(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()

m := TimeoutWithConfig(TimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: time.Second,
})

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
assert.EqualValues(t, "cookie", cookie.Name)
assert.EqualValues(t, "value", cookie.Value)
}

// Form values
if assert.NoError(t, c.Request().ParseForm()) {
assert.EqualValues(t, "value", c.Request().FormValue("form"))
}

// Query string
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
return nil
})(c)

assert.NoError(t, err)

}

0 comments on commit 67263b5

Please sign in to comment.