Skip to content

Commit

Permalink
Fix etag hijacking. (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickdappollonio authored Sep 9, 2024
1 parent 0dcbc25 commit 80bf87b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 64 deletions.
76 changes: 47 additions & 29 deletions internal/mw/etag.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,54 @@
package mw

import (
"bytes"
"crypto/sha1"
"encoding/hex"
"fmt"
"hash"
"net/http"
"sync"
)

// etagResponseWriter is a wrapper around http.ResponseWriter that will
// calculate the ETag header for the response as it streams the data.
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}

type etagResponseWriter struct {
rw http.ResponseWriter
hash hash.Hash
status int
hash hash.Hash
headers map[string][]string
buf *bytes.Buffer
status int
}

// Header returns the header map that will be sent by WriteHeader.
// Header returns the header map that will be sent by WriteHeader
func (e *etagResponseWriter) Header() http.Header {
return e.rw.Header()
return e.headers
}

// WriteHeader sends an HTTP response header with the provided status code.
// We don't write the status code just yet to the original response writer,
// since our goal is to calculate the ETag then update the status code if
// there was a match.
// WriteHeader sends an HTTP response header with the provided status code
func (e *etagResponseWriter) WriteHeader(status int) {
e.status = status
}

// Write writes the data to the connection as part of an HTTP reply, while
// calculating the ETag on the fly to avoid buffering large amounts of data.
// Write writes the data to the connection as part of an HTTP reply
func (e *etagResponseWriter) Write(p []byte) (int, error) {
// In Go, a call to Write will always
// set the status code to 200 if it's not set
if e.status == 0 {
e.WriteHeader(http.StatusOK)
e.status = http.StatusOK
}

// Write the data to the hash for ETag calculation
e.hash.Write(p)

// Write the data to the actual response writer
return e.rw.Write(p)
return e.buf.Write(p)
}

// Etag is a middleware that will calculate the ETag header for the response.
// Etag middleware
func Etag(enabled bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -51,28 +57,40 @@ func Etag(enabled bool) func(next http.Handler) http.Handler {
return
}

ew := &etagResponseWriter{
rw: w,
hash: sha1.New(),
buf := bufPool.Get().(*bytes.Buffer)
defer func() {
buf.Reset()
bufPool.Put(buf)
}()

alternateWriter := &etagResponseWriter{
headers: http.Header{},
buf: buf,
hash: sha1.New(),
}

// Call the next handler and stream the data while hashing
next.ServeHTTP(ew, r)
next.ServeHTTP(alternateWriter, r)

// If status code is in the range of 200-399, calculate ETag
if ew.status >= http.StatusOK && ew.status < http.StatusBadRequest {
etag := hex.EncodeToString(ew.hash.Sum(nil))
w.Header().Set("ETag", `"`+etag+`"`)
// If the status is in the range of 200-399, calculate ETag
if alternateWriter.status >= http.StatusOK && alternateWriter.status < http.StatusBadRequest {
etag := fmt.Sprintf("%q", hex.EncodeToString(alternateWriter.hash.Sum(nil)))
alternateWriter.Header().Set("Etag", etag)

// Check if the ETag matches the client request
if r.Header.Get("If-None-Match") == w.Header().Get("ETag") {
w.WriteHeader(http.StatusNotModified)
return
if r.Header.Get("If-None-Match") == etag {
alternateWriter.WriteHeader(http.StatusNotModified)
}
}

// Write the status code if it hasn't been written yet
w.WriteHeader(ew.status)
// Pass the response to the actual response writer
for key, vals := range alternateWriter.headers {
for _, val := range vals {
w.Header().Add(key, val)
}
}
w.WriteHeader(alternateWriter.status)
w.Write(alternateWriter.buf.Bytes())
})
}
}
63 changes: 28 additions & 35 deletions internal/mw/log_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,42 @@ import (
"time"
)

// statusResponseWriter is a wrapper around http.ResponseWriter that
// allows us to capture the status code and bytes written
type statusResponseWriter struct {
http.ResponseWriter
type logResponseWriter struct {
rw http.ResponseWriter
statusCode int
bytesWritten int
bytesWritten int64
}

// newStatusResponseWriter returns a new statusResponseWriter
func newStatusResponseWriter(w http.ResponseWriter) *statusResponseWriter {
return &statusResponseWriter{w, http.StatusOK, 0}
func (lrw *logResponseWriter) Header() http.Header {
return lrw.rw.Header()
}

// WriteHeader implements the http.ResponseWriter interface
func (lrw *statusResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
func (lrw *logResponseWriter) Write(p []byte) (int, error) {
n, err := lrw.rw.Write(p)
lrw.bytesWritten += int64(n)
return n, err
}

// Write implements the http.ResponseWriter interface
func (lrw *statusResponseWriter) Write(b []byte) (int, error) {
bw, err := lrw.ResponseWriter.Write(b)
lrw.bytesWritten = bw
return bw, err
func (lrw *logResponseWriter) WriteHeader(statusCode int) {
lrw.rw.WriteHeader(statusCode)
lrw.statusCode = statusCode
}

// LogRequest is a middleware that logs specific request data using a predefined
// template format. Available options are:
// - {http_method} the HTTP method
// - {url} the URL
// - {proto} the protocol version
// - {status_code} the HTTP status code
// - {status_text} the HTTP status text
// - {duration} the duration of the request
// - {bytes_written} the number of bytes written
// LogRequest middleware
func LogRequest(output io.Writer, format string, redactedQuerystringFields ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture starting time
start := time.Now()

// Create a new instance of our custom responsewriter
lrw := newStatusResponseWriter(w)
// Wrap the response writer
lrw := &logResponseWriter{
rw: w,
}

// Serve the next request
// Call the next middleware or handler
next.ServeHTTP(lrw, r)

// Capture path and remove any querystring keys
// Log the request after all middlewares have completed
urlpath := r.URL.Path
if r.URL.Query().Encode() != "" {
querystrings := r.URL.Query()
Expand All @@ -67,18 +55,23 @@ func LogRequest(output io.Writer, format string, redactedQuerystringFields ...st
urlpath = r.URL.Path + "?" + querystrings.Encode()
}

// Generate a string representation of the log message
// Get the status code or 200
statusCode := lrw.statusCode
if statusCode == 0 {
statusCode = http.StatusOK
}

// Log the request details
s := strings.NewReplacer(
"{http_method}", r.Method,
"{url}", urlpath,
"{proto}", r.Proto,
"{status_code}", fmt.Sprintf("%d", lrw.statusCode),
"{status_text}", http.StatusText(lrw.statusCode),
"{status_code}", fmt.Sprintf("%d", statusCode),
"{status_text}", http.StatusText(statusCode),
"{duration}", time.Since(start).String(),
"{bytes_written}", fmt.Sprintf("%d", lrw.bytesWritten),
).Replace(format)

// Print that log message to the output writer
fmt.Fprintln(output, s)
})
}
Expand Down

0 comments on commit 80bf87b

Please sign in to comment.