Skip to content

Commit

Permalink
Improve multiline header parsing (#708)
Browse files Browse the repository at this point in the history
- Replace tabs with spaces at line starts to match net/http
- Don't allow multi line header names. See: golang/go#34702
  • Loading branch information
erikdubbelboer committed Dec 14, 2019
1 parent 6a8a72a commit fd55658
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 21 deletions.
26 changes: 26 additions & 0 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,20 @@ func (s *headerScanner) next() bool {
s.nextColon = -1
} else {
n = bytes.IndexByte(s.b, ':')

// There can't be a \n inside the header name, check for this.
x := bytes.IndexByte(s.b, '\n')
if x < 0 {
// A header name should always at some point be followed by a \n
// even if it's the one that terminates the header block.
s.err = errNeedMore
return false
}
if x < n {
// There was a \n before the :
s.err = errInvalidName
return false
}
}
if n < 0 {
s.err = errNeedMore
Expand Down Expand Up @@ -2085,6 +2099,9 @@ func (s *headerScanner) next() bool {
if n+1 >= len(s.b) {
break
}
if s.b[n+1] != ' ' && s.b[n+1] != '\t' {
break
}
d := bytes.IndexByte(s.b[n+1:], '\n')
if d <= 0 {
break
Expand Down Expand Up @@ -2195,11 +2212,19 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i
}
write := 0
shrunk := 0
lineStart := false
for read := 0; read < length; read++ {
c := ov[read]
if c == '\r' || c == '\n' {
shrunk++
if c == '\n' {
lineStart = true
}
continue
} else if lineStart && c == '\t' {
c = ' '
} else {
lineStart = false
}
nv[write] = c
write++
Expand Down Expand Up @@ -2267,6 +2292,7 @@ func AppendNormalizedHeaderKeyBytes(dst, key []byte) []byte {

var (
errNeedMore = errors.New("need more data: cannot find trailing lf")
errInvalidName = errors.New("invalid header name")
errSmallBuffer = errors.New("small read buffer. Increase ReadBufferSize")
)

Expand Down
48 changes: 27 additions & 21 deletions header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,43 @@ import (
func TestResponseHeaderMultiLineValue(t *testing.T) {
s := "HTTP/1.1 200 OK\r\n" +
"EmptyValue1:\r\n" +
"Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + // the '\t' will be kept, won't be removed
"Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" +
"Foo: Bar\r\n" +
"Multi-Line: one;\r\n two\r\n" +
"Values: v1;\r\n v2;\r\n v3; v4\r\n" +
"Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" +
"\r\n"
expectContentType := "foo/bar;\tnewline; another/newline"
// net/http not only remove "\r\n" but also replace \t to space
expectNetHttpContentType := "foo/bar; newline; another/newline"
expectMultiLine := "one; two"
header := new(ResponseHeader)
_, err := header.parse([]byte(s))
if err != nil {
if _, err := header.parse([]byte(s)); err != nil {
t.Fatalf("parse headers with multi-line values failed, %s", err)
}
gotContentType := header.Peek("Content-Type")
if string(gotContentType) != expectContentType {
t.Fatalf("unexpected content-type: %q. Expecting %q", gotContentType, expectContentType)
}
gotMultiLine := header.Peek("Multi-Line")
if string(gotMultiLine) != expectMultiLine {
t.Fatalf("unexpected multi-line: %q. Expecting %q", gotMultiLine, expectMultiLine)
}
// ensure behave same as net/http
response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(s)), nil)
if err != nil {
t.Fatalf("parse response using net/http failed, %s", err)
}
gotNetHttpContentType := response.Header.Get("Content-Type")
if gotNetHttpContentType != expectNetHttpContentType {
t.Fatalf("unexpected content-type (net/http): %q. Expecting %q",
gotNetHttpContentType, expectNetHttpContentType)

for name, vals := range response.Header {
got := string(header.Peek(name))
want := vals[0]

if got != want {
t.Errorf("unexpected %s got: %q want: %q", name, got, want)
}
}
}

func TestResponseHeaderMultiLineName(t *testing.T) {
s := "HTTP/1.1 200 OK\r\n" +
"Host: golang.org\r\n" +
"Gopher-New-\r\n" +
" Line: This is a header on multiple lines\r\n" +
"\r\n"
header := new(ResponseHeader)
if _, err := header.parse([]byte(s)); err != errInvalidName {
m := make(map[string]string)
header.VisitAll(func(key, value []byte) {
m[string(key)] = string(value)
})
t.Errorf("expected error, got %q (%v)", m, err)
}
}

Expand Down

0 comments on commit fd55658

Please sign in to comment.