Skip to content

Commit

Permalink
refactor: improve naming and key handling in API client (#825)
Browse files Browse the repository at this point in the history
This commit uses a more descriptive name for the code that refreshes
connection info. Instead of "refresher," this commit renames the object
to "adminAPIClient" to match what we're doing in other connectors.

In addition, the RSA key is no longer stored on the cache but instead
passed to the only code that needs the key: the API client.
  • Loading branch information
enocom authored Jun 6, 2024
1 parent b286049 commit 294d77f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 42 deletions.
9 changes: 4 additions & 5 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,13 @@ type RefreshAheadCache struct {

connName instance.ConnName
logger debug.ContextLogger
key *rsa.PrivateKey

// refreshTimeout sets the maximum duration a refresh cycle can run
// for.
refreshTimeout time.Duration
// l controls the rate at which refresh cycles are run.
l *rate.Limiter
r refresher
r adminAPIClient

mu sync.RWMutex
useIAMAuthNDial bool
Expand Down Expand Up @@ -133,11 +132,11 @@ func NewRefreshAheadCache(
i := &RefreshAheadCache{
connName: cn,
logger: l,
key: key,
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
r: newRefresher(
r: newAdminAPIClient(
l,
client,
key,
ts,
dialerID,
),
Expand Down Expand Up @@ -416,7 +415,7 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
useIAMAuthN = i.useIAMAuthNDial
i.mu.Unlock()
r.result, r.err = i.r.ConnectionInfo(
ctx, i.connName, i.key, useIAMAuthN,
ctx, i.connName, useIAMAuthN,
)
}
switch r.err {
Expand Down
9 changes: 4 additions & 5 deletions internal/cloudsql/lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import (
type LazyRefreshCache struct {
connName instance.ConnName
logger debug.ContextLogger
key *rsa.PrivateKey
r refresher
r adminAPIClient
mu sync.Mutex
useIAMAuthNDial bool
needsRefresh bool
Expand All @@ -53,10 +52,10 @@ func NewLazyRefreshCache(
return &LazyRefreshCache{
connName: cn,
logger: l,
key: key,
r: newRefresher(
r: newAdminAPIClient(
l,
client,
key,
ts,
dialerID,
),
Expand Down Expand Up @@ -92,7 +91,7 @@ func (c *LazyRefreshCache) ConnectionInfo(
"[%v] Connection info refresh operation started",
c.connName.String(),
)
ci, err := c.r.ConnectionInfo(ctx, c.connName, c.key, c.useIAMAuthNDial)
ci, err := c.r.ConnectionInfo(ctx, c.connName, c.useIAMAuthNDial)
if err != nil {
c.logger.Debugf(
ctx,
Expand Down
30 changes: 17 additions & 13 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,44 +243,48 @@ func fetchEphemeralCert(
return c, nil
}

// newRefresher creates a Refresher.
func newRefresher(
// newAdminAPIClient creates a Refresher.
func newAdminAPIClient(
l debug.ContextLogger,
svc *sqladmin.Service,
key *rsa.PrivateKey,
ts oauth2.TokenSource,
dialerID string,
) refresher {
return refresher{
) adminAPIClient {
return adminAPIClient{
dialerID: dialerID,
logger: l,
key: key,
client: svc,
ts: ts,
}
}

// refresher manages the SQL Admin API access to instance metadata and to
// adminAPIClient manages the SQL Admin API access to instance metadata and to
// ephemeral certificates.
type refresher struct {
type adminAPIClient struct {
// dialerID is the unique ID of the associated dialer.
dialerID string
logger debug.ContextLogger
client *sqladmin.Service
// key is used to generate the client certificate
key *rsa.PrivateKey
client *sqladmin.Service
// ts is the TokenSource used for IAM DB AuthN.
ts oauth2.TokenSource
}

// ConnectionInfo immediately performs a full refresh operation using the Cloud
// SQL Admin API.
func (r refresher) ConnectionInfo(
ctx context.Context, cn instance.ConnName, k *rsa.PrivateKey, iamAuthNDial bool,
func (c adminAPIClient) ConnectionInfo(
ctx context.Context, cn instance.ConnName, iamAuthNDial bool,
) (ci ConnectionInfo, err error) {

var refreshEnd trace.EndSpanFunc
ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.RefreshConnection",
trace.AddInstanceName(cn.String()),
)
defer func() {
go trace.RecordRefreshResult(context.Background(), cn.String(), r.dialerID, err)
go trace.RecordRefreshResult(context.Background(), cn.String(), c.dialerID, err)
refreshEnd(err)
}()

Expand All @@ -292,7 +296,7 @@ func (r refresher) ConnectionInfo(
mdC := make(chan mdRes, 1)
go func() {
defer close(mdC)
md, err := fetchMetadata(ctx, r.client, cn)
md, err := fetchMetadata(ctx, c.client, cn)
mdC <- mdRes{md, err}
}()

Expand All @@ -306,9 +310,9 @@ func (r refresher) ConnectionInfo(
defer close(ecC)
var iamTS oauth2.TokenSource
if iamAuthNDial {
iamTS = r.ts
iamTS = c.ts
}
ec, err := fetchEphemeralCert(ctx, r.client, cn, k, iamTS)
ec, err := fetchEphemeralCert(ctx, c.client, cn, c.key, iamTS)
ecC <- ecRes{ec, err}
}()

Expand Down
38 changes: 19 additions & 19 deletions internal/cloudsql/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func TestRefresh(t *testing.T) {
}
}()

r := newRefresher(nullLogger{}, client, nil, testDialerID)
rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID)
rr, err := r.ConnectionInfo(context.Background(), cn, false)
if err != nil {
t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
}
Expand Down Expand Up @@ -118,8 +118,8 @@ func TestRefreshWithStaticTokenSource(t *testing.T) {
t.Cleanup(func() { _ = cleanup() })

ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "myaccestoken"})
r := newRefresher(nullLogger{}, client, ts, testDialerID)
ci, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, ts, testDialerID)
ci, err := r.ConnectionInfo(context.Background(), cn, true)
if err != nil {
t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
}
Expand Down Expand Up @@ -154,8 +154,8 @@ func TestRefreshRetries50xResponses(t *testing.T) {
}
}()

r := newRefresher(nullLogger{}, client, nil, testDialerID)
rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID)
rr, err := r.ConnectionInfo(context.Background(), cn, false)
if err != nil {
t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
}
Expand All @@ -179,16 +179,16 @@ func TestRefreshFailsFast(t *testing.T) {
}
defer cleanup()

r := newRefresher(nullLogger{}, client, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, false)
if err != nil {
t.Fatalf("expected no error, got = %v", err)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()
// context is canceled
_, err = r.ConnectionInfo(ctx, cn, RSAKey, false)
_, err = r.ConnectionInfo(ctx, cn, false)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled error, got = %v", err)
}
Expand Down Expand Up @@ -261,8 +261,8 @@ func TestRefreshAdjustsCertExpiry(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
ts := &fakeTokenSource{responses: tc.resps}
r := newRefresher(nullLogger{}, client, ts, testDialerID)
rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, ts, testDialerID)
rr, err := r.ConnectionInfo(context.Background(), cn, true)
if err != nil {
t.Fatalf("want no error, got = %v", err)
}
Expand Down Expand Up @@ -307,8 +307,8 @@ func TestRefreshWithIAMAuthErrors(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
ts := &fakeTokenSource{responses: tc.resps}
r := newRefresher(nullLogger{}, client, ts, testDialerID)
_, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, ts, testDialerID)
_, err := r.ConnectionInfo(context.Background(), cn, true)
if err == nil {
t.Fatalf("expected get failed error, got = %v", err)
}
Expand Down Expand Up @@ -367,8 +367,8 @@ func TestRefreshMetadataConfigError(t *testing.T) {
}
defer cleanup()

r := newRefresher(nullLogger{}, client, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, false)
if !errors.As(err, &tc.wantErr) {
t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
}
Expand Down Expand Up @@ -432,8 +432,8 @@ func TestRefreshMetadataRefreshError(t *testing.T) {
}
defer cleanup()

r := newRefresher(nullLogger{}, client, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, false)
if !errors.As(err, &tc.wantErr) {
t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
}
Expand Down Expand Up @@ -497,8 +497,8 @@ func TestRefreshWithFailedEphemeralCertCall(t *testing.T) {
}
defer cleanup()

r := newRefresher(nullLogger{}, client, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
r := newAdminAPIClient(nullLogger{}, client, RSAKey, nil, testDialerID)
_, err = r.ConnectionInfo(context.Background(), cn, false)

if !errors.As(err, &tc.wantErr) {
t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
Expand Down

0 comments on commit 294d77f

Please sign in to comment.