Skip to content

Commit

Permalink
Use httpsnoop to wrap ResponseWriter. (#193)
Browse files Browse the repository at this point in the history
Wrapping http.ResponseWriter is fraught with danger. Our compress
handler made sure to implement all the optional ResponseWriter
interfaces, but that made it implement them even if the underlying
writer did not. For example, if the underlying ResponseWriter was
_not_ an http.Hijacker, the compress writer nonetheless appeared to
implement http.Hijacker, but would panic if you called Hijack().

On the other hand, the logging handler checked for certain
combinations of optional interfaces and only implemented them as
appropriate. However, it didn't check for all optional interfaces or
all combinations, so most optional interfaces would still get lost.

Fix both problems by using httpsnoop to do the wrapping. It uses code
generation to ensure correctness, and it handles std lib changes like
the http.Pusher addition in Go 1.8.

Fixes #169.
  • Loading branch information
Muir Manders authored Aug 20, 2020
1 parent 2188616 commit 55df21f
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 145 deletions.
65 changes: 27 additions & 38 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,30 @@ import (
"io"
"net/http"
"strings"

"github.com/felixge/httpsnoop"
)

const acceptEncoding string = "Accept-Encoding"

type compressResponseWriter struct {
io.Writer
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
}

func (w *compressResponseWriter) WriteHeader(c int) {
w.ResponseWriter.Header().Del("Content-Length")
w.ResponseWriter.WriteHeader(c)
compressor io.Writer
w http.ResponseWriter
}

func (w *compressResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
func (cw *compressResponseWriter) WriteHeader(c int) {
cw.w.Header().Del("Content-Length")
cw.w.WriteHeader(c)
}

func (w *compressResponseWriter) Write(b []byte) (int, error) {
h := w.ResponseWriter.Header()
func (cw *compressResponseWriter) Write(b []byte) (int, error) {
h := cw.w.Header()
if h.Get("Content-Type") == "" {
h.Set("Content-Type", http.DetectContentType(b))
}
h.Del("Content-Length")

return w.Writer.Write(b)
return cw.compressor.Write(b)
}

type flusher interface {
Expand All @@ -47,12 +42,12 @@ type flusher interface {

func (w *compressResponseWriter) Flush() {
// Flush compressed data if compressor supports it.
if f, ok := w.Writer.(flusher); ok {
if f, ok := w.compressor.(flusher); ok {
f.Flush()
}
// Flush HTTP response.
if w.Flusher != nil {
w.Flusher.Flush()
if f, ok := w.w.(http.Flusher); ok {
f.Flush()
}
}

Expand Down Expand Up @@ -119,28 +114,22 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler {
w.Header().Set("Content-Encoding", encoding)
r.Header.Del(acceptEncoding)

hijacker, ok := w.(http.Hijacker)
if !ok { /* w is not Hijacker... oh well... */
hijacker = nil
cw := &compressResponseWriter{
w: w,
compressor: encWriter,
}

flusher, ok := w.(http.Flusher)
if !ok {
flusher = nil
}

closeNotifier, ok := w.(http.CloseNotifier)
if !ok {
closeNotifier = nil
}

w = &compressResponseWriter{
Writer: encWriter,
ResponseWriter: w,
Hijacker: hijacker,
Flusher: flusher,
CloseNotifier: closeNotifier,
}
w = httpsnoop.Wrap(w, httpsnoop.Hooks{
Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return cw.Write
},
WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return cw.WriteHeader
},
Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc {
return cw.Flush
},
})

h.ServeHTTP(w, r)
})
Expand Down
38 changes: 31 additions & 7 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) {
acceptEncoding: []string{compression},
},
})

}

func TestCompressHandlerNoCompression(t *testing.T) {
Expand Down Expand Up @@ -165,6 +164,7 @@ type fullyFeaturedResponseWriter struct{}
func (fullyFeaturedResponseWriter) Header() http.Header {
return http.Header{}
}

func (fullyFeaturedResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
Expand Down Expand Up @@ -193,9 +193,6 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
)
var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
comp := r.Header.Get(acceptEncoding)
if _, ok := rw.(*compressResponseWriter); !ok {
t.Fatalf("ResponseWriter wasn't wrapped by compressResponseWriter, got %T type", rw)
}
if _, ok := rw.(http.Flusher); !ok {
t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp)
}
Expand All @@ -207,9 +204,7 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
}
})
h = CompressHandler(h)
var (
rw fullyFeaturedResponseWriter
)
var rw fullyFeaturedResponseWriter
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
Expand All @@ -220,3 +215,32 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
r.Header.Set(acceptEncoding, "deflate")
h.ServeHTTP(rw, r)
}

type paltryResponseWriter struct{}

func (paltryResponseWriter) Header() http.Header {
return http.Header{}
}

func (paltryResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (paltryResponseWriter) WriteHeader(int) {}

func TestCompressHandlerDoesntInventInterfaces(t *testing.T) {
var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if _, ok := rw.(http.Hijacker); ok {
t.Error("ResponseWriter shouldn't implement http.Hijacker")
}
})

h = CompressHandler(h)

var rw paltryResponseWriter
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
r.Header.Set(acceptEncoding, "gzip")
h.ServeHTTP(rw, r)
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/gorilla/handlers

go 1.14

require github.com/felixge/httpsnoop v1.0.1
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
35 changes: 4 additions & 31 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ type responseLogger struct {
size int
}

func (l *responseLogger) Header() http.Header {
return l.w.Header()
}

func (l *responseLogger) Write(b []byte) (int, error) {
size, err := l.w.Write(b)
l.size += size
Expand All @@ -74,39 +70,16 @@ func (l *responseLogger) Size() int {
return l.size
}

func (l *responseLogger) Flush() {
f, ok := l.w.(http.Flusher)
if ok {
f.Flush()
}
}

type hijackLogger struct {
responseLogger
}

func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h := l.responseLogger.w.(http.Hijacker)
conn, rw, err := h.Hijack()
if err == nil && l.responseLogger.status == 0 {
func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn, rw, err := l.w.(http.Hijacker).Hijack()
if err == nil && l.status == 0 {
// The status will be StatusSwitchingProtocols if there was no error and
// WriteHeader has not been called yet
l.responseLogger.status = http.StatusSwitchingProtocols
l.status = http.StatusSwitchingProtocols
}
return conn, rw, err
}

type closeNotifyWriter struct {
loggingResponseWriter
http.CloseNotifier
}

type hijackCloseNotifier struct {
loggingResponseWriter
http.Hijacker
http.CloseNotifier
}

// isContentType validates the Content-Type header matches the supplied
// contentType. That is, its type and subtype match.
func isContentType(h http.Header, contentType string) bool {
Expand Down
29 changes: 0 additions & 29 deletions handlers_go18.go

This file was deleted.

13 changes: 11 additions & 2 deletions handlers_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ import (
"testing"
)

// *httptest.ResponseRecorder doesn't implement Pusher, so wrap it.
type pushRecorder struct {
*httptest.ResponseRecorder
}

func (pr pushRecorder) Push(target string, opts *http.PushOptions) error {
return nil
}

func TestLoggingHandlerWithPush(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if _, ok := w.(http.Pusher); !ok {
Expand All @@ -18,7 +27,7 @@ func TestLoggingHandlerWithPush(t *testing.T) {
})

logger := LoggingHandler(ioutil.Discard, handler)
logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/"))
logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/"))
}

func TestCombinedLoggingHandlerWithPush(t *testing.T) {
Expand All @@ -30,5 +39,5 @@ func TestCombinedLoggingHandlerWithPush(t *testing.T) {
})

logger := CombinedLoggingHandler(ioutil.Discard, handler)
logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/"))
logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/"))
}
7 changes: 0 additions & 7 deletions handlers_pre18.go

This file was deleted.

39 changes: 14 additions & 25 deletions logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strconv"
"time"
"unicode/utf8"

"github.com/felixge/httpsnoop"
)

// Logging
Expand Down Expand Up @@ -39,10 +41,10 @@ type loggingHandler struct {

func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
t := time.Now()
logger := makeLogger(w)
logger, w := makeLogger(w)
url := *req.URL

h.handler.ServeHTTP(logger, req)
h.handler.ServeHTTP(w, req)
if req.MultipartForm != nil {
req.MultipartForm.RemoveAll()
}
Expand All @@ -58,27 +60,16 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
h.formatter(h.writer, params)
}

func makeLogger(w http.ResponseWriter) loggingResponseWriter {
var logger loggingResponseWriter = &responseLogger{w: w, status: http.StatusOK}
if _, ok := w.(http.Hijacker); ok {
logger = &hijackLogger{responseLogger{w: w, status: http.StatusOK}}
}
h, ok1 := logger.(http.Hijacker)
c, ok2 := w.(http.CloseNotifier)
if ok1 && ok2 {
return hijackCloseNotifier{logger, h, c}
}
if ok2 {
return &closeNotifyWriter{logger, c}
}
return logger
}

type commonLoggingResponseWriter interface {
http.ResponseWriter
http.Flusher
Status() int
Size() int
func makeLogger(w http.ResponseWriter) (*responseLogger, http.ResponseWriter) {
logger := &responseLogger{w: w, status: http.StatusOK}
return logger, httpsnoop.Wrap(w, httpsnoop.Hooks{
Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return logger.Write
},
WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return logger.WriteHeader
},
})
}

const lowerhex = "0123456789abcdef"
Expand Down Expand Up @@ -145,7 +136,6 @@ func appendQuoted(buf []byte, s string) []byte {
}
}
return buf

}

// buildCommonLogLine builds a log entry for req in Apache Common Log Format.
Expand All @@ -160,7 +150,6 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int
}

host, _, err := net.SplitHostPort(req.RemoteAddr)

if err != nil {
host = req.RemoteAddr
}
Expand Down
Loading

0 comments on commit 55df21f

Please sign in to comment.