Skip to content

Commit

Permalink
Refactored with the comments and split the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilijamt committed Jan 3, 2021
1 parent cc3baf5 commit 5393d80
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 59 deletions.
28 changes: 24 additions & 4 deletions middleware/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package middleware
import (
"context"
"github.com/labstack/echo/v4"
"net/http"
"time"
)

Expand All @@ -14,16 +13,22 @@ type (
TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorHandler defines a function which is executed for a timeout
// It can be used to define a custom timeout error
ErrorHandler TimeoutErrorHandlerWithContext
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
Timeout time.Duration
}

TimeoutErrorHandlerWithContext func(error, echo.Context) error
)

var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
Skipper: DefaultSkipper,
Timeout: 0,
ErrorHandler: nil,
}
)

Expand Down Expand Up @@ -58,9 +63,24 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
done <- next(c)
}()

defer func() {
defer func() {
// we don't care what the error is as this is here to
// close the channel in case it hasn't been closed properly
// this is so we can clean up the running go routine
// in case the go routine has been closed this will panic, and
// this recover function will recover from the panic
_ = recover()
}()
close(done)
}()

select {
case <-ctx.Done():
return c.JSON(http.StatusGatewayTimeout, ctx.Err())
if config.ErrorHandler != nil {
return config.ErrorHandler(ctx.Err(), c)
}
return ctx.Err()
case err := <-done:
return err
}
Expand Down
132 changes: 77 additions & 55 deletions middleware/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package middleware

import (
"context"
"errors"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
Expand All @@ -13,78 +14,99 @@ import (
"time"
)

func TestTimeout(t *testing.T) {
e := echo.New()
func TestTimeoutSkipper(t *testing.T) {
m := TimeoutWithConfig(TimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
})

t.Run("Skipper", func(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)
err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

assert.NoError(t, err)
func TestTimeoutWithTimeout0(t *testing.T) {
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 0,
})

t.Run("Is cancelable", func(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Minute,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

err := m(func(c echo.Context) error {
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)
assert.NoError(t, err)
}

assert.NoError(t, err)
func TestTimeoutIsCancelable(t *testing.T) {
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Minute,
})

t.Run("Times out after the predefined timeout", func(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)
err := m(func(c echo.Context) error {
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
assert.EqualValues(t, http.StatusGatewayTimeout, rec.Code)
})
assert.NoError(t, err)
}

func TestTimeoutErrorOutInHandler(t *testing.T) {
m := Timeout()

t.Run("Error out in the handler", func(t *testing.T) {
t.Parallel()
m := Timeout()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
return errors.New("err")
})(c)
err := m(func(c echo.Context) error {
return errors.New("err")
})(c)

assert.Error(t, err)
}

assert.Error(t, err)
func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
ErrorHandler: func(err error, e echo.Context) error {
return err
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)

assert.EqualError(t, err, context.DeadlineExceeded.Error())
}

0 comments on commit 5393d80

Please sign in to comment.