diff --git a/client.go b/client.go index 9e7242691a..5e05e131d4 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ package fasthttp import ( "bufio" - "bytes" "crypto/tls" "errors" "fmt" @@ -464,11 +463,10 @@ func (c *Client) Do(req *Request, resp *Response) error { host := uri.Host() isTLS := false - scheme := uri.Scheme() - if bytes.Equal(scheme, strHTTPS) { + if uri.isHttps() { isTLS = true - } else if !bytes.Equal(scheme, strHTTP) { - return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) + } else if !uri.isHttp() { + return fmt.Errorf("unsupported protocol %q. http and https are supported", uri.Scheme()) } startCleaner := false @@ -1363,7 +1361,7 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) req.secureErrorLogMessage = c.SecureErrorLogMessage req.Header.secureErrorLogMessage = c.SecureErrorLogMessage - if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) { + if c.IsTLS != req.uri.isHttps() { return false, ErrHostClientRedirectToDifferentScheme } diff --git a/server_test.go b/server_test.go index 1ffe3964fa..4a2951562f 100644 --- a/server_test.go +++ b/server_test.go @@ -1158,9 +1158,8 @@ func TestServerServeTLSEmbed(t *testing.T) { ctx.Error("expecting tls", StatusBadRequest) return } - scheme := ctx.URI().Scheme() - if string(scheme) != "https" { - ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", scheme, "https"), StatusBadRequest) + if !ctx.URI().isHttps() { + ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", ctx.URI().Scheme(), "https"), StatusBadRequest) return } ctx.WriteString("success") //nolint:errcheck diff --git a/uri.go b/uri.go index 2285f45d0d..d53175f4ec 100644 --- a/uri.go +++ b/uri.go @@ -216,6 +216,14 @@ func (u *URI) SetSchemeBytes(scheme []byte) { lowercaseBytes(u.scheme) } +func (u *URI) isHttps() bool { + return bytes.Equal(u.scheme, strHTTPS) +} + +func (u *URI) isHttp() bool { + return len(u.scheme) == 0 || bytes.Equal(u.scheme, strHTTP) +} + // Reset clears uri. func (u *URI) Reset() { u.pathOriginal = u.pathOriginal[:0] @@ -282,14 +290,13 @@ func (u *URI) parse(host, uri []byte, isTLS bool) error { if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) { scheme, newHost, newURI := splitHostURI(host, uri) - u.scheme = append(u.scheme, scheme...) - lowercaseBytes(u.scheme) + u.SetSchemeBytes(scheme) host = newHost uri = newURI } if isTLS { - u.scheme = append(u.scheme[:0], strHTTPS...) + u.SetSchemeBytes(strHTTPS) } if n := bytes.IndexByte(host, '@'); n >= 0 { diff --git a/uri_test.go b/uri_test.go index fc0b2ef0cd..a605b46219 100644 --- a/uri_test.go +++ b/uri_test.go @@ -310,6 +310,29 @@ func testURIParseScheme(t *testing.T, uri, expectedScheme, expectedHost, expecte } } +func TestIsHttp(t *testing.T) { + var u URI + if !u.isHttp() || u.isHttps() { + t.Fatalf("http scheme is assumed by default and not https") + } + u.SetSchemeBytes([]byte{}) + if !u.isHttp() || u.isHttps() { + t.Fatalf("empty scheme must be threaten as http and not https") + } + u.SetScheme("http") + if !u.isHttp() || u.isHttps() { + t.Fatalf("scheme must be threaten as http and not https") + } + u.SetScheme("https") + if !u.isHttps() || u.isHttp() { + t.Fatalf("scheme must be threaten as https and not http") + } + u.SetScheme("dav") + if u.isHttps() || u.isHttp() { + t.Fatalf("scheme must be threaten as not http and not https") + } +} + func TestURIParse(t *testing.T) { t.Parallel()