Skip to content

Commit

Permalink
tapcfg+tapdb: run post migration checks
Browse files Browse the repository at this point in the history
  • Loading branch information
guggero committed Nov 15, 2024
1 parent 1d27ee2 commit bd9d22d
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 30 deletions.
9 changes: 1 addition & 8 deletions tapcfg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ import (
"github.com/lightningnetwork/lnd/signal"
)

// databaseBackend is an interface that contains all methods our different
// database backends implement.
type databaseBackend interface {
tapdb.BatchedQuerier
WithTx(tx *sql.Tx) *sqlc.Queries
}

// genServerConfig generates a server config from the given tapd config.
//
// NOTE: The RPCConfig and SignalInterceptor fields must be set by the caller
Expand All @@ -43,7 +36,7 @@ func genServerConfig(cfg *Config, cfgLogger btclog.Logger,

var (
err error
db databaseBackend
db tapdb.DatabaseBackend
dbType sqlc.BackendType
)

Expand Down
12 changes: 7 additions & 5 deletions tapdb/assets_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,8 @@ func parseAssetWitness(input AssetWitness) (asset.Witness, error) {
// dbAssetsToChainAssets maps a set of confirmed assets in the database, and
// the witnesses of those assets to a set of normal ChainAsset structs needed
// by a higher level application.
func (a *AssetStore) dbAssetsToChainAssets(dbAssets []ConfirmedAsset,
witnesses assetWitnesses) ([]*asset.ChainAsset, error) {
func dbAssetsToChainAssets(dbAssets []ConfirmedAsset, witnesses assetWitnesses,
clock clock.Clock) ([]*asset.ChainAsset, error) {

chainAssets := make([]*asset.ChainAsset, len(dbAssets))
for i := range dbAssets {
Expand Down Expand Up @@ -826,7 +826,7 @@ func (a *AssetStore) dbAssetsToChainAssets(dbAssets []ConfirmedAsset,
owner := sprout.AnchorLeaseOwner
expiry := sprout.AnchorLeaseExpiry
if len(owner) > 0 && expiry.Valid &&
expiry.Time.UTC().After(a.clock.Now().UTC()) {
expiry.Time.UTC().After(clock.Now().UTC()) {

copy(chainAssets[i].AnchorLeaseOwner[:], owner)
chainAssets[i].AnchorLeaseExpiry = &expiry.Time
Expand Down Expand Up @@ -1198,7 +1198,7 @@ func (a *AssetStore) FetchAllAssets(ctx context.Context, includeSpent,
return nil, dbErr
}

return a.dbAssetsToChainAssets(dbAssets, assetWitnesses)
return dbAssetsToChainAssets(dbAssets, assetWitnesses, a.clock)
}

// FetchManagedUTXOs fetches all UTXOs we manage.
Expand Down Expand Up @@ -1901,7 +1901,9 @@ func (a *AssetStore) queryChainAssets(ctx context.Context, q ActiveAssetsStore,
if err != nil {
return nil, err
}
matchingAssets, err := a.dbAssetsToChainAssets(dbAssets, assetWitnesses)
matchingAssets, err := dbAssetsToChainAssets(
dbAssets, assetWitnesses, a.clock,
)
if err != nil {
return nil, err
}
Expand Down
46 changes: 31 additions & 15 deletions tapdb/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tapdb

import (
"bytes"
"database/sql"
"errors"
"fmt"
"io"
Expand All @@ -14,6 +15,7 @@ import (
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source/httpfs"
"github.com/lightninglabs/taproot-assets/fn"
"github.com/lightninglabs/taproot-assets/tapdb/sqlc"
)

const (
Expand All @@ -25,6 +27,13 @@ const (
LatestMigrationVersion = 24
)

// DatabaseBackend is an interface that contains all methods our different
// Database backends implement.
type DatabaseBackend interface {
BatchedQuerier
WithTx(tx *sql.Tx) *sqlc.Queries
}

// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to. `currentDbVersion` is the current
// (migration) version of the database, or None if unknown.
Expand Down Expand Up @@ -115,17 +124,18 @@ func (m *migrationLogger) Verbose() bool {

// applyMigrations executes database migration files found in the given file
// system under the given path, using the passed database driver and database
// name, up to or down to the given target version.
// name, up to or down to the given target version. The boolean return value
// indicates whether any migrations were applied.
func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string,
targetVersion MigrationTarget, opts *migrateOptions) error {
targetVersion MigrationTarget, opts *migrateOptions) (bool, error) {

// With the migrate instance open, we'll create a new migration source
// using the embedded file system stored in sqlSchemas. The library
// we're using can't handle a raw file system interface, so we wrap it
// in this intermediate layer.
migrateFileServer, err := httpfs.New(http.FS(fs), path)
if err != nil {
return err
return false, err
}

// Finally, we'll run the migration with our driver above based on the
Expand All @@ -135,7 +145,7 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string,
"migrations", migrateFileServer, dbName, driver,
)
if err != nil {
return err
return false, err
}

migrationVersion, _, _ := sqlMigrate.Version()
Expand All @@ -144,38 +154,44 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string,
// prevent that without explicit accounting.
latestVersion := opts.latestVersion.UnwrapOr(LatestMigrationVersion)
if migrationVersion > latestVersion {
return fmt.Errorf("%w: database version is newer than the "+
"latest migration version, preventing downgrade: "+
return false, fmt.Errorf("%w: database version is newer than "+
"the latest migration version, preventing downgrade: "+
"db_version=%v, latest_migration_version=%v",
ErrMigrationDowngrade, migrationVersion, latestVersion)
}

// Report the current version of the database before the migration.
currentDbVersion, _, err := driver.Version()
versionBeforeMigration, _, err := driver.Version()
if err != nil {
return fmt.Errorf("unable to get current db version: %w", err)
return false, fmt.Errorf("unable to get current db version: %w",
err)
}
log.Infof("Attempting to apply migration(s) "+
"(current_db_version=%v, latest_migration_version=%v)",
currentDbVersion, latestVersion)
versionBeforeMigration, latestVersion)

// Apply our local logger to the migration instance.
sqlMigrate.Log = &migrationLogger{log}

// Execute the migration based on the target given.
err = targetVersion(sqlMigrate, currentDbVersion, latestVersion)
err = targetVersion(sqlMigrate, versionBeforeMigration, latestVersion)
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return err
return false, err
}

// If we actually did migrate, we'll now run the Golang based
// post-migration checks that ensure the database is in a consistent
// state, based on properties not fully expressible in SQL.

// Report the current version of the database after the migration.
currentDbVersion, _, err = driver.Version()
versionAfterMigration, _, err := driver.Version()
if err != nil {
return fmt.Errorf("unable to get current db version: %w", err)
return true, fmt.Errorf("unable to get current db version: %w",
err)
}
log.Infof("Database version after migration: %v", currentDbVersion)
log.Infof("Database version after migration: %v", versionAfterMigration)

return nil
return true, nil
}

// replacerFS is an implementation of a fs.FS virtual file system that wraps an
Expand Down
150 changes: 150 additions & 0 deletions tapdb/post_migration_checks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package tapdb

import (
"context"
"database/sql"
"fmt"
"time"

"github.com/lightninglabs/taproot-assets/asset"
"github.com/lightninglabs/taproot-assets/fn"
"github.com/lightninglabs/taproot-assets/tapdb/sqlc"
"github.com/lightninglabs/taproot-assets/tapscript"
"github.com/lightningnetwork/lnd/clock"
)

// postMigrationCheck is a function type for a function that performs a
// post-migration check on the database.
type postMigrationCheck func(context.Context, sqlc.Querier) error

var (
// postMigrationChecks is a list of functions that are run after the
// database migrations have been applied. These functions are used to
// perform additional checks on the database state that are not fully
// expressible in SQL.
postMigrationChecks = []postMigrationCheck{
detectScriptKeyType,
}
)

// runPostMigrationChecks runs a set of post-migration checks on the database
// using the given database backend.
func runPostMigrationChecks(db DatabaseBackend) error {
var (
ctx = context.Background()
txDb = NewTransactionExecutor(
db, func(tx *sql.Tx) sqlc.Querier {
return db.WithTx(tx)
},
)
writeTxOpts AssetStoreTxOptions
)

return txDb.ExecTx(ctx, &writeTxOpts, func(q sqlc.Querier) error {
log.Infof("Running %d post-migration checks",
len(postMigrationChecks))
start := time.Now()

for _, check := range postMigrationChecks {
err := check(ctx, q)
if err != nil {
return err
}
}

log.Infof("Post-migration checks completed in %v",
time.Since(start))

return nil
})
}

// detectScriptKeyType attempts to detect the type of the script keys that don't
// have a type set yet.
func detectScriptKeyType(ctx context.Context, q sqlc.Querier) error {
defaultClock := clock.NewDefaultClock()

// We start by fetching all assets, even the spent ones. We then collect
// a list of the burn keys from the assets (because burn keys can only
// be calculated from the asset's witness).
assetFilter := QueryAssetFilters{
Now: sql.NullTime{
Time: defaultClock.Now().UTC(),
Valid: true,
},
}
dbAssets, assetWitnesses, err := fetchAssetsWithWitness(
ctx, q, assetFilter,
)
if err != nil {
return fmt.Errorf("error fetching assets: %w", err)
}

chainAssets, err := dbAssetsToChainAssets(
dbAssets, assetWitnesses, defaultClock,
)
if err != nil {
return fmt.Errorf("error converting assets: %w", err)
}

burnAssets := fn.Filter(chainAssets, func(a *asset.ChainAsset) bool {
return a.IsBurn()
})
burnKeys := make(map[asset.SerializedKey]struct{})
for _, a := range burnAssets {
serializedKey := asset.ToSerialized(a.ScriptKey.PubKey)
burnKeys[serializedKey] = struct{}{}
}

untypedKeys, err := q.FetchUnknownTypeScriptKeys(ctx)
if err != nil {
return fmt.Errorf("error fetching script keys: %w", err)
}

channelFundingKey := asset.NewScriptKey(
tapscript.NewChannelFundingScriptTree().TaprootKey,
).PubKey

for _, k := range untypedKeys {
scriptKey, err := parseScriptKey(k.InternalKey, k.ScriptKey)
if err != nil {
return fmt.Errorf("error parsing script key: %w", err)
}

serializedKey := asset.ToSerialized(scriptKey.PubKey)
newType := asset.ScriptKeyUnknown

if _, ok := burnKeys[serializedKey]; ok {
newType = asset.ScriptKeyBurn
} else {
guessedType := scriptKey.GuessType()
if guessedType == asset.ScriptKeyBip86 {
newType = asset.ScriptKeyBip86
}

if guessedType == asset.ScriptKeyScriptPathExternal &&
scriptKey.PubKey.IsEqual(channelFundingKey) {

newType = asset.ScriptKeyScriptPathChannel
}
}

// If we were able to identify the key type, we update the key
// in the database.
if newType != asset.ScriptKeyUnknown {
_, err := q.UpsertScriptKey(ctx, NewScriptKey{
InternalKeyID: k.InternalKey.KeyID,
TweakedScriptKey: k.ScriptKey.TweakedScriptKey,
Tweak: k.ScriptKey.Tweak,
DeclaredKnown: k.ScriptKey.DeclaredKnown,
KeyType: sqlInt16(newType),
})
if err != nil {
return fmt.Errorf("error updating script key "+
"type: %w", err)
}
}
}

return nil
}
12 changes: 11 additions & 1 deletion tapdb/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,20 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget,
}

postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements)
return applyMigrations(
didMigrate, err := applyMigrations(
postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target,
opts,
)
if err != nil {
return fmt.Errorf("error applying migrations: %w", err)
}

// Run post-migration checks if we actually did migrate.
if didMigrate {
return runPostMigrationChecks(s)
}

return nil
}

// NewTestPostgresDB is a helper function that creates a Postgres database for
Expand Down
12 changes: 11 additions & 1 deletion tapdb/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,19 @@ func (s *SqliteStore) ExecuteMigrations(target MigrationTarget,
}

sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements)
return applyMigrations(
didMigrate, err := applyMigrations(
sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts,
)
if err != nil {
return fmt.Errorf("error applying migrations: %w", err)
}

// Run post-migration checks if we actually did migrate.
if didMigrate {
return runPostMigrationChecks(s)
}

return nil
}

// NewTestSqliteDB is a helper function that creates an SQLite database for
Expand Down

0 comments on commit bd9d22d

Please sign in to comment.