From ae4665cf7a215d14b3ba769bfa355c5420ce10ef Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 6 Apr 2021 10:11:31 +0300 Subject: [PATCH] Fix panic in redirect middleware on short host name (fix #1811) (#1813) --- middleware/redirect.go | 45 +++--- middleware/redirect_test.go | 263 +++++++++++++++++++++++++++++++----- 2 files changed, 249 insertions(+), 59 deletions(-) diff --git a/middleware/redirect.go b/middleware/redirect.go index 813e5b856..13877db38 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "strings" "github.com/labstack/echo/v4" ) @@ -40,11 +41,11 @@ func HTTPSRedirect() echo.MiddlewareFunc { // HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSRedirect()`. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https"; ok { - url = "https://" + host + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri } - return + return false, "" }) } @@ -59,11 +60,11 @@ func HTTPSWWWRedirect() echo.MiddlewareFunc { // HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSWWWRedirect()`. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https" && host[:4] != www; ok { - url = "https://www." + host + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri } - return + return false, "" }) } @@ -79,13 +80,11 @@ func HTTPSNonWWWRedirect() echo.MiddlewareFunc { // See `HTTPSNonWWWRedirect()`. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https"; ok { - if host[:4] == www { - host = host[4:] - } - url = "https://" + host + uri + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri } - return + return false, "" }) } @@ -100,11 +99,11 @@ func WWWRedirect() echo.MiddlewareFunc { // WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `WWWRedirect()`. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = host[:4] != www; ok { - url = scheme + "://www." + host + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri } - return + return false, "" }) } @@ -119,17 +118,17 @@ func NonWWWRedirect() echo.MiddlewareFunc { // NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `NonWWWRedirect()`. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = host[:4] == www; ok { - url = scheme + "://" + host[4:] + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } - return + return false, "" }) } func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper + config.Skipper = DefaultRedirectConfig.Skipper } if config.Code == 0 { config.Code = DefaultRedirectConfig.Code diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 082609574..9d1b56205 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -12,62 +12,253 @@ import ( type middlewareGenerator func() echo.MiddlewareFunc func TestRedirectHTTPSRedirect(t *testing.T) { - res := redirectTest(HTTPSRedirect, "labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestHTTPSRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSRedirect, "labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectHTTPSWWWRedirect(t *testing.T) { - res := redirectTest(HTTPSWWWRedirect, "labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://www.labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + { + whenHost: "a.com", + expectLocation: "https://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "https://www.ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestRedirectHTTPSWWWRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSWWWRedirect, "labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectHTTPSNonWWWRedirect(t *testing.T) { - res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "www.labstack.com", + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "https://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "https://ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestRedirectHTTPSNonWWWRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectWWWRedirect(t *testing.T) { - res := redirectTest(WWWRedirect, "labstack.com", nil) + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "http://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "http://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "http://www.ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.ip", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "http://www.labstack.com/", res.Header().Get(echo.HeaderLocation)) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectNonWWWRedirect(t *testing.T) { - res := redirectTest(NonWWWRedirect, "www.labstack.com", nil) + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.a.com", + expectLocation: "http://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.a.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader) + + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } +} + +func TestNonWWWRedirectWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenCode int + givenSkipFunc func(c echo.Context) bool + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + name: "usual redirect", + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + name: "redirect is skipped", + givenSkipFunc: func(c echo.Context) bool { + return true // skip always + }, + whenHost: "www.labstack.com", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + { + name: "redirect with custom status code", + givenCode: http.StatusSeeOther, + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusSeeOther, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + middleware := func() echo.MiddlewareFunc { + return NonWWWRedirectWithConfig(RedirectConfig{ + Skipper: tc.givenSkipFunc, + Code: tc.givenCode, + }) + } + res := redirectTest(middleware, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "http://labstack.com/", res.Header().Get(echo.HeaderLocation)) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {