From 5393d80b035650a98dcbd657c3e92013ed60b9ac Mon Sep 17 00:00:00 2001 From: Ilija Matoski Date: Sun, 3 Jan 2021 13:12:01 +0100 Subject: [PATCH] Refactored with the comments and split the tests --- middleware/timeout.go | 28 ++++++-- middleware/timeout_test.go | 132 +++++++++++++++++++++---------------- 2 files changed, 101 insertions(+), 59 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 87a690355..3c9911ab6 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -5,7 +5,6 @@ package middleware import ( "context" "github.com/labstack/echo/v4" - "net/http" "time" ) @@ -14,16 +13,22 @@ type ( TimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper + // ErrorHandler defines a function which is executed for a timeout + // It can be used to define a custom timeout error + ErrorHandler TimeoutErrorHandlerWithContext // Timeout configures a timeout for the middleware, defaults to 0 for no timeout Timeout time.Duration } + + TimeoutErrorHandlerWithContext func(error, echo.Context) error ) var ( // DefaultTimeoutConfig is the default Timeout middleware config. DefaultTimeoutConfig = TimeoutConfig{ - Skipper: DefaultSkipper, - Timeout: 0, + Skipper: DefaultSkipper, + Timeout: 0, + ErrorHandler: nil, } ) @@ -58,9 +63,24 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { done <- next(c) }() + defer func() { + defer func() { + // we don't care what the error is as this is here to + // close the channel in case it hasn't been closed properly + // this is so we can clean up the running go routine + // in case the go routine has been closed this will panic, and + // this recover function will recover from the panic + _ = recover() + }() + close(done) + }() + select { case <-ctx.Done(): - return c.JSON(http.StatusGatewayTimeout, ctx.Err()) + if config.ErrorHandler != nil { + return config.ErrorHandler(ctx.Err(), c) + } + return ctx.Err() case err := <-done: return err } diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 9b72dfb29..6106b8b48 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -3,6 +3,7 @@ package middleware import ( + "context" "errors" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -13,78 +14,99 @@ import ( "time" ) -func TestTimeout(t *testing.T) { - e := echo.New() +func TestTimeoutSkipper(t *testing.T) { + m := TimeoutWithConfig(TimeoutConfig{ + Skipper: func(context echo.Context) bool { + return true + }, + }) - t.Run("Skipper", func(t *testing.T) { - t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Skipper: func(context echo.Context) bool { - return true - }, - }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + e := echo.New() + c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) - return nil - })(c) + err := m(func(c echo.Context) error { + assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} - assert.NoError(t, err) +func TestTimeoutWithTimeout0(t *testing.T) { + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: 0, }) - t.Run("Is cancelable", func(t *testing.T) { - t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Minute, - }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) - err := m(func(c echo.Context) error { - assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) - return nil - })(c) + assert.NoError(t, err) +} - assert.NoError(t, err) +func TestTimeoutIsCancelable(t *testing.T) { + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Minute, }) - t.Run("Times out after the predefined timeout", func(t *testing.T) { - t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Second, - }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + e := echo.New() + c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - time.Sleep(time.Minute) - return nil - })(c) + err := m(func(c echo.Context) error { + assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) - assert.NoError(t, err) - assert.EqualValues(t, http.StatusGatewayTimeout, rec.Code) - }) + assert.NoError(t, err) +} + +func TestTimeoutErrorOutInHandler(t *testing.T) { + m := Timeout() - t.Run("Error out in the handler", func(t *testing.T) { - t.Parallel() - m := Timeout() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + e := echo.New() + c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - return errors.New("err") - })(c) + err := m(func(c echo.Context) error { + return errors.New("err") + })(c) + + assert.Error(t, err) +} - assert.Error(t, err) +func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Second, + ErrorHandler: func(err error, e echo.Context) error { + return err + }, }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + time.Sleep(time.Minute) + return nil + })(c) + + assert.EqualError(t, err, context.DeadlineExceeded.Error()) }