diff --git a/tapcfg/server.go b/tapcfg/server.go index 31364185d..b260e6d64 100644 --- a/tapcfg/server.go +++ b/tapcfg/server.go @@ -26,13 +26,6 @@ import ( "github.com/lightningnetwork/lnd/signal" ) -// databaseBackend is an interface that contains all methods our different -// database backends implement. -type databaseBackend interface { - tapdb.BatchedQuerier - WithTx(tx *sql.Tx) *sqlc.Queries -} - // genServerConfig generates a server config from the given tapd config. // // NOTE: The RPCConfig and SignalInterceptor fields must be set by the caller @@ -43,7 +36,7 @@ func genServerConfig(cfg *Config, cfgLogger btclog.Logger, var ( err error - db databaseBackend + db tapdb.DatabaseBackend dbType sqlc.BackendType ) diff --git a/tapdb/assets_store.go b/tapdb/assets_store.go index ef6244d8c..362b6d73c 100644 --- a/tapdb/assets_store.go +++ b/tapdb/assets_store.go @@ -621,8 +621,8 @@ func parseAssetWitness(input AssetWitness) (asset.Witness, error) { // dbAssetsToChainAssets maps a set of confirmed assets in the database, and // the witnesses of those assets to a set of normal ChainAsset structs needed // by a higher level application. -func (a *AssetStore) dbAssetsToChainAssets(dbAssets []ConfirmedAsset, - witnesses assetWitnesses) ([]*asset.ChainAsset, error) { +func dbAssetsToChainAssets(dbAssets []ConfirmedAsset, witnesses assetWitnesses, + clock clock.Clock) ([]*asset.ChainAsset, error) { chainAssets := make([]*asset.ChainAsset, len(dbAssets)) for i := range dbAssets { @@ -826,7 +826,7 @@ func (a *AssetStore) dbAssetsToChainAssets(dbAssets []ConfirmedAsset, owner := sprout.AnchorLeaseOwner expiry := sprout.AnchorLeaseExpiry if len(owner) > 0 && expiry.Valid && - expiry.Time.UTC().After(a.clock.Now().UTC()) { + expiry.Time.UTC().After(clock.Now().UTC()) { copy(chainAssets[i].AnchorLeaseOwner[:], owner) chainAssets[i].AnchorLeaseExpiry = &expiry.Time @@ -1198,7 +1198,7 @@ func (a *AssetStore) FetchAllAssets(ctx context.Context, includeSpent, return nil, dbErr } - return a.dbAssetsToChainAssets(dbAssets, assetWitnesses) + return dbAssetsToChainAssets(dbAssets, assetWitnesses, a.clock) } // FetchManagedUTXOs fetches all UTXOs we manage. @@ -1901,7 +1901,9 @@ func (a *AssetStore) queryChainAssets(ctx context.Context, q ActiveAssetsStore, if err != nil { return nil, err } - matchingAssets, err := a.dbAssetsToChainAssets(dbAssets, assetWitnesses) + matchingAssets, err := dbAssetsToChainAssets( + dbAssets, assetWitnesses, a.clock, + ) if err != nil { return nil, err } diff --git a/tapdb/migrations.go b/tapdb/migrations.go index e49bc6e1d..2377cfd29 100644 --- a/tapdb/migrations.go +++ b/tapdb/migrations.go @@ -2,6 +2,7 @@ package tapdb import ( "bytes" + "database/sql" "errors" "fmt" "io" @@ -14,6 +15,7 @@ import ( "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/source/httpfs" "github.com/lightninglabs/taproot-assets/fn" + "github.com/lightninglabs/taproot-assets/tapdb/sqlc" ) const ( @@ -25,6 +27,13 @@ const ( LatestMigrationVersion = 24 ) +// DatabaseBackend is an interface that contains all methods our different +// Database backends implement. +type DatabaseBackend interface { + BatchedQuerier + WithTx(tx *sql.Tx) *sqlc.Queries +} + // MigrationTarget is a functional option that can be passed to applyMigrations // to specify a target version to migrate to. `currentDbVersion` is the current // (migration) version of the database, or None if unknown. @@ -115,9 +124,10 @@ func (m *migrationLogger) Verbose() bool { // applyMigrations executes database migration files found in the given file // system under the given path, using the passed database driver and database -// name, up to or down to the given target version. +// name, up to or down to the given target version. The boolean return value +// indicates whether any migrations were applied. func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, - targetVersion MigrationTarget, opts *migrateOptions) error { + targetVersion MigrationTarget, opts *migrateOptions) (bool, error) { // With the migrate instance open, we'll create a new migration source // using the embedded file system stored in sqlSchemas. The library @@ -125,7 +135,7 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, // in this intermediate layer. migrateFileServer, err := httpfs.New(http.FS(fs), path) if err != nil { - return err + return false, err } // Finally, we'll run the migration with our driver above based on the @@ -135,7 +145,7 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, "migrations", migrateFileServer, dbName, driver, ) if err != nil { - return err + return false, err } migrationVersion, _, _ := sqlMigrate.Version() @@ -144,38 +154,44 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, // prevent that without explicit accounting. latestVersion := opts.latestVersion.UnwrapOr(LatestMigrationVersion) if migrationVersion > latestVersion { - return fmt.Errorf("%w: database version is newer than the "+ - "latest migration version, preventing downgrade: "+ + return false, fmt.Errorf("%w: database version is newer than "+ + "the latest migration version, preventing downgrade: "+ "db_version=%v, latest_migration_version=%v", ErrMigrationDowngrade, migrationVersion, latestVersion) } // Report the current version of the database before the migration. - currentDbVersion, _, err := driver.Version() + versionBeforeMigration, _, err := driver.Version() if err != nil { - return fmt.Errorf("unable to get current db version: %w", err) + return false, fmt.Errorf("unable to get current db version: %w", + err) } log.Infof("Attempting to apply migration(s) "+ "(current_db_version=%v, latest_migration_version=%v)", - currentDbVersion, latestVersion) + versionBeforeMigration, latestVersion) // Apply our local logger to the migration instance. sqlMigrate.Log = &migrationLogger{log} // Execute the migration based on the target given. - err = targetVersion(sqlMigrate, currentDbVersion, latestVersion) + err = targetVersion(sqlMigrate, versionBeforeMigration, latestVersion) if err != nil && !errors.Is(err, migrate.ErrNoChange) { - return err + return false, err } + // If we actually did migrate, we'll now run the Golang based + // post-migration checks that ensure the database is in a consistent + // state, based on properties not fully expressible in SQL. + // Report the current version of the database after the migration. - currentDbVersion, _, err = driver.Version() + versionAfterMigration, _, err := driver.Version() if err != nil { - return fmt.Errorf("unable to get current db version: %w", err) + return true, fmt.Errorf("unable to get current db version: %w", + err) } - log.Infof("Database version after migration: %v", currentDbVersion) + log.Infof("Database version after migration: %v", versionAfterMigration) - return nil + return true, nil } // replacerFS is an implementation of a fs.FS virtual file system that wraps an diff --git a/tapdb/post_migration_checks.go b/tapdb/post_migration_checks.go new file mode 100644 index 000000000..f4f382d4c --- /dev/null +++ b/tapdb/post_migration_checks.go @@ -0,0 +1,150 @@ +package tapdb + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/lightninglabs/taproot-assets/asset" + "github.com/lightninglabs/taproot-assets/fn" + "github.com/lightninglabs/taproot-assets/tapdb/sqlc" + "github.com/lightninglabs/taproot-assets/tapscript" + "github.com/lightningnetwork/lnd/clock" +) + +// postMigrationCheck is a function type for a function that performs a +// post-migration check on the database. +type postMigrationCheck func(context.Context, sqlc.Querier) error + +var ( + // postMigrationChecks is a list of functions that are run after the + // database migrations have been applied. These functions are used to + // perform additional checks on the database state that are not fully + // expressible in SQL. + postMigrationChecks = []postMigrationCheck{ + detectScriptKeyType, + } +) + +// runPostMigrationChecks runs a set of post-migration checks on the database +// using the given database backend. +func runPostMigrationChecks(db DatabaseBackend) error { + var ( + ctx = context.Background() + txDb = NewTransactionExecutor( + db, func(tx *sql.Tx) sqlc.Querier { + return db.WithTx(tx) + }, + ) + writeTxOpts AssetStoreTxOptions + ) + + return txDb.ExecTx(ctx, &writeTxOpts, func(q sqlc.Querier) error { + log.Infof("Running %d post-migration checks", + len(postMigrationChecks)) + start := time.Now() + + for _, check := range postMigrationChecks { + err := check(ctx, q) + if err != nil { + return err + } + } + + log.Infof("Post-migration checks completed in %v", + time.Since(start)) + + return nil + }) +} + +// detectScriptKeyType attempts to detect the type of the script keys that don't +// have a type set yet. +func detectScriptKeyType(ctx context.Context, q sqlc.Querier) error { + defaultClock := clock.NewDefaultClock() + + // We start by fetching all assets, even the spent ones. We then collect + // a list of the burn keys from the assets (because burn keys can only + // be calculated from the asset's witness). + assetFilter := QueryAssetFilters{ + Now: sql.NullTime{ + Time: defaultClock.Now().UTC(), + Valid: true, + }, + } + dbAssets, assetWitnesses, err := fetchAssetsWithWitness( + ctx, q, assetFilter, + ) + if err != nil { + return fmt.Errorf("error fetching assets: %w", err) + } + + chainAssets, err := dbAssetsToChainAssets( + dbAssets, assetWitnesses, defaultClock, + ) + if err != nil { + return fmt.Errorf("error converting assets: %w", err) + } + + burnAssets := fn.Filter(chainAssets, func(a *asset.ChainAsset) bool { + return a.IsBurn() + }) + burnKeys := make(map[asset.SerializedKey]struct{}) + for _, a := range burnAssets { + serializedKey := asset.ToSerialized(a.ScriptKey.PubKey) + burnKeys[serializedKey] = struct{}{} + } + + untypedKeys, err := q.FetchUnknownTypeScriptKeys(ctx) + if err != nil { + return fmt.Errorf("error fetching script keys: %w", err) + } + + channelFundingKey := asset.NewScriptKey( + tapscript.NewChannelFundingScriptTree().TaprootKey, + ).PubKey + + for _, k := range untypedKeys { + scriptKey, err := parseScriptKey(k.InternalKey, k.ScriptKey) + if err != nil { + return fmt.Errorf("error parsing script key: %w", err) + } + + serializedKey := asset.ToSerialized(scriptKey.PubKey) + newType := asset.ScriptKeyUnknown + + if _, ok := burnKeys[serializedKey]; ok { + newType = asset.ScriptKeyBurn + } else { + guessedType := scriptKey.GuessType() + if guessedType == asset.ScriptKeyBip86 { + newType = asset.ScriptKeyBip86 + } + + if guessedType == asset.ScriptKeyScriptPathExternal && + scriptKey.PubKey.IsEqual(channelFundingKey) { + + newType = asset.ScriptKeyScriptPathChannel + } + } + + // If we were able to identify the key type, we update the key + // in the database. + if newType != asset.ScriptKeyUnknown { + _, err := q.UpsertScriptKey(ctx, NewScriptKey{ + InternalKeyID: k.InternalKey.KeyID, + TweakedScriptKey: k.ScriptKey.TweakedScriptKey, + Tweak: k.ScriptKey.Tweak, + DeclaredKnown: k.ScriptKey.DeclaredKnown, + KeyType: sqlInt16(newType), + }) + if err != nil { + return fmt.Errorf("error updating script key "+ + "type: %w", err) + } + } + } + + return nil +} diff --git a/tapdb/postgres.go b/tapdb/postgres.go index 12f91ad8f..61b17aeeb 100644 --- a/tapdb/postgres.go +++ b/tapdb/postgres.go @@ -158,10 +158,20 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, } postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) - return applyMigrations( + didMigrate, err := applyMigrations( postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target, opts, ) + if err != nil { + return fmt.Errorf("error applying migrations: %w", err) + } + + // Run post-migration checks if we actually did migrate. + if didMigrate { + return runPostMigrationChecks(s) + } + + return nil } // NewTestPostgresDB is a helper function that creates a Postgres database for diff --git a/tapdb/sqlite.go b/tapdb/sqlite.go index 9e3cb4ed8..9c378c799 100644 --- a/tapdb/sqlite.go +++ b/tapdb/sqlite.go @@ -244,9 +244,19 @@ func (s *SqliteStore) ExecuteMigrations(target MigrationTarget, } sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) - return applyMigrations( + didMigrate, err := applyMigrations( sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts, ) + if err != nil { + return fmt.Errorf("error applying migrations: %w", err) + } + + // Run post-migration checks if we actually did migrate. + if didMigrate { + return runPostMigrationChecks(s) + } + + return nil } // NewTestSqliteDB is a helper function that creates an SQLite database for