diff --git a/middleware/timeout.go b/middleware/timeout.go new file mode 100644 index 000000000..d146541e6 --- /dev/null +++ b/middleware/timeout.go @@ -0,0 +1,81 @@ +// +build go1.13 + +package middleware + +import ( + "context" + "github.com/labstack/echo/v4" + "time" +) + +type ( + // TimeoutConfig defines the config for Timeout middleware. + TimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // ErrorHandler defines a function which is executed for a timeout + // It can be used to define a custom timeout error + ErrorHandler TimeoutErrorHandlerWithContext + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + Timeout time.Duration + } + + // TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can + // handle the error as we see fit + TimeoutErrorHandlerWithContext func(error, echo.Context) error +) + +var ( + // DefaultTimeoutConfig is the default Timeout middleware config. + DefaultTimeoutConfig = TimeoutConfig{ + Skipper: DefaultSkipper, + Timeout: 0, + ErrorHandler: nil, + } +) + +// Timeout returns a middleware which recovers from panics anywhere in the chain +// and handles the control to the centralized HTTPErrorHandler. +func Timeout() echo.MiddlewareFunc { + return TimeoutWithConfig(DefaultTimeoutConfig) +} + +// TimeoutWithConfig returns a Timeout middleware with config. +// See: `Timeout()`. +func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultTimeoutConfig.Skipper + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) || config.Timeout == 0 { + return next(c) + } + + ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) + defer cancel() + + // this does a deep clone of the context, wondering if there is a better way to do this? + c.SetRequest(c.Request().Clone(ctx)) + + done := make(chan error, 1) + go func() { + // This goroutine will keep running even if this middleware times out and + // will be stopped when ctx.Done() is called down the next(c) call chain + done <- next(c) + }() + + select { + case <-ctx.Done(): + if config.ErrorHandler != nil { + return config.ErrorHandler(ctx.Err(), c) + } + return ctx.Err() + case err := <-done: + return err + } + } + } +} diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go new file mode 100644 index 000000000..c0e945933 --- /dev/null +++ b/middleware/timeout_test.go @@ -0,0 +1,177 @@ +// +build go1.13 + +package middleware + +import ( + "context" + "errors" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + "time" +) + +func TestTimeoutSkipper(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Skipper: func(context echo.Context) bool { + return true + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestTimeoutWithTimeout0(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: 0, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestTimeoutIsCancelable(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Minute, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestTimeoutErrorOutInHandler(t *testing.T) { + t.Parallel() + m := Timeout() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + return errors.New("err") + })(c) + + assert.Error(t, err) +} + +func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Second, + ErrorHandler: func(err error, e echo.Context) error { + assert.EqualError(t, err, context.DeadlineExceeded.Error()) + return errors.New("err") + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + time.Sleep(time.Minute) + return nil + })(c) + + assert.EqualError(t, err, errors.New("err").Error()) +} + +func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Second, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + time.Sleep(time.Minute) + return nil + })(c) + + assert.EqualError(t, err, context.DeadlineExceeded.Error()) +} + +func TestTimeoutTestRequestClone(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + m := TimeoutWithConfig(TimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: time.Second, + }) + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // Cookie test + cookie, err := c.Request().Cookie("cookie") + if assert.NoError(t, err) { + assert.EqualValues(t, "cookie", cookie.Name) + assert.EqualValues(t, "value", cookie.Value) + } + + // Form values + if assert.NoError(t, c.Request().ParseForm()) { + assert.EqualValues(t, "value", c.Request().FormValue("form")) + } + + // Query string + assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) + return nil + })(c) + + assert.NoError(t, err) + +}