From 4c2fd1fb042b122e2f96830ddb58aee6c9f90bf3 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 9 Mar 2021 14:22:11 +0200 Subject: [PATCH] Allow proxy middleware to use query part in rewrite (fix #1798) (#1802) --- middleware/middleware.go | 46 ++++++++++++++++++++--------------- middleware/middleware_test.go | 29 +++++++++++++++++++--- middleware/proxy.go | 5 ++-- middleware/proxy_test.go | 27 ++++++++++++++------ middleware/rewrite.go | 6 ++--- middleware/rewrite_test.go | 9 +++++-- 6 files changed, 85 insertions(+), 37 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 6bdb0eb79..a7ad73a5c 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "net/url" "regexp" "strconv" "strings" @@ -49,30 +48,39 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { return rulesRegex } -func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { - for k, v := range rewriteRegex { - rawPath := req.URL.RawPath - if rawPath != "" { - // RawPath is only set when there has been escaping done. In that case Path must be deduced from rewritten RawPath - // because encoded Path could match rules that RawPath did not - if replacer := captureTokens(k, rawPath); replacer != nil { - rawPath = replacer.Replace(v) - - req.URL.RawPath = rawPath - req.URL.Path, _ = url.PathUnescape(rawPath) - - return // rewrite only once - } +func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error { + if len(rewriteRegex) == 0 { + return nil + } - continue + // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // We only want to use path part for rewriting and therefore trim prefix if it exists + rawURI := req.RequestURI + if rawURI != "" && rawURI[0] != '/' { + prefix := "" + if req.URL.Scheme != "" { + prefix = req.URL.Scheme + "://" } + if req.URL.Host != "" { + prefix += req.URL.Host // host or host:port + } + if prefix != "" { + rawURI = strings.TrimPrefix(rawURI, prefix) + } + } - if replacer := captureTokens(k, req.URL.Path); replacer != nil { - req.URL.Path = replacer.Replace(v) + for k, v := range rewriteRegex { + if replacer := captureTokens(k, rawURI); replacer != nil { + url, err := req.URL.Parse(replacer.Replace(v)) + if err != nil { + return err + } + req.URL = url - return // rewrite only once + return nil // rewrite only once } } + return nil } // DefaultSkipper returns false which processes the middleware. diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index bc14c531d..44f44142c 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -8,11 +8,13 @@ import ( "testing" ) -func TestRewritePath(t *testing.T) { +func TestRewriteURL(t *testing.T) { var testCases = []struct { whenURL string expectPath string expectRawPath string + expectQuery string + expectErr string }{ { whenURL: "http://localhost:8080/old", @@ -28,6 +30,7 @@ func TestRewritePath(t *testing.T) { whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1", expectPath: "/user/+_+/order/___++++", expectRawPath: "", + expectQuery: "test=1", }, { whenURL: "http://localhost:8080/users/%20a/orders/%20aa", @@ -35,9 +38,10 @@ func TestRewritePath(t *testing.T) { expectRawPath: "", }, { - whenURL: "http://localhost:8080/%47%6f%2f", + whenURL: "http://localhost:8080/%47%6f%2f?test=1", expectPath: "/Go/", expectRawPath: "/%47%6f%2f", + expectQuery: "test=1", }, { whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", @@ -49,21 +53,40 @@ func TestRewritePath(t *testing.T) { expectPath: "/user/jill/order/T/cO4lW/t/Vp/", expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", }, + { + whenURL: "http://localhost:8080/static", + expectPath: "/static/path", + expectRawPath: "", + expectQuery: "role=AUTHOR&limit=1000", + }, + { + whenURL: "/static", + expectPath: "/static/path", + expectRawPath: "", + expectQuery: "role=AUTHOR&limit=1000", + }, } rules := map[*regexp.Regexp]string{ regexp.MustCompile("^/old$"): "/new", regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2", + regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000", } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rewritePath(rules, req) + err := rewriteURL(rules, req) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/. assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path. + assert.Equal(t, tc.expectQuery, req.URL.RawQuery) }) } } diff --git a/middleware/proxy.go b/middleware/proxy.go index 63eec5a20..6f01f3a7c 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -231,8 +231,9 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Set rewrite path and raw path - rewritePath(config.RegexRewrite, req) + if err := rewriteURL(config.RegexRewrite, req); err != nil { + return err + } // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 591981e7f..93daf735e 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -245,12 +245,16 @@ func TestProxyRewrite(t *testing.T) { func TestProxyRewriteRegex(t *testing.T) { // Setup - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + receivedRequestURI := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server + // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic + // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested + receivedRequestURI <- r.RequestURI + })) defer upstream.Close() - url, _ := url.Parse(upstream.URL) - rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() + tmpUrL, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}}) // Rewrite e := echo.New() @@ -279,14 +283,21 @@ func TestProxyRewriteRegex(t *testing.T) { {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"}, {"/x/ignore/test", http.StatusOK, "/v4/test"}, {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, + // NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation + // $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently) + {"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"}, } for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { - req.URL, _ = url.Parse(tc.requestPath) - rec = httptest.NewRecorder() + targetURL, _ := url.Parse(tc.requestPath) + req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + + actualRequestURI := <-receivedRequestURI + assert.Equal(t, tc.expectPath, actualRequestURI) assert.Equal(t, tc.statusCode, rec.Code) }) } diff --git a/middleware/rewrite.go b/middleware/rewrite.go index c05d5d84f..e5b0a6b56 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -72,9 +72,9 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { return next(c) } - req := c.Request() - // Set rewrite path and raw path - rewritePath(config.RegexRules, req) + if err := rewriteURL(config.RegexRules, c.Request()); err != nil { + return err + } return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index cff2714d7..0ac04bb2f 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -220,6 +220,8 @@ func TestEchoRewriteWithRegexRules(t *testing.T) { func TestEchoRewriteReplacementEscaping(t *testing.T) { e := echo.New() + // NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts + // so in reality they append query and fragment part as `$1` matches everything after that prefix e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "^/a/*": "/$1?query=param", @@ -228,6 +230,7 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) { RegexRules: map[*regexp.Regexp]string{ regexp.MustCompile("^/x/(.*)"): "/$1?query=param", regexp.MustCompile("^/y/(.*)"): "/$1;part#one", + regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test", }, })) @@ -236,13 +239,15 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) { testCases := []struct { requestPath string - expectPath string + expect string }{ {"/unmatched", "/unmatched"}, {"/a/test", "/test?query=param"}, {"/b/foo/bar", "/foo/bar;part#one"}, {"/x/test", "/test?query=param"}, {"/y/foo/bar", "/foo/bar;part#one"}, + {"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"}, + {"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending } for _, tc := range testCases { @@ -250,7 +255,7 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) { req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, tc.expect, req.URL.String()) }) } }