Skip to content

Commit

Permalink
Fix panic in redirect middleware on short host name (fix #1811) (#1813)
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas authored Apr 6, 2021
1 parent 67f6346 commit ae4665c
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 59 deletions.
45 changes: 22 additions & 23 deletions middleware/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"net/http"
"strings"

"github.com/labstack/echo/v4"
)
Expand Down Expand Up @@ -40,11 +41,11 @@ func HTTPSRedirect() echo.MiddlewareFunc {
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSRedirect()`.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if ok = scheme != "https"; ok {
url = "https://" + host + uri
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return
return false, ""
})
}

Expand All @@ -59,11 +60,11 @@ func HTTPSWWWRedirect() echo.MiddlewareFunc {
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSWWWRedirect()`.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if ok = scheme != "https" && host[:4] != www; ok {
url = "https://www." + host + uri
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return
return false, ""
})
}

Expand All @@ -79,13 +80,11 @@ func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
// See `HTTPSNonWWWRedirect()`.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if ok = scheme != "https"; ok {
if host[:4] == www {
host = host[4:]
}
url = "https://" + host + uri
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return
return false, ""
})
}

Expand All @@ -100,11 +99,11 @@ func WWWRedirect() echo.MiddlewareFunc {
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `WWWRedirect()`.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if ok = host[:4] != www; ok {
url = scheme + "://www." + host + uri
return redirect(config, func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return
return false, ""
})
}

Expand All @@ -119,17 +118,17 @@ func NonWWWRedirect() echo.MiddlewareFunc {
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `NonWWWRedirect()`.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if ok = host[:4] == www; ok {
url = scheme + "://" + host[4:] + uri
return redirect(config, func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return
return false, ""
})
}

func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
config.Skipper = DefaultRedirectConfig.Skipper
}
if config.Code == 0 {
config.Code = DefaultRedirectConfig.Code
Expand Down
263 changes: 227 additions & 36 deletions middleware/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,253 @@ import (
type middlewareGenerator func() echo.MiddlewareFunc

func TestRedirectHTTPSRedirect(t *testing.T) {
res := redirectTest(HTTPSRedirect, "labstack.com", nil)

assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation))
}
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "labstack.com",
expectLocation: "https://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}

func TestHTTPSRedirectBehindTLSTerminationProxy(t *testing.T) {
header := http.Header{}
header.Set(echo.HeaderXForwardedProto, "https")
res := redirectTest(HTTPSRedirect, "labstack.com", header)
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader)

assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}

func TestRedirectHTTPSWWWRedirect(t *testing.T) {
res := redirectTest(HTTPSWWWRedirect, "labstack.com", nil)

assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "https://www.labstack.com/", res.Header().Get(echo.HeaderLocation))
}
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "labstack.com",
expectLocation: "https://www.labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.labstack.com",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
{
whenHost: "a.com",
expectLocation: "https://www.a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "https://www.ip/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}

func TestRedirectHTTPSWWWRedirectBehindTLSTerminationProxy(t *testing.T) {
header := http.Header{}
header.Set(echo.HeaderXForwardedProto, "https")
res := redirectTest(HTTPSWWWRedirect, "labstack.com", header)
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader)

assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}

func TestRedirectHTTPSNonWWWRedirect(t *testing.T) {
res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", nil)

assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation))
}
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "www.labstack.com",
expectLocation: "https://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
expectLocation: "https://a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "https://ip/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}

func TestRedirectHTTPSNonWWWRedirectBehindTLSTerminationProxy(t *testing.T) {
header := http.Header{}
header.Set(echo.HeaderXForwardedProto, "https")
res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", header)
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader)

assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}

func TestRedirectWWWRedirect(t *testing.T) {
res := redirectTest(WWWRedirect, "labstack.com", nil)
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "labstack.com",
expectLocation: "http://www.labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
expectLocation: "http://www.a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "http://www.ip/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "https://www.a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.ip",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}

for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader)

assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "http://www.labstack.com/", res.Header().Get(echo.HeaderLocation))
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}

func TestRedirectNonWWWRedirect(t *testing.T) {
res := redirectTest(NonWWWRedirect, "www.labstack.com", nil)
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "www.labstack.com",
expectLocation: "http://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.a.com",
expectLocation: "http://a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.a.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "https://a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}

for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader)

assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}

func TestNonWWWRedirectWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenCode int
givenSkipFunc func(c echo.Context) bool
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
name: "usual redirect",
whenHost: "www.labstack.com",
expectLocation: "http://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
name: "redirect is skipped",
givenSkipFunc: func(c echo.Context) bool {
return true // skip always
},
whenHost: "www.labstack.com",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
{
name: "redirect with custom status code",
givenCode: http.StatusSeeOther,
whenHost: "www.labstack.com",
expectLocation: "http://labstack.com/",
expectStatusCode: http.StatusSeeOther,
},
}

for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
middleware := func() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(RedirectConfig{
Skipper: tc.givenSkipFunc,
Code: tc.givenCode,
})
}
res := redirectTest(middleware, tc.whenHost, tc.whenHeader)

assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "http://labstack.com/", res.Header().Get(echo.HeaderLocation))
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}

func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {
Expand Down

0 comments on commit ae4665c

Please sign in to comment.