Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: use a closure to wrap transactions #469

Merged
merged 15 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.DS_Store

/.dapper
/.cache
/certs
Expand All @@ -6,3 +8,7 @@
*.swp
.idea
steve

informer_object_cache.db
informer_object_cache.db-shm
informer_object_cache.db-wal
120 changes: 72 additions & 48 deletions pkg/sqlcache/db/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (
"reflect"
"sync"

"github.com/pkg/errors"
"errors"

"github.com/rancher/steve/pkg/sqlcache/db/transaction"

// needed for drivers
Expand All @@ -29,8 +30,59 @@ const (
informerObjectCachePerms fs.FileMode = 0o600
)

// Client is a database client that provides encrypting, decrypting, and database resetting.
type Client struct {
// Client defines a database client that provides encrypting, decrypting, and database resetting
type Client interface {
WithTransaction(ctx context.Context, forWriting bool, f WithTransactionFunction) error
Prepare(stmt string) *sql.Stmt
QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error)
ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error)
ReadStrings(rows Rows) ([]string, error)
ReadInt(rows Rows) (int, error)
Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error
CloseStmt(closable Closable) error
NewConnection() error
}

// WithTransaction runs f within a transaction.
//
// If forWriting is true, this method blocks until all other concurrent forWriting
// transactions have either committed or rolled back.
// If forWriting is false, it is assumed the returned transaction will exclusively
// be used for DQL (e.g. SELECT) queries.
// Not respecting the above rule might result in transactions failing with unexpected
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
// See discussion in https://github.com/rancher/lasso/pull/98 for details
//
// The transaction is committed if f returns nil, otherwise it is rolled back.
func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTransactionFunction) error {
c.connLock.RLock()
// note: this assumes _txlock=immediate in the connection string, see NewConnection
tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
ReadOnly: !forWriting,
})
c.connLock.RUnlock()
if err != nil {
return err
}

err = f(transaction.NewClient(tx))

if err != nil {
rerr := tx.Rollback()
err = errors.Join(err, rerr)
moio marked this conversation as resolved.
Show resolved Hide resolved
} else {
cerr := tx.Commit()
err = errors.Join(err, cerr)
}

return err
}

// WithTransactionFunction is a function that uses a transaction
type WithTransactionFunction func(tx transaction.Client) error

// client is the main implementation of Client. Other implementations exist for test purposes
type client struct {
conn Connection
connLock sync.RWMutex
encryptor Encryptor
Expand Down Expand Up @@ -74,15 +126,6 @@ func (e *QueryError) Unwrap() error {
return e.Err
}

// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed.
type TXClient interface {
StmtExec(stmt transaction.Stmt, args ...any) error
Exec(stmt string, args ...any) error
Commit() error
Stmt(stmt *sql.Stmt) transaction.Stmt
Cancel() error
}

// Encryptor encrypts data with a key which is rotated to avoid wear-out.
type Encryptor interface {
// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead.
Expand All @@ -95,9 +138,9 @@ type Decryptor interface {
Decrypt([]byte, []byte, uint32) ([]byte, error)
}

// NewClient returns a Client. If the given connection is nil then a default one will be created.
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) {
client := &Client{
// NewClient returns a client. If the given connection is nil then a default one will be created.
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (Client, error) {
client := &client{
encryptor: encryptor,
decryptor: decryptor,
}
Expand All @@ -114,19 +157,19 @@ func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client,
}

// Prepare prepares the given string into a sql statement on the client's connection.
func (c *Client) Prepare(stmt string) *sql.Stmt {
func (c *client) Prepare(stmt string) *sql.Stmt {
c.connLock.RLock()
defer c.connLock.RUnlock()
prepared, err := c.conn.Prepare(stmt)
if err != nil {
panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err))
panic(fmt.Errorf("Error preparing statement: %s\n%w", stmt, err))
}
return prepared
}

// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried
// given a sqlite busy error.
func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
func (c *client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()

Expand All @@ -135,13 +178,13 @@ func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params

// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant
// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection.
func (c *Client) CloseStmt(closable Closable) error {
func (c *client) CloseStmt(closable Closable) error {
return closable.Close()
}

// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type,
// and returns a slice of those objects.
func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
func (c *client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()

Expand Down Expand Up @@ -171,7 +214,7 @@ func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([
}

// ReadStrings scans the given rows into strings, and then returns the strings as a slice.
func (c *Client) ReadStrings(rows Rows) ([]string, error) {
func (c *client) ReadStrings(rows Rows) ([]string, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()

Expand Down Expand Up @@ -199,7 +242,7 @@ func (c *Client) ReadStrings(rows Rows) ([]string, error) {
}

// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries)
func (c *Client) ReadInt(rows Rows) (int, error) {
func (c *client) ReadInt(rows Rows) (int, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()

Expand All @@ -226,28 +269,7 @@ func (c *Client) ReadInt(rows Rows) (int, error) {
return result, nil
}

// BeginTx attempts to begin a transaction.
// If forWriting is true, this method blocks until all other concurrent forWriting
// transactions have either committed or rolled back.
// If forWriting is false, it is assumed the returned transaction will exclusively
// be used for DQL (e.g. SELECT) queries.
// Not respecting the above rule might result in transactions failing with unexpected
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
// See discussion in https://github.com/rancher/lasso/pull/98 for details
func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
// note: this assumes _txlock=immediate in the connection string, see NewConnection
sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
ReadOnly: !forWriting,
})
if err != nil {
return nil, err
}
return transaction.NewClient(sqlTx), nil
}

func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
func (c *client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
var data, dataNonce sql.RawBytes
var kid uint32
err := rows.Scan(&data, &dataNonce, &kid)
Expand All @@ -264,8 +286,9 @@ func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
return data, nil
}

// Upsert used to be called upsertEncrypted in store package before move
func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
// Upsert executes an upsert statement encrypting arguments if necessary
// note the statement should have 4 parameters: key, objBytes, dataNonce, kid
func (c *client) Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
objBytes := toBytes(obj)
var dataNonce []byte
var err error
Expand All @@ -277,7 +300,8 @@ func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, should
}
}

return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid)
_, err = tx.Stmt(stmt).Exec(key, objBytes, dataNonce, kid)
return err
}

// toBytes encodes an object to a byte slice
Expand Down Expand Up @@ -312,7 +336,7 @@ func closeRowsOnError(rows Rows, err error) error {

// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently
// creates new files.
func (c *Client) NewConnection() error {
func (c *client) NewConnection() error {
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {
Expand Down
80 changes: 15 additions & 65 deletions pkg/sqlcache/db/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ import (
"reflect"
"testing"

"github.com/rancher/steve/pkg/sqlcache/db/transaction"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)

// Mocks for this test are generated with the following command.
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Client,Stmt

type testStoreObject struct {
Id string
Expand All @@ -37,7 +38,7 @@ func TestNewClient(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
expectedClient := &Client{
expectedClient := &client{
conn: c,
encryptor: e,
decryptor: d,
Expand Down Expand Up @@ -389,58 +390,6 @@ func TestReadInt(t *testing.T) {
}
}

func TestBegin(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}

var tests []testCase

// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)

sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), false)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)

sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), true)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)

c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.BeginTx(context.Background(), false)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}

func TestUpsert(t *testing.T) {
type testCase struct {
description string
Expand All @@ -459,14 +408,14 @@ func TestUpsert(t *testing.T) {
d := SetupMockDecryptor(t)

client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil)
stmt.EXPECT().Exec("somekey", testByteValue, testByteValue, keyID).Return(nil, nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.Nil(t, err)
},
Expand All @@ -477,7 +426,7 @@ func TestUpsert(t *testing.T) {
d := SetupMockDecryptor(t)

client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
testObjBytes := toBytes(testObject)
e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error"))
Expand All @@ -491,14 +440,14 @@ func TestUpsert(t *testing.T) {
d := SetupMockDecryptor(t)

client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error"))
stmt.EXPECT().Exec("somekey", testByteValue, testByteValue, keyID).Return(nil, fmt.Errorf("error"))
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.NotNil(t, err)
},
Expand All @@ -509,13 +458,13 @@ func TestUpsert(t *testing.T) {
e := SetupMockEncryptor(t)

client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
var testByteValue []byte
testObjBytes := toBytes(testObject)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil)
stmt.EXPECT().Exec("somekey", testObjBytes, testByteValue, uint32(0)).Return(nil, nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, false)
assert.Nil(t, err)
},
Expand Down Expand Up @@ -582,9 +531,10 @@ func TestNewConnection(t *testing.T) {
assert.Nil(t, err)

// Create a transaction to ensure that the file is written to disk.
txC, err := client.BeginTx(context.Background(), false)
err = client.WithTransaction(context.Background(), false, func(tx transaction.Client) error {
return nil
})
assert.NoError(t, err)
assert.NoError(t, txC.Commit())

assert.FileExists(t, InformerObjectCacheDBPath)
assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600)
Expand Down Expand Up @@ -630,7 +580,7 @@ func SetupMockRows(t *testing.T) *MockRows {
return MockR
}

func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client {
func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) Client {
c, _ := NewClient(connection, encryptor, decryptor)
return c
}
Expand Down
Loading