Skip to content

Commit

Permalink
Generate and display more useful error message (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
misberner authored Nov 20, 2020
1 parent 14f2e96 commit 77f0d86
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 10 deletions.
12 changes: 11 additions & 1 deletion client/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
"github.com/pkg/errors"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.stackrox.io/grpc-http1/internal/grpcproto"
"golang.stackrox.io/grpc-http1/internal/grpcweb"
"golang.stackrox.io/grpc-http1/internal/httputils"
"golang.stackrox.io/grpc-http1/internal/pipeconn"
"golang.stackrox.io/grpc-http1/internal/stringutils"
"google.golang.org/grpc"
Expand All @@ -36,6 +38,13 @@ import (
)

func modifyResponse(resp *http.Response) error {
// Check if the response is an error response right away, and attempt to display a more useful
// message than gRPC does by default. We still delegate to the default gRPC behavior for 200 responses
// which are otherwise invalid.
if err := httputils.ExtractResponseError(resp); err != nil {
return errors.Wrap(err, "receiving gRPC response from remote endpoint")
}

if resp.ContentLength == 0 {
// Make sure headers do not get flushed, as otherwise the gRPC client will complain about missing trailers.
resp.Header.Set(dontFlushHeadersHeaderKey, "true")
Expand Down Expand Up @@ -66,7 +75,8 @@ func writeError(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusOK)

w.Header().Set("Grpc-Status", fmt.Sprintf("%d", codes.Unavailable))
w.Header().Set("Grpc-Message", errors.Wrap(err, "transport").Error())
errMsg := errors.Wrap(err, "transport").Error()
w.Header().Set("Grpc-Message", grpcproto.EncodeGrpcMessage(errMsg))
}

func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool) *httputil.ReverseProxy {
Expand Down
20 changes: 16 additions & 4 deletions client/ws_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pkg/errors"
"golang.stackrox.io/grpc-http1/internal/grpcproto"
"golang.stackrox.io/grpc-http1/internal/grpcwebsocket"
"golang.stackrox.io/grpc-http1/internal/httputils"
"golang.stackrox.io/grpc-http1/internal/pipeconn"
"golang.stackrox.io/grpc-http1/internal/size"
"google.golang.org/grpc/codes"
Expand All @@ -42,7 +43,7 @@ const (
)

var (
subprotocols = []string{"grpc-ws"}
subprotocols = []string{grpcwebsocket.SubprotocolName}
)

type http2WebSocketProxy struct {
Expand Down Expand Up @@ -209,7 +210,8 @@ func (c *websocketConn) writeErrorIfNecessary() {
c.w.WriteHeader(http.StatusOK)

c.w.Header().Set("Trailer:Grpc-Status", fmt.Sprintf("%d", codes.Unavailable))
c.w.Header().Set("Trailer:Grpc-Message", errors.Wrap(c.err, "transport").Error())
errMsg := errors.Wrap(c.err, "transport").Error()
c.w.Header().Set("Trailer:Grpc-Message", grpcproto.EncodeGrpcMessage(errMsg))
}

// ServeHTTP handles gRPC-WebSocket traffic.
Expand All @@ -228,16 +230,26 @@ func (h *http2WebSocketProxy) ServeHTTP(w http.ResponseWriter, req *http.Request
url := *req.URL // Copy the value, so we do not overwrite the URL.
url.Scheme = scheme
url.Host = h.endpoint
conn, _, err := websocket.Dial(req.Context(), url.String(), &websocket.DialOptions{
conn, resp, err := websocket.Dial(req.Context(), url.String(), &websocket.DialOptions{
// Add the gRPC headers to the WebSocket handshake request.
HTTPHeader: req.Header,
HTTPClient: h.httpClient,
Subprotocols: subprotocols,
// gRPC already performs compression, so no need for WebSocket to add compression as well.
CompressionMode: websocket.CompressionDisabled,
})
if resp != nil {
// Not strictly necessary because the library already replaces resp.Body with a NopCloser,
// but seems too easy to miss should we switch to a different library.
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
writeError(w, errors.Wrapf(err, "connecting to gRPC server %q", url.String()))
if resp != nil {
if respErr := httputils.ExtractResponseError(resp); respErr != nil {
err = fmt.Errorf("%w; response error: %v", err, respErr)
}
}
writeError(w, errors.Wrapf(err, "connecting to gRPC endpoint %q", url.String()))
return
}
conn.SetReadLimit(64 * size.MB)
Expand Down
65 changes: 65 additions & 0 deletions internal/grpcproto/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package grpcproto

import (
"bytes"
"fmt"
"unicode/utf8"
)

// This code is copied from google.golang.org/[email protected]/internal/transport/http_util.go, ll.443-494,
// and has been adjusted to make the `EncodeGrpcMessage` function exported.
// The original code is Copyright (c) by the gRPC authors and was distributed under the
// Apache License, version 2.0.

const (
spaceByte = ' '
tildeByte = '~'
percentByte = '%'
)

// EncodeGrpcMessage is used to encode status code in header field
// "grpc-message". It does percent encoding and also replaces invalid utf-8
// characters with Unicode replacement character.
//
// It checks to see if each individual byte in msg is an allowable byte, and
// then either percent encoding or passing it through. When percent encoding,
// the byte is converted into hexadecimal notation with a '%' prepended.
func EncodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}

func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
for len(msg) > 0 {
r, size := utf8.DecodeRuneInString(msg)
for _, b := range []byte(string(r)) {
if size > 1 {
// If size > 1, r is not ascii. Always do percent encoding.
buf.WriteString(fmt.Sprintf("%%%02X", b))
continue
}

// The for loop is necessary even if size == 1. r could be
// utf8.RuneError.
//
// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
if b >= spaceByte && b <= tildeByte && b != percentByte {
buf.WriteByte(b)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", b))
}
}
msg = msg[size:]
}
return buf.String()
}
7 changes: 7 additions & 0 deletions internal/grpcwebsocket/consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package grpcwebsocket

const (
// SubprotocolName is the subprotocol for gRPC-websocket specified in the Sec-Websocket-Protocol
// header.
SubprotocolName = "grpc-ws"
)
54 changes: 54 additions & 0 deletions internal/httputils/error_message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package httputils

import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
"strings"
"unicode/utf8"
)

const (
maxBodyBytes = 1024
)

var (
httpHeaderOptSeparatorRegex = regexp.MustCompile(`;\s*`)
)

// ExtractResponseError extracts an error from an HTTP response, reading at most 1024 bytes of the
// response body.
func ExtractResponseError(resp *http.Response) error {
if resp.StatusCode < 400 {
return nil
}
contentTypeFields := httpHeaderOptSeparatorRegex.Split(resp.Header.Get("Content-Type"), 2)
if len(contentTypeFields) == 0 {
return errors.New(resp.Status)
}

if contentTypeFields[0] != "text/plain" {
return fmt.Errorf("%s, content-type %s", resp.Status, contentTypeFields[0])
}

bodyReader := io.LimitReader(resp.Body, maxBodyBytes)
contents, err := ioutil.ReadAll(bodyReader)
contentsStr := strings.TrimSpace(string(contents))
if !utf8.Valid(contents) {
contentsStr = "invalid UTF-8 characters in response"
}
if err != nil {
if contentsStr == "" {
return fmt.Errorf("%s, error reading response body: %v", resp.Status, err)
}
return fmt.Errorf("%s: %s, error reading response body after %d bytes: %v", resp.Status, contentsStr, len(contents), err)
}

if contentsStr == "" {
return errors.New(resp.Status)
}
return fmt.Errorf("%s: %s", resp.Status, contentsStr)
}
21 changes: 16 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package server

import (
"errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -173,7 +174,10 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, op
}

return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if isWebSocketUpgrade(req.Header) {
if isUpgrade, err := isWebSocketUpgrade(req.Header); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
} else if isUpgrade {
handleGRPCWS(w, req, grpcSrv)
return
}
Expand All @@ -188,10 +192,17 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, op
})
}

func isWebSocketUpgrade(header http.Header) bool {
return header.Get("Connection") == "Upgrade" &&
header.Get("Upgrade") == "websocket" &&
header.Get("Sec-Websocket-Protocol") == "grpc-ws"
func isWebSocketUpgrade(header http.Header) (bool, error) {
if header.Get("Sec-Websocket-Protocol") != grpcwebsocket.SubprotocolName {
return false, nil
}
if header.Get("Connection") != "Upgrade" {
return false, errors.New("missing 'Connection: Upgrade' header in gRPC-websocket request (this usually means your proxy or load balancer does not support websockets)")
}
if header.Get("Upgrade") != "websocket" {
return false, errors.New("missing 'Upgrade: websocket' header in gRPC-websocket request (this usually means your proxy or load balancer does not support websockets)")
}
return true, nil
}

func spaceOrComma(r rune) bool {
Expand Down

0 comments on commit 77f0d86

Please sign in to comment.