Skip to content

Commit

Permalink
fix: Refresh client cert when it is rejected by the server.
Browse files Browse the repository at this point in the history
  • Loading branch information
hessjcg committed Jan 31, 2025
1 parent f0537a6 commit 95d468c
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 40 deletions.
51 changes: 42 additions & 9 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
tlsConn := tls.Client(conn, ci.TLSConfig())
err = tlsConn.HandshakeContext(ctx)
if err != nil {
// TLS handshake errors are fatal and require a refresh. Remove the instance
// from the cache so that future calls to Dial() will block until the
// certificate is refreshed successfully.
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
// refresh the instance info in case it caused the handshake failure
c.ForceRefresh()
d.removeCached(ctx, cn, c, err)
_ = tlsConn.Close() // best effort close attempt
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
}
Expand All @@ -435,7 +437,22 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
iConn := newInstrumentedConn(tlsConn, func() {
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
}, d.dialerID, cn.String())
},
func(err error) {
// ignore EOF
if err == io.EOF {
return
}
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
if d.isTLSError(err) {
// TLS handshake errors are fatal. Remove the instance from the cache
// so that future calls to Dial() will block until the certificate
// is refreshed successfully.
d.removeCached(ctx, cn, c, err)
_ = tlsConn.Close() // best effort close attempt
}
},
d.dialerID, cn.String())

// If this connection was opened using a Domain Name, then store it for later
// in case it needs to be forcibly closed.
Expand All @@ -446,12 +463,19 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
}
return iConn, nil
}
func (d *Dialer) isTLSError(err error) bool {
if nErr, ok := err.(net.Error); ok {
return !nErr.Timeout() && // it's a permanent net error
strings.Contains(nErr.Error(), "tls") // it's a TLS-related error
}
return false
}

// removeCached stops all background refreshes and deletes the connection
// info cache from the map of caches.
func (d *Dialer) removeCached(
ctx context.Context,
i instance.ConnName, c connectionInfoCache, err error,
i instance.ConnName, c *monitoredCache, err error,
) {
d.logger.Debugf(
ctx,
Expand All @@ -461,8 +485,11 @@ func (d *Dialer) removeCached(
)
d.lock.Lock()
defer d.lock.Unlock()
c.Close()
delete(d.cache, createKey(i))
key := createKey(i)
if cachedC, ok := d.cache[key]; ok && cachedC == c {
delete(d.cache, key)
}
c.connectionInfoCache.Close()
}

// validClientCert checks that the ephemeral client certificate retrieved from
Expand Down Expand Up @@ -504,7 +531,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
}
ci, err := c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c.connectionInfoCache, err)
d.removeCached(ctx, cn, c, err)
return "", err
}
return ci.DBVersion, nil
Expand All @@ -528,17 +555,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
}
_, err = c.ConnectionInfo(ctx)
if err != nil {
d.removeCached(ctx, cn, c.connectionInfoCache, err)
d.removeCached(ctx, cn, c, err)
}
return err
}

// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn {
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
return &instrumentedConn{
Conn: conn,
closeFunc: closeFunc,
errFunc: errFunc,
dialerID: dialerID,
connName: connName,
}
Expand All @@ -549,6 +577,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
type instrumentedConn struct {
net.Conn
closeFunc func()
errFunc func(error)
mu sync.RWMutex
closed bool
dialerID string
Expand All @@ -561,6 +590,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
bytesRead, err := i.Conn.Read(b)
if err == nil {
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
} else {
i.errFunc(err)
}
return bytesRead, err
}
Expand All @@ -571,6 +602,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
bytesWritten, err := i.Conn.Write(b)
if err == nil {
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
} else {
i.errFunc(err)
}
return bytesWritten, err
}
Expand Down
65 changes: 61 additions & 4 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
}
}

func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
func TestDialerRefreshesAfterRotateClientCA(t *testing.T) {
inst := mock.NewFakeCSQLInstanceWithSan(
"my-project", "my-region", "my-instance", []string{"db.example.com"},
mock.WithDNS("db.example.com"),
Expand Down Expand Up @@ -1210,14 +1210,13 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
cancel1()

mock.RotateClientCA(inst)
time.Sleep(2 * time.Second)

// Recreate the instance, which generates new server certificates
// Start the server with new certificates
cancel2 := mock.StartServerProxy(t, inst)
defer cancel2()

// Dial a second time. We expect no error on dial, but TLS error on read.
t.Log("Second attempt should fail...")
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
if err != nil {
t.Fatal("Should be no certificate error after, got ", err)
Expand All @@ -1230,8 +1229,66 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
} else {
t.Fatal("Want read error, got no error")
}
t.Log("Second attempt done")

// Dial again. This should complete after the refresh.
t.Log("Third attempt...")
testSuccessfulDial(
context.Background(), t, d,
"my-project:my-region:my-instance",
)
t.Log("Third attempt OK.")
}

func TestDialerRefreshesAfterRotateServerCA(t *testing.T) {
inst := mock.NewFakeCSQLInstanceWithSan(
"my-project", "my-region", "my-instance", []string{"db.example.com"},
mock.WithDNS("db.example.com"),
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
)

d := setupDialer(t, setupConfig{
skipServer: true,
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
},
dialerOptions: []Option{
WithTokenSource(mock.EmptyTokenSource{}),
WithDebugLogger(&dialerTestLogger{t: t}),
WithLazyRefresh(),
// Note: this succeeds with lazy refresh, but fails with lazy.
// because dialer.ForceRefresh does not block connections while the
// refresh is in progress.
},
})
cancel1 := mock.StartServerProxy(t, inst)
t.Log("First attempt...")
testSuccessfulDial(
context.Background(), t, d,
"my-project:my-region:my-instance",
)
t.Log("First attempt OK. Resetting client cert.")

// Close the server
cancel1()

mock.RotateCA(inst)

// Start the server with new certificates
cancel2 := mock.StartServerProxy(t, inst)
defer cancel2()

// Dial a second time. We expect no error on dial, but TLS error on read.
t.Log("Second attempt should fail...")
_, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
if err != nil {
t.Log("Got error on dial as expected.", err)
} else {
t.Fatal("Want dial error, got no error")
}

time.Sleep(2 * time.Second)
// Dial again. This should occur after the refresh has completed.
t.Log("Third attempt...")
testSuccessfulDial(
Expand Down
2 changes: 2 additions & 0 deletions internal/mock/cloudsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ func (f FakeCSQLInstance) serverCACert() ([]byte, error) {
if f.signer != nil {
return f.signer(f.Cert, f.Key)
}

if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
// legacy server mode, return only the server cert
return toPEMFormat(f.certs.serverCert)
}
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)

}

// ClientCert creates an ephemeral client certificate signed with the Cloud SQL
Expand Down
54 changes: 27 additions & 27 deletions internal/mock/sqladmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,38 @@ func (r *Request) matches(hR *http.Request) bool {
//
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get
func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
var ips []*sqladmin.IpMapping
for ipType, addr := range i.ipAddrs {
if ipType == "PUBLIC" {
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
continue
}
if ipType == "PRIVATE" {
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
}
}

certBytes, err := i.serverCACert()
if err != nil {
panic(err)
}

db := &sqladmin.ConnectSettings{
BackendType: i.backendType,
DatabaseVersion: i.dbVersion,
DnsName: i.DNSName,
IpAddresses: ips,
Region: i.region,
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
PscEnabled: i.pscEnabled,
ServerCaMode: i.serverCAMode,
}

r := &Request{
reqMethod: http.MethodGet,
reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
reqCt: ct,
handle: func(resp http.ResponseWriter, _ *http.Request) {
var ips []*sqladmin.IpMapping
for ipType, addr := range i.ipAddrs {
if ipType == "PUBLIC" {
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
continue
}
if ipType == "PRIVATE" {
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
}
}

certBytes, err := i.serverCACert()
if err != nil {
panic(err)
}

db := &sqladmin.ConnectSettings{
BackendType: i.backendType,
DatabaseVersion: i.dbVersion,
DnsName: i.DNSName,
IpAddresses: ips,
Region: i.region,
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
PscEnabled: i.pscEnabled,
ServerCaMode: i.serverCAMode,
}

b, err := db.MarshalJSON()
if err != nil {
http.Error(resp, err.Error(), http.StatusInternalServerError)
Expand Down

0 comments on commit 95d468c

Please sign in to comment.