From bb306f8bd70b469e3f862d36724f0a3ecea41ba2 Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Fri, 31 Jan 2025 14:11:57 -0700 Subject: [PATCH] fix: Refresh client cert when it is rejected by the server. --- dialer.go | 51 ++++++++++++++++++++++++------ dialer_test.go | 65 ++++++++++++++++++++++++++++++++++++--- internal/mock/certs.go | 4 +++ internal/mock/cloudsql.go | 15 +++++---- internal/mock/sqladmin.go | 54 ++++++++++++++++---------------- 5 files changed, 143 insertions(+), 46 deletions(-) diff --git a/dialer.go b/dialer.go index 37a8def7..ecd42244 100644 --- a/dialer.go +++ b/dialer.go @@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn tlsConn := tls.Client(conn, ci.TLSConfig()) err = tlsConn.HandshakeContext(ctx) if err != nil { + // TLS handshake errors are fatal and require a refresh. Remove the instance + // from the cache so that future calls to Dial() will block until the + // certificate is refreshed successfully. d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err) - // refresh the instance info in case it caused the handshake failure - c.ForceRefresh() + d.removeCached(ctx, cn, c, err) _ = tlsConn.Close() // best effort close attempt return nil, errtype.NewDialError("handshake failed", cn.String(), err) } @@ -435,7 +437,22 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn iConn := newInstrumentedConn(tlsConn, func() { n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1 trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String()) - }, d.dialerID, cn.String()) + }, + func(err error) { + // ignore EOF + if err == io.EOF { + return + } + d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err) + if d.isTLSError(err) { + // TLS handshake errors are fatal. Remove the instance from the cache + // so that future calls to Dial() will block until the certificate + // is refreshed successfully. + d.removeCached(ctx, cn, c, err) + _ = tlsConn.Close() // best effort close attempt + } + }, + d.dialerID, cn.String()) // If this connection was opened using a Domain Name, then store it for later // in case it needs to be forcibly closed. @@ -446,12 +463,19 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn } return iConn, nil } +func (d *Dialer) isTLSError(err error) bool { + if nErr, ok := err.(net.Error); ok { + return !nErr.Timeout() && // it's a permanent net error + strings.Contains(nErr.Error(), "tls") // it's a TLS-related error + } + return false +} // removeCached stops all background refreshes and deletes the connection // info cache from the map of caches. func (d *Dialer) removeCached( ctx context.Context, - i instance.ConnName, c connectionInfoCache, err error, + i instance.ConnName, c *monitoredCache, err error, ) { d.logger.Debugf( ctx, @@ -461,8 +485,11 @@ func (d *Dialer) removeCached( ) d.lock.Lock() defer d.lock.Unlock() - c.Close() - delete(d.cache, createKey(i)) + key := createKey(i) + if cachedC, ok := d.cache[key]; ok && cachedC == c { + delete(d.cache, key) + } + c.connectionInfoCache.Close() } // validClientCert checks that the ephemeral client certificate retrieved from @@ -504,7 +531,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) } ci, err := c.ConnectionInfo(ctx) if err != nil { - d.removeCached(ctx, cn, c.connectionInfoCache, err) + d.removeCached(ctx, cn, c, err) return "", err } return ci.DBVersion, nil @@ -528,17 +555,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err } _, err = c.ConnectionInfo(ctx) if err != nil { - d.removeCached(ctx, cn, c.connectionInfoCache, err) + d.removeCached(ctx, cn, c, err) } return err } // newInstrumentedConn initializes an instrumentedConn that on closing will // decrement the number of open connects and record the result. -func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn { +func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn { return &instrumentedConn{ Conn: conn, closeFunc: closeFunc, + errFunc: errFunc, dialerID: dialerID, connName: connName, } @@ -549,6 +577,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str type instrumentedConn struct { net.Conn closeFunc func() + errFunc func(error) mu sync.RWMutex closed bool dialerID string @@ -561,6 +590,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) { bytesRead, err := i.Conn.Read(b) if err == nil { go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID) + } else { + i.errFunc(err) } return bytesRead, err } @@ -571,6 +602,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) { bytesWritten, err := i.Conn.Write(b) if err == nil { go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID) + } else { + i.errFunc(err) } return bytesWritten, err } diff --git a/dialer_test.go b/dialer_test.go index f9080071..a6cf17fd 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -1175,7 +1175,7 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) { } } -func TestDialerRefreshesAfterClientCertificateError(t *testing.T) { +func TestDialerRefreshesAfterRotateClientCA(t *testing.T) { inst := mock.NewFakeCSQLInstanceWithSan( "my-project", "my-region", "my-instance", []string{"db.example.com"}, mock.WithDNS("db.example.com"), @@ -1210,14 +1210,13 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) { cancel1() mock.RotateClientCA(inst) - time.Sleep(2 * time.Second) - // Recreate the instance, which generates new server certificates // Start the server with new certificates cancel2 := mock.StartServerProxy(t, inst) defer cancel2() // Dial a second time. We expect no error on dial, but TLS error on read. + t.Log("Second attempt should fail...") conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance") if err != nil { t.Fatal("Should be no certificate error after, got ", err) @@ -1230,8 +1229,66 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) { } else { t.Fatal("Want read error, got no error") } + t.Log("Second attempt done") + + // Dial again. This should complete after the refresh. + t.Log("Third attempt...") + testSuccessfulDial( + context.Background(), t, d, + "my-project:my-region:my-instance", + ) + t.Log("Third attempt OK.") +} + +func TestDialerRefreshesAfterRotateServerCA(t *testing.T) { + inst := mock.NewFakeCSQLInstanceWithSan( + "my-project", "my-region", "my-instance", []string{"db.example.com"}, + mock.WithDNS("db.example.com"), + mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"), + ) + + d := setupDialer(t, setupConfig{ + skipServer: true, + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 2), + mock.CreateEphemeralSuccess(inst, 2), + }, + dialerOptions: []Option{ + WithTokenSource(mock.EmptyTokenSource{}), + WithDebugLogger(&dialerTestLogger{t: t}), + WithLazyRefresh(), + // Note: this succeeds with lazy refresh, but fails with lazy. + // because dialer.ForceRefresh does not block connections while the + // refresh is in progress. + }, + }) + cancel1 := mock.StartServerProxy(t, inst) + t.Log("First attempt...") + testSuccessfulDial( + context.Background(), t, d, + "my-project:my-region:my-instance", + ) + t.Log("First attempt OK. Resetting client cert.") + + // Close the server + cancel1() + + mock.RotateCA(inst) + + // Start the server with new certificates + cancel2 := mock.StartServerProxy(t, inst) + defer cancel2() + + // Dial a second time. We expect no error on dial, but TLS error on read. + t.Log("Second attempt should fail...") + _, err := d.Dial(context.Background(), "my-project:my-region:my-instance") + if err != nil { + t.Log("Got error on dial as expected.", err) + } else { + t.Fatal("Want dial error, got no error") + } - time.Sleep(2 * time.Second) // Dial again. This should occur after the refresh has completed. t.Log("Third attempt...") testSuccessfulDial( diff --git a/internal/mock/certs.go b/internal/mock/certs.go index 1454beb3..2c5b5056 100644 --- a/internal/mock/certs.go +++ b/internal/mock/certs.go @@ -257,17 +257,21 @@ func (ct *TLSCertificates) serverChain(serverCAMode string) []tls.Certificate { }} } + +// ClientCAPool returns a CertPool with the client CA. func (ct *TLSCertificates) ClientCAPool() *x509.CertPool { clientCa := x509.NewCertPool() clientCa.AddCert(ct.clientSigningCACertificate) return clientCa } +// RotateClientCA rotates only client CA certificates and keys. func (ct *TLSCertificates) RotateClientCA() { ct.clientSigningCaKeyPair = mustGenerateKey() ct.clientSigningCACertificate = mustBuildRootCertificate(signingCaSubject, ct.clientSigningCaKeyPair) } +// RotateCA rotates all certificates and keys. func (ct *TLSCertificates) RotateCA() { oneYear := time.Now().AddDate(1, 0, 0) ct.serverCaKeyPair = mustGenerateKey() diff --git a/internal/mock/cloudsql.go b/internal/mock/cloudsql.go index 203d264d..c916cbcf 100644 --- a/internal/mock/cloudsql.go +++ b/internal/mock/cloudsql.go @@ -74,13 +74,14 @@ func (f FakeCSQLInstance) String() string { func (f FakeCSQLInstance) serverCACert() ([]byte, error) { if f.signer != nil { return f.signer(f.Cert, f.Key) - } else { - if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" { - // legacy server mode, return only the server cert - return toPEMFormat(f.certs.serverCert) - } - return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert) } + + if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" { + // legacy server mode, return only the server cert + return toPEMFormat(f.certs.serverCert) + } + return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert) + } // ClientCert creates an ephemeral client certificate signed with the Cloud SQL @@ -297,10 +298,12 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() { } } +// RotateCA rotates all CA certificates in the instance. func RotateCA(inst FakeCSQLInstance) { inst.certs.RotateCA() } +// RotateClientCA rotates all only the client CA certificates in the instance. func RotateClientCA(inst FakeCSQLInstance) { inst.certs.RotateClientCA() } diff --git a/internal/mock/sqladmin.go b/internal/mock/sqladmin.go index f8afdfd7..82bb3afc 100644 --- a/internal/mock/sqladmin.go +++ b/internal/mock/sqladmin.go @@ -102,38 +102,38 @@ func (r *Request) matches(hR *http.Request) bool { // // https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request { - var ips []*sqladmin.IpMapping - for ipType, addr := range i.ipAddrs { - if ipType == "PUBLIC" { - ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"}) - continue - } - if ipType == "PRIVATE" { - ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"}) - } - } - - certBytes, err := i.serverCACert() - if err != nil { - panic(err) - } - - db := &sqladmin.ConnectSettings{ - BackendType: i.backendType, - DatabaseVersion: i.dbVersion, - DnsName: i.DNSName, - IpAddresses: ips, - Region: i.region, - ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)}, - PscEnabled: i.pscEnabled, - ServerCaMode: i.serverCAMode, - } - r := &Request{ reqMethod: http.MethodGet, reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name), reqCt: ct, handle: func(resp http.ResponseWriter, _ *http.Request) { + var ips []*sqladmin.IpMapping + for ipType, addr := range i.ipAddrs { + if ipType == "PUBLIC" { + ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"}) + continue + } + if ipType == "PRIVATE" { + ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"}) + } + } + + certBytes, err := i.serverCACert() + if err != nil { + panic(err) + } + + db := &sqladmin.ConnectSettings{ + BackendType: i.backendType, + DatabaseVersion: i.dbVersion, + DnsName: i.DNSName, + IpAddresses: ips, + Region: i.region, + ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)}, + PscEnabled: i.pscEnabled, + ServerCaMode: i.serverCAMode, + } + b, err := db.MarshalJSON() if err != nil { http.Error(resp, err.Error(), http.StatusInternalServerError)