diff --git a/experimental/status/status_source.go b/experimental/status/status_source.go index eda6b9858..487e5c1e7 100644 --- a/experimental/status/status_source.go +++ b/experimental/status/status_source.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "errors" "fmt" + "io" "net" "net/http" "net/url" @@ -142,7 +143,8 @@ func IsDownstreamHTTPError(err error) bool { return IsDownstreamError(err) || isConnectionResetOrRefusedError(err) || isDNSNotFoundError(err) || - isTLSCertificateVerificationError(err) + isTLSCertificateVerificationError(err) || + isHTTPEOFError(err) } // InCancelledError returns true if err is context.Canceled or is gRPC status Canceled. @@ -202,6 +204,20 @@ func isTLSCertificateVerificationError(err error) bool { return false } +// isHTTPEOFError returns true if the error is an EOF error inside of url.Error or net.OpError, indicating the connection was closed prematurely by server +func isHTTPEOFError(err error) bool { + var netErr *net.OpError + if errors.As(err, &netErr) { + return errors.Is(netErr.Err, io.EOF) + } + + var urlErr *url.Error + if errors.As(err, &urlErr) { + return errors.Is(urlErr.Err, io.EOF) + } + return false +} + type sourceCtxKey struct{} // SourceFromContext returns the source stored in the context. diff --git a/experimental/status/status_source_test.go b/experimental/status/status_source_test.go index 905b8f622..8537e8f2a 100644 --- a/experimental/status/status_source_test.go +++ b/experimental/status/status_source_test.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "errors" "fmt" + "io" "net" "net/url" "os" @@ -248,6 +249,36 @@ func TestIsDownstreamHTTPError(t *testing.T) { err: x509.UnknownAuthorityError{}, expected: true, }, + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "io.EOF error", + err: io.EOF, + expected: false, + }, + { + name: "url io.EOF error", + err: &url.Error{Op: "Get", URL: "https://example.com", Err: io.EOF}, + expected: true, + }, + { + name: "net op io.EOF error", + err: &net.OpError{Err: io.EOF}, + expected: true, + }, + { + name: "wrapped url io.EOF error", + err: fmt.Errorf("wrapped: %w", &url.Error{Op: "Get", URL: "https://example.com", Err: io.EOF}), + expected: true, + }, + { + name: "joined error with io.EOF", + err: errors.Join(io.EOF, &url.Error{Op: "Get", URL: "https://example.com", Err: io.EOF}), + expected: true, + }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) {