diff --git a/header.go b/header.go index a04f4f90f5..8dab12acbe 100644 --- a/header.go +++ b/header.go @@ -2053,6 +2053,20 @@ func (s *headerScanner) next() bool { s.nextColon = -1 } else { n = bytes.IndexByte(s.b, ':') + + // There can't be a \n inside the header name, check for this. + x := bytes.IndexByte(s.b, '\n') + if x < 0 { + // A header name should always at some point be followed by a \n + // even if it's the one that terminates the header block. + s.err = errNeedMore + return false + } + if x < n { + // There was a \n before the : + s.err = errInvalidName + return false + } } if n < 0 { s.err = errNeedMore @@ -2085,6 +2099,9 @@ func (s *headerScanner) next() bool { if n+1 >= len(s.b) { break } + if s.b[n+1] != ' ' && s.b[n+1] != '\t' { + break + } d := bytes.IndexByte(s.b[n+1:], '\n') if d <= 0 { break @@ -2195,11 +2212,19 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i } write := 0 shrunk := 0 + lineStart := false for read := 0; read < length; read++ { c := ov[read] if c == '\r' || c == '\n' { shrunk++ + if c == '\n' { + lineStart = true + } continue + } else if lineStart && c == '\t' { + c = ' ' + } else { + lineStart = false } nv[write] = c write++ @@ -2267,6 +2292,7 @@ func AppendNormalizedHeaderKeyBytes(dst, key []byte) []byte { var ( errNeedMore = errors.New("need more data: cannot find trailing lf") + errInvalidName = errors.New("invalid header name") errSmallBuffer = errors.New("small read buffer. Increase ReadBufferSize") ) diff --git a/header_test.go b/header_test.go index 6f3aa0fe1a..cec008f5fa 100644 --- a/header_test.go +++ b/header_test.go @@ -15,37 +15,43 @@ import ( func TestResponseHeaderMultiLineValue(t *testing.T) { s := "HTTP/1.1 200 OK\r\n" + "EmptyValue1:\r\n" + - "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + // the '\t' will be kept, won't be removed + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + "Foo: Bar\r\n" + "Multi-Line: one;\r\n two\r\n" + - "Values: v1;\r\n v2;\r\n v3; v4\r\n" + + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + "\r\n" - expectContentType := "foo/bar;\tnewline; another/newline" - // net/http not only remove "\r\n" but also replace \t to space - expectNetHttpContentType := "foo/bar; newline; another/newline" - expectMultiLine := "one; two" header := new(ResponseHeader) - _, err := header.parse([]byte(s)) - if err != nil { + if _, err := header.parse([]byte(s)); err != nil { t.Fatalf("parse headers with multi-line values failed, %s", err) } - gotContentType := header.Peek("Content-Type") - if string(gotContentType) != expectContentType { - t.Fatalf("unexpected content-type: %q. Expecting %q", gotContentType, expectContentType) - } - gotMultiLine := header.Peek("Multi-Line") - if string(gotMultiLine) != expectMultiLine { - t.Fatalf("unexpected multi-line: %q. Expecting %q", gotMultiLine, expectMultiLine) - } - // ensure behave same as net/http response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(s)), nil) if err != nil { t.Fatalf("parse response using net/http failed, %s", err) } - gotNetHttpContentType := response.Header.Get("Content-Type") - if gotNetHttpContentType != expectNetHttpContentType { - t.Fatalf("unexpected content-type (net/http): %q. Expecting %q", - gotNetHttpContentType, expectNetHttpContentType) + + for name, vals := range response.Header { + got := string(header.Peek(name)) + want := vals[0] + + if got != want { + t.Errorf("unexpected %s got: %q want: %q", name, got, want) + } + } +} + +func TestResponseHeaderMultiLineName(t *testing.T) { + s := "HTTP/1.1 200 OK\r\n" + + "Host: golang.org\r\n" + + "Gopher-New-\r\n" + + " Line: This is a header on multiple lines\r\n" + + "\r\n" + header := new(ResponseHeader) + if _, err := header.parse([]byte(s)); err != errInvalidName { + m := make(map[string]string) + header.VisitAll(func(key, value []byte) { + m[string(key)] = string(value) + }) + t.Errorf("expected error, got %q (%v)", m, err) } }