diff --git a/config_builder.go b/config_builder.go index afafbed754..d32f01a115 100644 --- a/config_builder.go +++ b/config_builder.go @@ -51,6 +51,7 @@ import ( "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" "github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/walletunlocker" "github.com/lightningnetwork/lnd/watchtower" @@ -60,6 +61,16 @@ import ( "gopkg.in/macaroon-bakery.v2/bakery" ) +const ( + // invoiceMigrationBatchSize is the number of invoices that will be + // migrated in a single batch. + invoiceMigrationBatchSize = 1000 + + // invoiceMigration is the version of the migration that will be used to + // migrate invoices from the kvdb to the sql database. + invoiceMigration = 6 +) + // GrpcRegistrar is an interface that must be satisfied by an external subserver // that wants to be able to register its own gRPC server onto lnd's main // grpc.Server instance. @@ -932,10 +943,10 @@ type DatabaseInstances struct { // the btcwallet's loader. WalletDB btcwallet.LoaderOption - // NativeSQLStore is a pointer to a native SQL store that can be used - // for native SQL queries for tables that already support it. This may - // be nil if the use-native-sql flag was not set. - NativeSQLStore *sqldb.BaseDB + // NativeSQLStore holds a reference to the native SQL store that can + // be used for native SQL queries for tables that already support it. + // This may be nil if the use-native-sql flag was not set. + NativeSQLStore sqldb.DB } // DefaultDatabaseBuilder is a type that builds the default database backends @@ -1038,7 +1049,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( if err != nil { cleanUp() - err := fmt.Errorf("unable to open graph DB: %w", err) + err = fmt.Errorf("unable to open graph DB: %w", err) d.logger.Error(err) return nil, nil, err @@ -1072,51 +1083,69 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( case err != nil: cleanUp() - err := fmt.Errorf("unable to open graph DB: %w", err) + err = fmt.Errorf("unable to open graph DB: %w", err) d.logger.Error(err) return nil, nil, err } - // Instantiate a native SQL invoice store if the flag is set. + // Instantiate a native SQL store if the flag is set. if d.cfg.DB.UseNativeSQL { - // KV invoice db resides in the same database as the channel - // state DB. Let's query the database to see if we have any - // invoices there. If we do, we won't allow the user to start - // lnd with native SQL enabled, as we don't currently migrate - // the invoices to the new database schema. - invoiceSlice, err := dbs.ChanStateDB.QueryInvoices( - ctx, invoices.InvoiceQuery{ - NumMaxInvoices: 1, - }, - ) - if err != nil { - cleanUp() - d.logger.Errorf("Unable to query KV invoice DB: %v", - err) + migrations := sqldb.GetMigrations() + + // If the user has not explicitly disabled the SQL invoice + // migration, attach the custom migration function to invoice + // migration (version 6). Even if this custom migration is + // disabled, the regular native SQL store migrations will still + // run. If the database version is already above this custom + // migration's version (6), it will be skipped permanently, + // regardless of the flag. + if !d.cfg.DB.SkipSQLInvoiceMigration { + migrationFn := func(tx *sqlc.Queries) error { + return invoices.MigrateInvoicesToSQL( + ctx, dbs.ChanStateDB.Backend, + dbs.ChanStateDB, tx, + invoiceMigrationBatchSize, + ) + } - return nil, nil, err + // Make sure we attach the custom migration function to + // the correct migration version. + for i := 0; i < len(migrations); i++ { + if migrations[i].Version != invoiceMigration { + continue + } + + migrations[i].MigrationFn = migrationFn + } } - if len(invoiceSlice.Invoices) > 0 { + // We need to apply all migrations to the native SQL store + // before we can use it. + err = dbs.NativeSQLStore.ApplyAllMigrations(ctx, migrations) + if err != nil { cleanUp() - err := fmt.Errorf("found invoices in the KV invoice " + - "DB, migration to native SQL is not yet " + - "supported") + err = fmt.Errorf("faild to run migrations for the "+ + "native SQL store: %w", err) d.logger.Error(err) return nil, nil, err } + // With the DB ready and migrations applied, we can now create + // the base DB and transaction executor for the native SQL + // invoice store. + baseDB := dbs.NativeSQLStore.GetBaseDB() executor := sqldb.NewTransactionExecutor( - dbs.NativeSQLStore, - func(tx *sql.Tx) invoices.SQLInvoiceQueries { - return dbs.NativeSQLStore.WithTx(tx) + baseDB, func(tx *sql.Tx) invoices.SQLInvoiceQueries { + return baseDB.WithTx(tx) }, ) - dbs.InvoiceDB = invoices.NewSQLStore( + sqlInvoiceDB := invoices.NewSQLStore( executor, clock.NewDefaultClock(), ) + + dbs.InvoiceDB = sqlInvoiceDB } else { dbs.InvoiceDB = dbs.ChanStateDB } @@ -1129,7 +1158,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( if err != nil { cleanUp() - err := fmt.Errorf("unable to open %s database: %w", + err = fmt.Errorf("unable to open %s database: %w", lncfg.NSTowerClientDB, err) d.logger.Error(err) return nil, nil, err @@ -1144,7 +1173,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( if err != nil { cleanUp() - err := fmt.Errorf("unable to open %s database: %w", + err = fmt.Errorf("unable to open %s database: %w", lncfg.NSTowerServerDB, err) d.logger.Error(err) return nil, nil, err diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 4fc64028d3..21654cee2a 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -241,6 +241,11 @@ The underlying functionality between those two options remain the same. transactions can run at once, increasing efficiency. Includes several bugfixes to allow this to work properly. +* [Migrate KV invoices to + SQL](https://github.com/lightningnetwork/lnd/pull/8831) as part of a larger + effort to support SQL databases natively in LND. + + ## Code Health * A code refactor that [moves all the graph related DB code out of the @@ -265,6 +270,7 @@ The underlying functionality between those two options remain the same. * Abdullahi Yunus * Alex Akselrod +* Andras Banki-Horvath * Animesh Bilthare * Boris Nagaev * Carla Kirk-Cohen diff --git a/go.mod b/go.mod index 7680509fd3..c660cbb5af 100644 --- a/go.mod +++ b/go.mod @@ -138,7 +138,7 @@ require ( github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opencontainers/runc v1.1.12 // indirect github.com/ory/dockertest/v3 v3.10.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.26.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect @@ -207,6 +207,10 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2 // allows us to specify that as an option. replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display +// Temporary replace until https://github.com/lightningnetwork/lnd/pull/8831 is +// merged. +replace github.com/lightningnetwork/lnd/sqldb => ./sqldb + // If you change this please also update docs/INSTALL.md and GO_VERSION in // Makefile (then run `make lint` to see where else it needs to be updated as // well). diff --git a/go.sum b/go.sum index 5ed9cd2046..72bf21ab6e 100644 --- a/go.sum +++ b/go.sum @@ -464,8 +464,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJl github.com/lightningnetwork/lnd/kvdb v1.4.12/go.mod h1:hx9buNcxsZpZwh8m1sjTQwy2SOeBoWWOZ3RnOQkMsxI= github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI= github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4= -github.com/lightningnetwork/lnd/sqldb v1.0.6 h1:LJdDSVdN33bVBIefsaJlPW9PDAm6GrXlyFucmzSJ3Ts= -github.com/lightningnetwork/lnd/sqldb v1.0.6/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY= diff --git a/invoices/invoices.go b/invoices/invoices.go index c48629c583..32164cbe17 100644 --- a/invoices/invoices.go +++ b/invoices/invoices.go @@ -187,6 +187,11 @@ func (r InvoiceRef) Modifier() RefModifier { return r.refModifier } +// IsHashOnly returns true if the invoice ref only contains a payment hash. +func (r InvoiceRef) IsHashOnly() bool { + return r.payHash != nil && r.payAddr == nil && r.setID == nil +} + // String returns a human-readable representation of an InvoiceRef. func (r InvoiceRef) String() string { var ids []string diff --git a/invoices/kv_sql_migration_test.go b/invoices/kv_sql_migration_test.go new file mode 100644 index 0000000000..b3048a17bf --- /dev/null +++ b/invoices/kv_sql_migration_test.go @@ -0,0 +1,203 @@ +package invoices_test + +import ( + "context" + "database/sql" + "os" + "path" + "testing" + "time" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" + invpkg "github.com/lightningnetwork/lnd/invoices" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/kvdb/sqlbase" + "github.com/lightningnetwork/lnd/kvdb/sqlite" + "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/stretchr/testify/require" +) + +// TestMigrationWithChannelDB tests the migration of invoices from a bolt backed +// channel.db to a SQL database. Note that this test does not attempt to be a +// complete migration test for all invoice types but rather is added as a tool +// for developers and users to debug invoice migration issues with an actual +// channel.db file. +func TestMigrationWithChannelDB(t *testing.T) { + // First create a shared Postgres instance so we don't spawn a new + // docker container for each test. + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + makeSQLDB := func(t *testing.T, sqlite bool) (*invpkg.SQLStore, + *sqldb.TransactionExecutor[*sqlc.Queries]) { + + var db *sqldb.BaseDB + if sqlite { + db = sqldb.NewTestSqliteDB(t).BaseDB + } else { + db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB + } + + invoiceExecutor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries { + return db.WithTx(tx) + }, + ) + + genericExecutor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) *sqlc.Queries { + return db.WithTx(tx) + }, + ) + + testClock := clock.NewTestClock(time.Unix(1, 0)) + + return invpkg.NewSQLStore(invoiceExecutor, testClock), + genericExecutor + } + + migrationTest := func(t *testing.T, kvStore *channeldb.DB, + sqlite bool) { + + sqlInvoiceStore, sqlStore := makeSQLDB(t, sqlite) + ctxb := context.Background() + + const batchSize = 11 + var opts sqldb.MigrationTxOptions + err := sqlStore.ExecTx( + ctxb, &opts, func(tx *sqlc.Queries) error { + return invpkg.MigrateInvoicesToSQL( + ctxb, kvStore.Backend, kvStore, tx, + batchSize, + ) + }, func() {}, + ) + require.NoError(t, err) + + // MigrateInvoices will check if the inserted invoice equals to + // the migrated one, but as a sanity check, we'll also fetch the + // invoices from the store and compare them to the original + // invoices. + query := invpkg.InvoiceQuery{ + IndexOffset: 0, + // As a sanity check, fetch more invoices than we have + // to ensure that we did not add any extra invoices. + // Note that we don't really have a way to know the + // exact number of invoices in the bolt db without first + // iterating over all of them, but for test purposes + // constant should be enough. + NumMaxInvoices: 9999, + } + result1, err := kvStore.QueryInvoices(ctxb, query) + require.NoError(t, err) + numInvoices := len(result1.Invoices) + + result2, err := sqlInvoiceStore.QueryInvoices(ctxb, query) + require.NoError(t, err) + require.Equal(t, numInvoices, len(result2.Invoices)) + + // Simply zero out the add index so we don't fail on that when + // comparing. + for i := 0; i < numInvoices; i++ { + result1.Invoices[i].AddIndex = 0 + result2.Invoices[i].AddIndex = 0 + + // We need to override the timezone of the invoices as + // the provided DB vs the test runners local time zone + // might be different. + invpkg.OverrideInvoiceTimeZone(&result1.Invoices[i]) + invpkg.OverrideInvoiceTimeZone(&result2.Invoices[i]) + + require.Equal( + t, result1.Invoices[i], result2.Invoices[i], + ) + } + } + + tests := []struct { + name string + dbPath string + }{ + { + "empty", + t.TempDir(), + }, + { + "testdata", + "testdata", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + var kvStore *channeldb.DB + + // First check if we have a channel.sqlite file in the + // testdata directory. If we do, we'll use that as the + // channel db for the migration test. + chanDBPath := path.Join( + test.dbPath, lncfg.SqliteChannelDBName, + ) + + // Just some sane defaults for the sqlite config. + const ( + timeout = 5 * time.Second + maxConns = 50 + ) + + sqliteConfig := &sqlite.Config{ + Timeout: timeout, + BusyTimeout: timeout, + MaxConnections: maxConns, + } + + if fileExists(chanDBPath) { + sqlbase.Init(maxConns) + + sqliteBackend, err := kvdb.Open( + kvdb.SqliteBackendName, + context.Background(), + sqliteConfig, test.dbPath, + lncfg.SqliteChannelDBName, + lncfg.NSChannelDB, + ) + + require.NoError(t, err) + kvStore, err = channeldb.CreateWithBackend( + sqliteBackend, + ) + + require.NoError(t, err) + } else { + kvStore = channeldb.OpenForTesting( + t, test.dbPath, + ) + } + + t.Run("Postgres", func(t *testing.T) { + migrationTest(t, kvStore, false) + }) + + t.Run("SQLite", func(t *testing.T) { + migrationTest(t, kvStore, true) + }) + }) + } +} + +func fileExists(filename string) bool { + info, err := os.Stat(filename) + if os.IsNotExist(err) { + return false + } + + return !info.IsDir() +} diff --git a/invoices/sql_migration.go b/invoices/sql_migration.go new file mode 100644 index 0000000000..86cdd53734 --- /dev/null +++ b/invoices/sql_migration.go @@ -0,0 +1,557 @@ +package invoices + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/pmezard/go-difflib/difflib" +) + +var ( + // invoiceBucket is the name of the bucket within the database that + // stores all data related to invoices no matter their final state. + // Within the invoice bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint32. + invoiceBucket = []byte("invoices") + + // paymentHashIndexBucket is the name of the sub-bucket within the + // invoiceBucket which indexes all invoices by their payment hash. The + // payment hash is the sha256 of the invoice's payment preimage. This + // index is used to detect duplicates, and also to provide a fast path + // for looking up incoming HTLCs to determine if we're able to settle + // them fully. + // + // maps: payHash => invoiceKey + invoiceIndexBucket = []byte("paymenthashes") + + // numInvoicesKey is the name of key which houses the auto-incrementing + // invoice ID which is essentially used as a primary key. With each + // invoice inserted, the primary key is incremented by one. This key is + // stored within the invoiceIndexBucket. Within the invoiceBucket + // invoices are uniquely identified by the invoice ID. + numInvoicesKey = []byte("nik") + + // addIndexBucket is an index bucket that we'll use to create a + // monotonically increasing set of add indexes. Each time we add a new + // invoice, this sequence number will be incremented and then populated + // within the new invoice. + // + // In addition to this sequence number, we map: + // + // addIndexNo => invoiceKey + addIndexBucket = []byte("invoice-add-index") + + // ErrMigrationMismatch is returned when the migrated invoice does not + // match the original invoice. + ErrMigrationMismatch = fmt.Errorf("migrated invoice does not match " + + "original invoice") +) + +// createInvoiceHashIndex generates a hash index that contains payment hashes +// for each invoice in the database. Retrieving the payment hash for certain +// invoices, such as those created for spontaneous AMP payments, can be +// challenging because the hash is not directly derivable from the invoice's +// parameters and is stored separately in the `paymenthashes` bucket. This +// bucket maps payment hashes to invoice keys, but for migration purposes, we +// need the ability to query in the reverse direction. This function establishes +// a new index in the SQL database that maps each invoice key to its +// corresponding payment hash. +func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend, + tx *sqlc.Queries) error { + + return db.View(func(kvTx kvdb.RTx) error { + invoices := kvTx.ReadBucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + invoiceIndex := invoices.NestedReadBucket( + invoiceIndexBucket, + ) + if invoiceIndex == nil { + return ErrNoInvoicesCreated + } + + addIndex := invoices.NestedReadBucket(addIndexBucket) + if addIndex == nil { + return ErrNoInvoicesCreated + } + + // First, iterate over all elements in the add index bucket and + // insert the add index value for the corresponding invoice key + // in the payment_hashes table. + err := addIndex.ForEach(func(k, v []byte) error { + // The key is the add index, and the value is + // the invoice key. + addIndexNo := binary.BigEndian.Uint64(k) + invoiceKey := binary.BigEndian.Uint32(v) + + return tx.InsertKVInvoiceKeyAndAddIndex(ctx, + sqlc.InsertKVInvoiceKeyAndAddIndexParams{ + ID: int64(invoiceKey), + AddIndex: int64(addIndexNo), + }, + ) + }) + if err != nil { + return err + } + + // Next, iterate over all hashes in the invoice index bucket and + // set the hash to the corresponding the invoice key in the + // payment_hashes table. + return invoiceIndex.ForEach(func(k, v []byte) error { + // Skip the special numInvoicesKey as that does + // not point to a valid invoice. + if bytes.Equal(k, numInvoicesKey) { + return nil + } + + // The key is the payment hash, and the value + // is the invoice key. + if len(k) != lntypes.HashSize { + return fmt.Errorf("invalid payment "+ + "hash length: expected %v, "+ + "got %v", lntypes.HashSize, + len(k)) + } + + invoiceKey := binary.BigEndian.Uint32(v) + + return tx.SetKVInvoicePaymentHash(ctx, + sqlc.SetKVInvoicePaymentHashParams{ + ID: int64(invoiceKey), + Hash: k, + }, + ) + }) + }, func() {}) +} + +// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated +// invoice into the SQL database. The parameters are derived from the original +// invoice insert parameters. +func toInsertMigratedInvoiceParams(params sqlc.InsertInvoiceParams, +) sqlc.InsertMigratedInvoiceParams { + + return sqlc.InsertMigratedInvoiceParams{ + Hash: params.Hash, + Preimage: params.Preimage, + Memo: params.Memo, + AmountMsat: params.AmountMsat, + CltvDelta: params.CltvDelta, + Expiry: params.Expiry, + PaymentAddr: params.PaymentAddr, + PaymentRequest: params.PaymentRequest, + PaymentRequestHash: params.PaymentRequestHash, + State: params.State, + AmountPaidMsat: params.AmountPaidMsat, + IsAmp: params.IsAmp, + IsHodl: params.IsHodl, + IsKeysend: params.IsKeysend, + CreatedAt: params.CreatedAt, + } +} + +// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note +// that perfect equality between the old and new schemas is not achievable, as +// the invoice's add index cannot be mapped directly to its ID due to SQL’s +// auto-incrementing primary key. The ID returned from the insert will instead +// serve as the add index in the new schema. +func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries, + invoice *Invoice, paymentHash lntypes.Hash) error { + + insertInvoiceParams, err := makeInsertInvoiceParams( + invoice, paymentHash, + ) + if err != nil { + return err + } + + // Convert the insert invoice parameters to the migrated invoice insert + // parameters. + insertMigratedInvoiceParams := toInsertMigratedInvoiceParams( + insertInvoiceParams, + ) + + // If the invoice is settled, we'll also set the timestamp and the index + // at which it was settled. + if invoice.State == ContractSettled { + if invoice.SettleIndex == 0 { + return fmt.Errorf("settled invoice %s missing settle "+ + "index", paymentHash) + } + + if invoice.SettleDate.IsZero() { + return fmt.Errorf("settled invoice %s missing settle "+ + "date", paymentHash) + } + + insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64( + invoice.SettleIndex, + ) + insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime( + invoice.SettleDate.UTC(), + ) + } + + // First we need to insert the invoice itself so we can use the "add + // index" which in this case is the auto incrementing primary key that + // is returned from the insert. + invoiceID, err := tx.InsertMigratedInvoice( + ctx, insertMigratedInvoiceParams, + ) + if err != nil { + return fmt.Errorf("unable to insert invoice: %w", err) + } + + // Insert the invoice's features. + for feature := range invoice.Terms.Features.Features() { + params := sqlc.InsertInvoiceFeatureParams{ + InvoiceID: invoiceID, + Feature: int32(feature), + } + + err := tx.InsertInvoiceFeature(ctx, params) + if err != nil { + return fmt.Errorf("unable to insert invoice "+ + "feature(%v): %w", feature, err) + } + } + + sqlHtlcIDs := make(map[models.CircuitKey]int64) + + // Now insert the HTLCs of the invoice. We'll also keep track of the SQL + // ID of each HTLC so we can use it when inserting the AMP sub invoices. + for circuitKey, htlc := range invoice.Htlcs { + htlcParams := sqlc.InsertInvoiceHTLCParams{ + HtlcID: int64(circuitKey.HtlcID), + ChanID: strconv.FormatUint( + circuitKey.ChanID.ToUint64(), 10, + ), + AmountMsat: int64(htlc.Amt), + AcceptHeight: int32(htlc.AcceptHeight), + AcceptTime: htlc.AcceptTime.UTC(), + ExpiryHeight: int32(htlc.Expiry), + State: int16(htlc.State), + InvoiceID: invoiceID, + } + + // Leave the MPP amount as NULL if the MPP total amount is zero. + if htlc.MppTotalAmt != 0 { + htlcParams.TotalMppMsat = sqldb.SQLInt64( + int64(htlc.MppTotalAmt), + ) + } + + // Leave the resolve time as NULL if the HTLC is not resolved. + if !htlc.ResolveTime.IsZero() { + htlcParams.ResolveTime = sqldb.SQLTime( + htlc.ResolveTime.UTC(), + ) + } + + sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams) + if err != nil { + return fmt.Errorf("unable to insert invoice htlc: %w", + err) + } + + sqlHtlcIDs[circuitKey] = sqlID + + // Store custom records. + for key, value := range htlc.CustomRecords { + err = tx.InsertInvoiceHTLCCustomRecord( + ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{ + Key: int64(key), + Value: value, + HtlcID: sqlID, + }, + ) + if err != nil { + return err + } + } + } + + if !invoice.IsAMP() { + return nil + } + + for setID, ampState := range invoice.AMPState { + // Find the earliest HTLC of the AMP invoice, which will + // be used as the creation date of this sub invoice. + var createdAt time.Time + for circuitKey := range ampState.InvoiceKeys { + htlc := invoice.Htlcs[circuitKey] + if createdAt.IsZero() { + createdAt = htlc.AcceptTime.UTC() + continue + } + + if createdAt.After(htlc.AcceptTime) { + createdAt = htlc.AcceptTime.UTC() + } + } + + params := sqlc.InsertAMPSubInvoiceParams{ + SetID: setID[:], + State: int16(ampState.State), + CreatedAt: createdAt, + InvoiceID: invoiceID, + } + + if ampState.SettleIndex != 0 { + if ampState.SettleDate.IsZero() { + return fmt.Errorf("settled AMP sub invoice %x "+ + "missing settle date", setID) + } + + params.SettledAt = sqldb.SQLTime( + ampState.SettleDate.UTC(), + ) + + params.SettleIndex = sqldb.SQLInt64( + ampState.SettleIndex, + ) + } + + err := tx.InsertAMPSubInvoice(ctx, params) + if err != nil { + return fmt.Errorf("unable to insert AMP sub invoice: "+ + "%w", err) + } + + // Now we can add the AMP HTLCs to the database. + for circuitKey := range ampState.InvoiceKeys { + htlc := invoice.Htlcs[circuitKey] + rootShare := htlc.AMP.Record.RootShare() + + sqlHtlcID, ok := sqlHtlcIDs[circuitKey] + if !ok { + return fmt.Errorf("missing htlc for AMP htlc: "+ + "%v", circuitKey) + } + + params := sqlc.InsertAMPSubInvoiceHTLCParams{ + InvoiceID: invoiceID, + SetID: setID[:], + HtlcID: sqlHtlcID, + RootShare: rootShare[:], + ChildIndex: int64(htlc.AMP.Record.ChildIndex()), + Hash: htlc.AMP.Hash[:], + } + + if htlc.AMP.Preimage != nil { + params.Preimage = htlc.AMP.Preimage[:] + } + + err = tx.InsertAMPSubInvoiceHTLC(ctx, params) + if err != nil { + return fmt.Errorf("unable to insert AMP sub "+ + "invoice: %w", err) + } + } + } + + return nil +} + +// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local +// time zone and chops off the nanosecond part for comparison. This is needed +// because KV database stores times as-is which as an unwanted side effect would +// fail migration due to time comparison expecting both the original and +// migrated invoices to be in the same local time zone and in microsecond +// precision. Note that PostgreSQL stores times in microsecond precision while +// SQLite can store times in nanosecond precision if using TEXT storage class. +func OverrideInvoiceTimeZone(invoice *Invoice) { + fixTime := func(t time.Time) time.Time { + return t.In(time.Local).Truncate(time.Microsecond) + } + + invoice.CreationDate = fixTime(invoice.CreationDate) + + if !invoice.SettleDate.IsZero() { + invoice.SettleDate = fixTime(invoice.SettleDate) + } + + if invoice.IsAMP() { + for setID, ampState := range invoice.AMPState { + if ampState.SettleDate.IsZero() { + continue + } + + ampState.SettleDate = fixTime(ampState.SettleDate) + invoice.AMPState[setID] = ampState + } + } + + for _, htlc := range invoice.Htlcs { + if !htlc.AcceptTime.IsZero() { + htlc.AcceptTime = fixTime(htlc.AcceptTime) + } + + if !htlc.ResolveTime.IsZero() { + htlc.ResolveTime = fixTime(htlc.ResolveTime) + } + } +} + +// MigrateInvoicesToSQL runs the migration of all invoices from the KV database +// to the SQL database. The migration is done in a single transaction to ensure +// that all invoices are migrated or none at all. This function can be run +// multiple times without causing any issues as it will check if the migration +// has already been performed. +func MigrateInvoicesToSQL(ctx context.Context, db kvdb.Backend, + kvStore InvoiceDB, tx *sqlc.Queries, batchSize int) error { + + log.Infof("Starting migration of invoices from KV to SQL") + + offset := uint64(0) + t0 := time.Now() + // Create the hash index which we will use to look up invoice + // payment hashes by their add index during migration. + err := createInvoiceHashIndex(ctx, db, tx) + if err != nil && !errors.Is(err, ErrNoInvoicesCreated) { + log.Errorf("Unable to create invoice hash index: %v", + err) + + return err + } + log.Debugf("Created SQL invoice hash index in %v", time.Since(t0)) + + total := 0 + // Now we can start migrating the invoices. We'll do this in + // batches to reduce memory usage. + for { + t0 = time.Now() + query := InvoiceQuery{ + IndexOffset: offset, + NumMaxInvoices: uint64(batchSize), + } + + queryResult, err := kvStore.QueryInvoices(ctx, query) + if err != nil && !errors.Is(err, ErrNoInvoicesCreated) { + return fmt.Errorf("unable to query invoices: "+ + "%w", err) + } + + if len(queryResult.Invoices) == 0 { + log.Infof("All invoices migrated") + + break + } + + err = migrateInvoices(ctx, tx, queryResult.Invoices) + if err != nil { + return err + } + + offset = queryResult.LastIndexOffset + total += len(queryResult.Invoices) + log.Debugf("Migrated %d KV invoices to SQL in %v\n", total, + time.Since(t0)) + } + + // Clean up the hash index as it's no longer needed. + err = tx.ClearKVInvoiceHashIndex(ctx) + if err != nil { + return fmt.Errorf("unable to clear invoice hash "+ + "index: %w", err) + } + + log.Infof("Migration of %d invoices from KV to SQL completed", total) + + return nil +} + +func migrateInvoices(ctx context.Context, tx *sqlc.Queries, + invoices []Invoice) error { + + for i, invoice := range invoices { + var paymentHash lntypes.Hash + if invoice.Terms.PaymentPreimage != nil { + paymentHash = invoice.Terms.PaymentPreimage.Hash() + } else { + paymentHashBytes, err := + tx.GetKVInvoicePaymentHashByAddIndex( + ctx, int64(invoice.AddIndex), + ) + if err != nil { + // This would be an unexpected inconsistency + // in the kv database. We can't do much here + // so we'll notify the user and continue. + log.Warnf("Cannot migrate invoice, unable to "+ + "fetch payment hash (add_index=%v): %v", + invoice.AddIndex, err) + + continue + } + + copy(paymentHash[:], paymentHashBytes) + } + + err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash) + if err != nil { + return fmt.Errorf("unable to migrate invoice(%v): %w", + paymentHash, err) + } + + migratedInvoice, err := fetchInvoice( + ctx, tx, InvoiceRefByHash(paymentHash), + ) + if err != nil { + return fmt.Errorf("unable to fetch migrated "+ + "invoice(%v): %w", paymentHash, err) + } + + // Override the time zone for comparison. Note that we need to + // override both invoices as the original invoice is coming from + // KV database, it was stored as a binary serialized Go + // time.Time value which has nanosecond precision but might have + // been created in a different time zone. The migrated invoice + // is stored in SQL in UTC and selected in the local time zone, + // however in PostgreSQL it has microsecond precision while in + // SQLite it has nanosecond precision if using TEXT storage + // class. + OverrideInvoiceTimeZone(&invoice) + OverrideInvoiceTimeZone(migratedInvoice) + + // Override the add index before checking for equality. + migratedInvoice.AddIndex = invoice.AddIndex + + if !reflect.DeepEqual(invoice, *migratedInvoice) { + diff := difflib.UnifiedDiff{ + A: difflib.SplitLines( + spew.Sdump(invoice), + ), + B: difflib.SplitLines( + spew.Sdump(migratedInvoice), + ), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 3, + } + diffText, _ := difflib.GetUnifiedDiffString(diff) + + return fmt.Errorf("%w: %v.\n%v", ErrMigrationMismatch, + paymentHash, diffText) + } + } + + return nil +} diff --git a/invoices/sql_migration_test.go b/invoices/sql_migration_test.go new file mode 100644 index 0000000000..179097f489 --- /dev/null +++ b/invoices/sql_migration_test.go @@ -0,0 +1,421 @@ +package invoices + +import ( + "context" + crand "crypto/rand" + "database/sql" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +var ( + // testHtlcIDSequence is a global counter for generating unique HTLC + // IDs. + testHtlcIDSequence uint64 +) + +// randomString generates a random string of a given length using rapid. +func randomStringRapid(t *rapid.T, length int) string { + // Define the character set for the string. + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll + + // Generate a string by selecting random characters from the charset. + runes := make([]rune, length) + for i := range runes { + // Draw a random index and use it to select a character from the + // charset. + index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex") + runes[i] = rune(charset[index]) + } + + return string(runes) +} + +// randTimeBetween generates a random time between min and max. +func randTimeBetween(min, max time.Time) time.Time { + var timeZones = []*time.Location{ + time.UTC, + time.FixedZone("EST", -5*3600), + time.FixedZone("MST", -7*3600), + time.FixedZone("PST", -8*3600), + time.FixedZone("CEST", 2*3600), + } + + // Ensure max is after min + if max.Before(min) { + min, max = max, min + } + + // Calculate the range in nanoseconds + duration := max.Sub(min) + randDuration := time.Duration(rand.Int63n(duration.Nanoseconds())) + + // Generate the random time + randomTime := min.Add(randDuration) + + // Assign a random time zone + randomTimeZone := timeZones[rand.Intn(len(timeZones))] + + // Return the time in the random time zone + return randomTime.In(randomTimeZone) +} + +// randTime generates a random time between 2009 and 2140. +func randTime() time.Time { + min := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC) + max := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC) + + return randTimeBetween(min, max) +} + +func randInvoiceTime(invoice *Invoice) time.Time { + return randTimeBetween( + invoice.CreationDate, + invoice.CreationDate.Add(invoice.Terms.Expiry), + ) +} + +// randHTLCRapid generates a random HTLC for an invoice using rapid to randomize +// its parameters. +func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) ( + models.CircuitKey, *InvoiceHTLC) { + + htlc := &InvoiceHTLC{ + Amt: amt, + AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"), + AcceptTime: randInvoiceTime(invoice), + Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"), + } + + // Set MPP total amount if MPP feature is enabled in the invoice. + if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) { + htlc.MppTotalAmt = invoice.Terms.Value + } + + // Set the HTLC state and resolve time based on the invoice state. + switch invoice.State { + case ContractSettled: + htlc.State = HtlcStateSettled + htlc.ResolveTime = randInvoiceTime(invoice) + + case ContractCanceled: + htlc.State = HtlcStateCanceled + htlc.ResolveTime = randInvoiceTime(invoice) + + case ContractAccepted: + htlc.State = HtlcStateAccepted + } + + // Add randomized custom records to the HTLC. + htlc.CustomRecords = make(record.CustomSet) + numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") + for i := 0; i < numRecords; i++ { + key := rapid.Uint64Range( + record.CustomTypeStart, 1000+record.CustomTypeStart, + ).Draw(t, "customRecordKey") + value := []byte(randomStringRapid(t, 10)) + htlc.CustomRecords[key] = value + } + + // Generate a unique HTLC ID and assign it to a channel ID. + htlcID := atomic.AddUint64(&testHtlcIDSequence, 1) + randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5) + + circuitKey := models.CircuitKey{ + ChanID: randChanID, + HtlcID: htlcID, + } + + return circuitKey, htlc +} + +// generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP +// HTLCs if applicable, using rapid for randomization of HTLC count and +// distribution. +func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) { + mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired) + + // Use rapid to determine the number of HTLCs based on invoice state and + // MPP feature. + numHTLCs := 1 + if invoice.State == ContractOpen { + numHTLCs = 0 + } else if mpp { + numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs") + } + + total := invoice.Terms.Value + + // Distribute the total amount across the HTLCs, adding any remainder to + // the last HTLC. + if numHTLCs > 0 { + amt := total / lnwire.MilliSatoshi(numHTLCs) + remainder := total - amt*lnwire.MilliSatoshi(numHTLCs) + + for i := 0; i < numHTLCs; i++ { + if i == numHTLCs-1 { + // Add remainder to the last HTLC. + amt += remainder + } + + // Generate an HTLC with a random circuit key and add it + // to the invoice. + circuitKey, htlc := randHTLCRapid(t, invoice, amt) + invoice.Htlcs[circuitKey] = htlc + } + } +} + +// generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to +// randomize various parameters of the HTLCs in the AMP set. +func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) { + // Randomly determine the number of AMP sets (1 to 5). + numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs") + settledIdx := uint64(1) + + for i := 0; i < numSetIDs; i++ { + var setID SetID + _, err := crand.Read(setID[:]) + require.NoError(t, err) + + // Determine the number of HTLCs in this set (1 to 5). + numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs") + total := invoice.Terms.Value + invoiceKeys := make(map[CircuitKey]struct{}) + + // Calculate the amount per HTLC and account for remainder in + // the final HTLC. + amt := total / lnwire.MilliSatoshi(numHTLCs) + remainder := total - amt*lnwire.MilliSatoshi(numHTLCs) + + var htlcState HtlcState + for j := 0; j < numHTLCs; j++ { + if j == numHTLCs-1 { + amt += remainder + } + + // Generate HTLC with randomized parameters. + circuitKey, htlc := randHTLCRapid(t, invoice, amt) + htlcState = htlc.State + + var ( + rootShare, hash [32]byte + preimage lntypes.Preimage + ) + + // Randomize AMP data fields. + _, err := crand.Read(rootShare[:]) + require.NoError(t, err) + _, err = crand.Read(hash[:]) + require.NoError(t, err) + _, err = crand.Read(preimage[:]) + require.NoError(t, err) + + record := record.NewAMP(rootShare, setID, uint32(j)) + + htlc.AMP = &InvoiceHtlcAMPData{ + Record: *record, + Hash: hash, + Preimage: &preimage, + } + + invoice.Htlcs[circuitKey] = htlc + invoiceKeys[circuitKey] = struct{}{} + } + + ampState := InvoiceStateAMP{ + State: htlcState, + InvoiceKeys: invoiceKeys, + } + if htlcState == HtlcStateSettled { + ampState.SettleIndex = settledIdx + ampState.SettleDate = randInvoiceTime(invoice) + settledIdx++ + } + + // Set the total amount paid if the AMP set is not canceled. + if htlcState != HtlcStateCanceled { + ampState.AmtPaid = invoice.Terms.Value + } + + invoice.AMPState[setID] = ampState + } +} + +// TestMigrateSingleInvoiceRapid tests the migration of single invoices with +// random data variations using rapid. This test generates a random invoice +// configuration and ensures successful migration. +// +// NOTE: This test may need to be changed if the Invoice or any of the related +// types are modified. +func TestMigrateSingleInvoiceRapid(t *testing.T) { + // Create a shared Postgres instance for efficient testing. + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore { + var db *sqldb.BaseDB + if sqlite { + db = sqldb.NewTestSqliteDB(t).BaseDB + } else { + db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB + } + + executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLInvoiceQueries { + return db.WithTx(tx) + }, + ) + + testClock := clock.NewTestClock(time.Unix(1, 0)) + + return NewSQLStore(executor, testClock) + } + + // Define property-based test using rapid. + rapid.Check(t, func(rt *rapid.T) { + // Randomized feature flags for MPP and AMP. + mpp := rapid.Bool().Draw(rt, "mpp") + amp := rapid.Bool().Draw(rt, "amp") + + for _, sqlite := range []bool{true, false} { + store := makeSQLDB(t, sqlite) + testMigrateSingleInvoiceRapid(rt, store, mpp, amp) + } + }) +} + +// testMigrateSingleInvoiceRapid is the primary function for the migration of a +// single invoice with random data in a rapid-based test setup. +func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool, + amp bool) { + + ctxb := context.Background() + invoices := make(map[lntypes.Hash]*Invoice) + + for i := 0; i < 100; i++ { + invoice := generateTestInvoiceRapid(t, mpp, amp) + var hash lntypes.Hash + _, err := crand.Read(hash[:]) + require.NoError(t, err) + + invoices[hash] = invoice + } + + var ops SQLInvoiceQueriesTxOptions + err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error { + for hash, invoice := range invoices { + err := MigrateSingleInvoice(ctxb, tx, invoice, hash) + require.NoError(t, err) + } + + return nil + }, func() {}) + require.NoError(t, err) + + // Fetch and compare each migrated invoice from the store with the + // original. + for hash, invoice := range invoices { + sqlInvoice, err := store.LookupInvoice( + ctxb, InvoiceRefByHash(hash), + ) + require.NoError(t, err) + + invoice.AddIndex = sqlInvoice.AddIndex + + OverrideInvoiceTimeZone(invoice) + OverrideInvoiceTimeZone(&sqlInvoice) + + require.Equal(t, *invoice, sqlInvoice) + } +} + +// generateTestInvoiceRapid generates a random invoice with variations based on +// mpp and amp flags. +func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice { + var preimage lntypes.Preimage + _, err := crand.Read(preimage[:]) + require.NoError(t, err) + + terms := ContractTerm{ + FinalCltvDelta: rapid.Int32Range(1, 1000).Draw( + t, "FinalCltvDelta", + ), + Expiry: time.Duration( + rapid.IntRange(1, 4444).Draw(t, "Expiry"), + ) * time.Minute, + PaymentPreimage: &preimage, + Value: lnwire.MilliSatoshi( + rapid.Int64Range(1, 9999999).Draw(t, "Value"), + ), + PaymentAddr: [32]byte{}, + Features: lnwire.EmptyFeatureVector(), + } + + if amp { + terms.Features.Set(lnwire.AMPRequired) + } else if mpp { + terms.Features.Set(lnwire.MPPRequired) + } + + created := randTime() + + const maxContractState = 3 + state := ContractState( + rapid.IntRange(0, maxContractState).Draw(t, "ContractState"), + ) + var ( + settled time.Time + settleIndex uint64 + ) + if state == ContractSettled { + settled = randTimeBetween(created, created.Add(terms.Expiry)) + settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex") + } + + invoice := &Invoice{ + Memo: []byte(randomStringRapid(t, 10)), + PaymentRequest: []byte( + randomStringRapid(t, MaxPaymentRequestSize), + ), + CreationDate: created, + SettleDate: settled, + Terms: terms, + AddIndex: 0, + SettleIndex: settleIndex, + State: state, + AMPState: make(map[SetID]InvoiceStateAMP), + HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"), + } + + invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC) + + if invoice.IsAMP() { + generateAMPHtlcsRapid(t, invoice) + } else { + generateInvoiceHTLCsRapid(t, invoice) + } + + for _, htlc := range invoice.Htlcs { + if htlc.State == HtlcStateSettled { + invoice.AmtPaid += htlc.Amt + } + } + + return invoice +} diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 5459ec26c7..55517bfdd4 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -32,6 +32,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64, error) + // TODO(bhandras): remove this once migrations have been separated out. + InsertMigratedInvoice(ctx context.Context, + arg sqlc.InsertMigratedInvoiceParams) (int64, error) + InsertInvoiceFeature(ctx context.Context, arg sqlc.InsertInvoiceFeatureParams) error @@ -47,6 +51,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat GetInvoice(ctx context.Context, arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error) + GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice, + error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice, error) @@ -79,6 +86,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat UpsertAMPSubInvoice(ctx context.Context, arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error) + // TODO(bhandras): remove this once migrations have been separated out. + InsertAMPSubInvoice(ctx context.Context, + arg sqlc.InsertAMPSubInvoiceParams) error + UpdateAMPSubInvoiceState(ctx context.Context, arg sqlc.UpdateAMPSubInvoiceStateParams) error @@ -119,6 +130,19 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat OnAMPSubInvoiceSettled(ctx context.Context, arg sqlc.OnAMPSubInvoiceSettledParams) error + + // Migration specific methods. + // TODO(bhandras): remove this once migrations have been separated out. + InsertKVInvoiceKeyAndAddIndex(ctx context.Context, + arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error + + SetKVInvoicePaymentHash(ctx context.Context, + arg sqlc.SetKVInvoicePaymentHashParams) error + + GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ( + []byte, error) + + ClearKVInvoiceHashIndex(ctx context.Context) error } var _ InvoiceDB = (*SQLStore)(nil) @@ -200,6 +224,66 @@ func NewSQLStore(db BatchedSQLInvoiceQueries, } } +func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) ( + sqlc.InsertInvoiceParams, error) { + + // Precompute the payment request hash so we can use it in the query. + var paymentRequestHash []byte + if len(invoice.PaymentRequest) > 0 { + h := sha256.New() + h.Write(invoice.PaymentRequest) + paymentRequestHash = h.Sum(nil) + } + + params := sqlc.InsertInvoiceParams{ + Hash: paymentHash[:], + AmountMsat: int64(invoice.Terms.Value), + CltvDelta: sqldb.SQLInt32( + invoice.Terms.FinalCltvDelta, + ), + Expiry: int32(invoice.Terms.Expiry.Seconds()), + // Note: keysend invoices don't have a payment request. + PaymentRequest: sqldb.SQLStr(string( + invoice.PaymentRequest), + ), + PaymentRequestHash: paymentRequestHash, + State: int16(invoice.State), + AmountPaidMsat: int64(invoice.AmtPaid), + IsAmp: invoice.IsAMP(), + IsHodl: invoice.HodlInvoice, + IsKeysend: invoice.IsKeysend(), + CreatedAt: invoice.CreationDate.UTC(), + } + + if invoice.Memo != nil { + // Store the memo as a nullable string in the database. Note + // that for compatibility reasons, we store the value as a valid + // string even if it's empty. + params.Memo = sql.NullString{ + String: string(invoice.Memo), + Valid: true, + } + } + + // Some invoices may not have a preimage, like in the case of HODL + // invoices. + if invoice.Terms.PaymentPreimage != nil { + preimage := *invoice.Terms.PaymentPreimage + if preimage == UnknownPreimage { + return sqlc.InsertInvoiceParams{}, + errors.New("cannot use all-zeroes preimage") + } + params.Preimage = preimage[:] + } + + // Some non MPP payments may have the default (invalid) value. + if invoice.Terms.PaymentAddr != BlankPayAddr { + params.PaymentAddr = invoice.Terms.PaymentAddr[:] + } + + return params, nil +} + // AddInvoice inserts the targeted invoice into the database. If the invoice has // *any* payment hashes which already exists within the database, then the // insertion will be aborted and rejected due to the strict policy banning any @@ -220,55 +304,16 @@ func (i *SQLStore) AddInvoice(ctx context.Context, invoiceID int64 ) - // Precompute the payment request hash so we can use it in the query. - var paymentRequestHash []byte - if len(newInvoice.PaymentRequest) > 0 { - h := sha256.New() - h.Write(newInvoice.PaymentRequest) - paymentRequestHash = h.Sum(nil) + insertInvoiceParams, err := makeInsertInvoiceParams( + newInvoice, paymentHash, + ) + if err != nil { + return 0, err } - err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error { - params := sqlc.InsertInvoiceParams{ - Hash: paymentHash[:], - Memo: sqldb.SQLStr(string(newInvoice.Memo)), - AmountMsat: int64(newInvoice.Terms.Value), - // Note: BOLT12 invoices don't have a final cltv delta. - CltvDelta: sqldb.SQLInt32( - newInvoice.Terms.FinalCltvDelta, - ), - Expiry: int32(newInvoice.Terms.Expiry.Seconds()), - // Note: keysend invoices don't have a payment request. - PaymentRequest: sqldb.SQLStr(string( - newInvoice.PaymentRequest), - ), - PaymentRequestHash: paymentRequestHash, - State: int16(newInvoice.State), - AmountPaidMsat: int64(newInvoice.AmtPaid), - IsAmp: newInvoice.IsAMP(), - IsHodl: newInvoice.HodlInvoice, - IsKeysend: newInvoice.IsKeysend(), - CreatedAt: newInvoice.CreationDate.UTC(), - } - - // Some invoices may not have a preimage, like in the case of - // HODL invoices. - if newInvoice.Terms.PaymentPreimage != nil { - preimage := *newInvoice.Terms.PaymentPreimage - if preimage == UnknownPreimage { - return errors.New("cannot use all-zeroes " + - "preimage") - } - params.Preimage = preimage[:] - } - - // Some non MPP payments may have the default (invalid) value. - if newInvoice.Terms.PaymentAddr != BlankPayAddr { - params.PaymentAddr = newInvoice.Terms.PaymentAddr[:] - } - + err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error { var err error - invoiceID, err = db.InsertInvoice(ctx, params) + invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams) if err != nil { return fmt.Errorf("unable to insert invoice: %w", err) } @@ -312,22 +357,31 @@ func (i *SQLStore) AddInvoice(ctx context.Context, return newInvoice.AddIndex, nil } -// fetchInvoice fetches the common invoice data and the AMP state for the -// invoice with the given reference. -func (i *SQLStore) fetchInvoice(ctx context.Context, - db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) { +// getInvoiceByRef fetches the invoice with the given reference. The reference +// may be a payment hash, a payment address, or a set ID for an AMP sub invoice. +func getInvoiceByRef(ctx context.Context, + db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) { + // If the reference is empty, we can't look up the invoice. if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil { - return nil, ErrInvoiceNotFound + return sqlc.Invoice{}, ErrInvoiceNotFound } - var ( - invoice *Invoice - params sqlc.GetInvoiceParams - ) + // If the reference is a hash only, we can look up the invoice directly + // by the payment hash which is faster. + if ref.IsHashOnly() { + invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:]) + if errors.Is(err, sql.ErrNoRows) { + return sqlc.Invoice{}, ErrInvoiceNotFound + } + + return invoice, err + } + + // Otherwise the reference may include more fields, so we'll need to + // assemble the query parameters based on the fields that are set. + var params sqlc.GetInvoiceParams - // Given all invoices are uniquely identified by their payment hash, - // we can use it to query a specific invoice. if ref.PayHash() != nil { params.Hash = ref.PayHash()[:] } @@ -363,18 +417,34 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, } else { rows, err = db.GetInvoice(ctx, params) } + switch { case len(rows) == 0: - return nil, ErrInvoiceNotFound + return sqlc.Invoice{}, ErrInvoiceNotFound case len(rows) > 1: // In case the reference is ambiguous, meaning it matches more // than one invoice, we'll return an error. - return nil, fmt.Errorf("ambiguous invoice ref: %s: %s", - ref.String(), spew.Sdump(rows)) + return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+ + "%s: %s", ref.String(), spew.Sdump(rows)) case err != nil: - return nil, fmt.Errorf("unable to fetch invoice: %w", err) + return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w", + err) + } + + return rows[0], nil +} + +// fetchInvoice fetches the common invoice data and the AMP state for the +// invoice with the given reference. +func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) ( + *Invoice, error) { + + // Fetch the invoice from the database. + sqlInvoice, err := getInvoiceByRef(ctx, db, ref) + if err != nil { + return nil, err } var ( @@ -391,8 +461,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, fetchAmpHtlcs = true case HtlcSetOnlyModifier: - // In this case we'll fetch all AMP HTLCs for the - // specified set id. + // In this case we'll fetch all AMP HTLCs for the specified set + // id. if ref.SetID() == nil { return nil, fmt.Errorf("set ID is required to use " + "the HTLC set only modifier") @@ -412,8 +482,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, } // Fetch the rest of the invoice data and fill the invoice struct. - _, invoice, err = fetchInvoiceData( - ctx, db, rows[0], setID, fetchAmpHtlcs, + _, invoice, err := fetchInvoiceData( + ctx, db, sqlInvoice, setID, fetchAmpHtlcs, ) if err != nil { return nil, err @@ -616,7 +686,7 @@ func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64, invoiceKeys[key] = struct{}{} - if htlc.State != HtlcStateCanceled { //nolint: ll + if htlc.State != HtlcStateCanceled { amtPaid += htlc.Amt } } @@ -646,7 +716,7 @@ func (i *SQLStore) LookupInvoice(ctx context.Context, readTxOpt := NewSQLInvoiceQueryReadTx() txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error { - invoice, err = i.fetchInvoice(ctx, db, ref) + invoice, err = fetchInvoice(ctx, db, ref) return err }, func() {}) @@ -1347,7 +1417,7 @@ func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, ref.refModifier = HtlcSetOnlyModifier } - invoice, err := i.fetchInvoice(ctx, db, ref) + invoice, err := fetchInvoice(ctx, db, ref) if err != nil { return err } @@ -1506,13 +1576,6 @@ func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries, if len(htlcs) > 0 { invoice.Htlcs = htlcs - var amountPaid lnwire.MilliSatoshi - for _, htlc := range htlcs { - if htlc.State == HtlcStateSettled { - amountPaid += htlc.Amt - } - } - invoice.AmtPaid = amountPaid } return hash, invoice, nil diff --git a/invoices/testdata/channel.db b/invoices/testdata/channel.db new file mode 100644 index 0000000000..69397f529d Binary files /dev/null and b/invoices/testdata/channel.db differ diff --git a/itest/list_on_test.go b/itest/list_on_test.go index be3244fe5e..38cd56d350 100644 --- a/itest/list_on_test.go +++ b/itest/list_on_test.go @@ -626,10 +626,6 @@ var allTestCases = []*lntest.TestCase{ Name: "open channel locked balance", TestFunc: testOpenChannelLockedBalance, }, - { - Name: "nativesql no migration", - TestFunc: testNativeSQLNoMigration, - }, { Name: "sweep cpfp anchor outgoing timeout", TestFunc: testSweepCPFPAnchorOutgoingTimeout, @@ -682,6 +678,10 @@ var allTestCases = []*lntest.TestCase{ Name: "quiescence", TestFunc: testQuiescence, }, + { + Name: "invoice migration", + TestFunc: testInvoiceMigration, + }, } // appendPrefixed is used to add a prefix to each test name in the subtests diff --git a/itest/lnd_invoice_migration_test.go b/itest/lnd_invoice_migration_test.go new file mode 100644 index 0000000000..b4bcfcdc46 --- /dev/null +++ b/itest/lnd_invoice_migration_test.go @@ -0,0 +1,307 @@ +package itest + +import ( + "database/sql" + "path" + "time" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/invoices" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/kvdb/postgres" + "github.com/lightningnetwork/lnd/kvdb/sqlbase" + "github.com/lightningnetwork/lnd/kvdb/sqlite" + "github.com/lightningnetwork/lnd/lncfg" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/node" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +func openChannelDB(ht *lntest.HarnessTest, hn *node.HarnessNode) *channeldb.DB { + sqlbase.Init(0) + var ( + backend kvdb.Backend + err error + ) + + switch hn.Cfg.DBBackend { + case node.BackendSqlite: + backend, err = kvdb.Open( + kvdb.SqliteBackendName, + ht.Context(), + &sqlite.Config{ + Timeout: defaultTimeout, + BusyTimeout: defaultTimeout, + }, + hn.Cfg.DBDir(), lncfg.SqliteChannelDBName, + lncfg.NSChannelDB, + ) + require.NoError(ht, err) + + case node.BackendPostgres: + backend, err = kvdb.Open( + kvdb.PostgresBackendName, ht.Context(), + &postgres.Config{ + Dsn: hn.Cfg.PostgresDsn, + Timeout: defaultTimeout, + }, lncfg.NSChannelDB, + ) + require.NoError(ht, err) + } + + db, err := channeldb.CreateWithBackend(backend) + require.NoError(ht, err) + + return db +} + +func openNativeSQLInvoiceDB(ht *lntest.HarnessTest, + hn *node.HarnessNode) invoices.InvoiceDB { + + var db *sqldb.BaseDB + + switch hn.Cfg.DBBackend { + case node.BackendSqlite: + sqliteStore, err := sqldb.NewSqliteStore( + &sqldb.SqliteConfig{ + Timeout: defaultTimeout, + BusyTimeout: defaultTimeout, + }, + path.Join( + hn.Cfg.DBDir(), + lncfg.SqliteNativeDBName, + ), + ) + require.NoError(ht, err) + db = sqliteStore.BaseDB + + case node.BackendPostgres: + postgresStore, err := sqldb.NewPostgresStore( + &sqldb.PostgresConfig{ + Dsn: hn.Cfg.PostgresDsn, + Timeout: defaultTimeout, + }, + ) + require.NoError(ht, err) + db = postgresStore.BaseDB + } + + executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) invoices.SQLInvoiceQueries { + return db.WithTx(tx) + }, + ) + + return invoices.NewSQLStore( + executor, clock.NewDefaultClock(), + ) +} + +// clampTime truncates the time of the passed invoice to the microsecond level. +func clampTime(invoice *invoices.Invoice) { + trunc := func(t time.Time) time.Time { + return t.Truncate(time.Microsecond) + } + + invoice.CreationDate = trunc(invoice.CreationDate) + + if !invoice.SettleDate.IsZero() { + invoice.SettleDate = trunc(invoice.SettleDate) + } + + if invoice.IsAMP() { + for setID, ampState := range invoice.AMPState { + if ampState.SettleDate.IsZero() { + continue + } + + ampState.SettleDate = trunc(ampState.SettleDate) + invoice.AMPState[setID] = ampState + } + } + + for _, htlc := range invoice.Htlcs { + if !htlc.AcceptTime.IsZero() { + htlc.AcceptTime = trunc(htlc.AcceptTime) + } + + if !htlc.ResolveTime.IsZero() { + htlc.ResolveTime = trunc(htlc.ResolveTime) + } + } +} + +// testInvoiceMigration tests that the invoice migration from the old KV store +// to the new native SQL store works as expected. +func testInvoiceMigration(ht *lntest.HarnessTest) { + alice := ht.NewNodeWithCoins("Alice", nil) + bob := ht.NewNodeWithCoins("Bob", nil) + + // Make sure we run the test with SQLite or Postgres. + if bob.Cfg.DBBackend != node.BackendSqlite && + bob.Cfg.DBBackend != node.BackendPostgres { + + ht.Skip("node not running with SQLite or Postgres") + } + + // Skip the test if the node is already running with native SQL. + if bob.Cfg.NativeSQL { + ht.Skip("node already running with native SQL") + } + + ht.EnsureConnected(alice, bob) + cp := ht.OpenChannel( + alice, bob, lntest.OpenChannelParams{ + Amt: 1000000, + PushAmt: 500000, + }, + ) + + // Alice and bob should have one channel open with each other now. + ht.AssertNodeNumChannels(alice, 1) + ht.AssertNodeNumChannels(bob, 1) + + ht.RestartNodeWithExtraArgs(bob, []string{ + "--accept-amp", + }) + + // Step 1: Add 10 normal invoices and pay 5 of them. + normalInvoices := make([]*lnrpc.AddInvoiceResponse, 10) + for i := 0; i < 10; i++ { + invoice := &lnrpc.Invoice{ + Value: int64(1000 + i*100), // Varying amounts + IsAmp: false, + } + + resp := bob.RPC.AddInvoice(invoice) + normalInvoices[i] = resp + } + + for _, inv := range normalInvoices { + sendReq := &routerrpc.SendPaymentRequest{ + PaymentRequest: inv.PaymentRequest, + TimeoutSeconds: 60, + FeeLimitMsat: noFeeLimitMsat, + } + + ht.SendPaymentAssertSettled(alice, sendReq) + } + + // Step 2: Add 10 AMP invoices and send multiple payments to 5 of them. + ampInvoices := make([]*lnrpc.AddInvoiceResponse, 10) + for i := 0; i < 10; i++ { + invoice := &lnrpc.Invoice{ + Value: int64(2000 + i*200), // Varying amounts + IsAmp: true, + } + + resp := bob.RPC.AddInvoice(invoice) + ampInvoices[i] = resp + } + + // Select the first 5 invoices to send multiple AMP payments. + for i := 0; i < 5; i++ { + inv := ampInvoices[i] + + // Send 3 payments to each. + for j := 0; j < 3; j++ { + payReq := &routerrpc.SendPaymentRequest{ + PaymentRequest: inv.PaymentRequest, + TimeoutSeconds: 60, + FeeLimitMsat: noFeeLimitMsat, + Amp: true, + } + + // Send a normal AMP payment first, then a spontaneous + // AMP payment. + ht.SendPaymentAssertSettled(alice, payReq) + + // Generate an external payment address when attempting + // to pseudo-reuse an AMP invoice. When using an + // external payment address, we'll also expect an extra + // invoice to appear in the ListInvoices response, since + // a new invoice will be JIT inserted under a different + // payment address than the one in the invoice. + // + // NOTE: This will only work when the peer has + // spontaneous AMP payments enabled otherwise no invoice + // under a different payment_addr will be found. + payReq.PaymentAddr = ht.Random32Bytes() + ht.SendPaymentAssertSettled(alice, payReq) + } + } + + // We can close the channel now. + ht.CloseChannel(alice, cp) + + // Now stop Bob so we can open the DB for examination. + require.NoError(ht, bob.Stop()) + + // Open the KV channel DB. + db := openChannelDB(ht, bob) + + query := invoices.InvoiceQuery{ + IndexOffset: 0, + // As a sanity check, fetch more invoices than we have + // to ensure that we did not add any extra invoices. + NumMaxInvoices: 9999, + } + + // Fetch all invoices and make sure we have 35 in total. + result1, err := db.QueryInvoices(ht.Context(), query) + require.NoError(ht, err) + + numInvoices := len(result1.Invoices) + + bob.SetExtraArgs([]string{"--db.use-native-sql"}) + + // Now run the migration flow three times to ensure that each run is + // idempotent. + for i := 0; i < 3; i++ { + // Start bob with the native SQL flag set. This will trigger the + // migration to run. + require.NoError(ht, bob.Start(ht.Context())) + + // At this point the migration should have completed and the + // node should be running with native SQL. Now we'll stop Bob + // again so we can safely examine the database. + require.NoError(ht, bob.Stop()) + + // Now we'll open the database with the native SQL backend and + // fetch the invoices again to ensure that they were migrated + // correctly. + sqlInvoiceDB := openNativeSQLInvoiceDB(ht, bob) + result2, err := sqlInvoiceDB.QueryInvoices(ht.Context(), query) + require.NoError(ht, err) + + require.Equal(ht, numInvoices, len(result2.Invoices)) + + // Simply zero out the add index so we don't fail on that when + // comparing. + for i := 0; i < numInvoices; i++ { + result1.Invoices[i].AddIndex = 0 + result2.Invoices[i].AddIndex = 0 + + // Clamp the precision to microseconds. Note that we + // need to override both invoices as the original + // invoice is coming from KV database, it was stored as + // a binary serialized Go time.Time value which has + // nanosecond precision. The migrated invoice is stored + // in SQL in PostgreSQL has microsecond precision while + // in SQLite it has nanosecond precision if using TEXT + // storage class. + clampTime(&result1.Invoices[i]) + clampTime(&result2.Invoices[i]) + require.Equal( + ht, result1.Invoices[i], result2.Invoices[i], + ) + } + } + + // Start Bob again so the test can complete. + require.NoError(ht, bob.Start(ht.Context())) +} diff --git a/itest/lnd_misc_test.go b/itest/lnd_misc_test.go index 30dba0a878..98b1121c6a 100644 --- a/itest/lnd_misc_test.go +++ b/itest/lnd_misc_test.go @@ -1,7 +1,6 @@ package itest import ( - "context" "encoding/hex" "fmt" "os" @@ -1245,44 +1244,6 @@ func testSignVerifyMessageWithAddr(ht *lntest.HarnessTest) { require.False(ht, respValid.Valid, "external signature did validate") } -// testNativeSQLNoMigration tests that nodes that have invoices would not start -// up with native SQL enabled, as we don't currently support migration of KV -// invoices to the new SQL schema. -func testNativeSQLNoMigration(ht *lntest.HarnessTest) { - alice := ht.NewNode("Alice", nil) - - // Make sure we run the test with SQLite or Postgres. - if alice.Cfg.DBBackend != node.BackendSqlite && - alice.Cfg.DBBackend != node.BackendPostgres { - - ht.Skip("node not running with SQLite or Postgres") - } - - // Skip the test if the node is already running with native SQL. - if alice.Cfg.NativeSQL { - ht.Skip("node already running with native SQL") - } - - alice.RPC.AddInvoice(&lnrpc.Invoice{ - Value: 10_000, - }) - - alice.SetExtraArgs([]string{"--db.use-native-sql"}) - - // Restart the node manually as we're really only interested in the - // startup error. - require.NoError(ht, alice.Stop()) - require.NoError(ht, alice.StartLndCmd(context.Background())) - - // We expect the node to fail to start up with native SQL enabled, as we - // have an invoice in the KV store. - require.Error(ht, alice.WaitForProcessExit()) - - // Reset the extra args and restart alice. - alice.SetExtraArgs(nil) - require.NoError(ht, alice.Start(ht.Context())) -} - // testSendSelectedCoins tests that we're able to properly send the selected // coins from the wallet to a single target address. func testSendSelectedCoins(ht *lntest.HarnessTest) { diff --git a/lncfg/db.go b/lncfg/db.go index 3d45bb78b1..bdeadb933d 100644 --- a/lncfg/db.go +++ b/lncfg/db.go @@ -87,6 +87,8 @@ type DB struct { UseNativeSQL bool `long:"use-native-sql" description:"Use native SQL for tables that already support it."` + SkipSQLInvoiceMigration bool `long:"skip-sql-invoice-migration" description:"Do not migrate invoices stored in our key-value database to native SQL."` + NoGraphCache bool `long:"no-graph-cache" description:"Don't use the in-memory graph cache for path finding. Much slower but uses less RAM. Can only be used with a bolt database backend."` PruneRevocation bool `long:"prune-revocation" description:"Run the optional migration that prunes the revocation logs to save disk space."` @@ -115,7 +117,8 @@ func DefaultDB() *DB { MaxConnections: defaultSqliteMaxConnections, BusyTimeout: defaultSqliteBusyTimeout, }, - UseNativeSQL: false, + UseNativeSQL: false, + SkipSQLInvoiceMigration: false, } } @@ -231,10 +234,10 @@ type DatabaseBackends struct { // the underlying wallet database from. WalletDB btcwallet.LoaderOption - // NativeSQLStore is a pointer to a native SQL store that can be used - // for native SQL queries for tables that already support it. This may - // be nil if the use-native-sql flag was not set. - NativeSQLStore *sqldb.BaseDB + // NativeSQLStore holds a reference to the native SQL store that can + // be used for native SQL queries for tables that already support it. + // This may be nil if the use-native-sql flag was not set. + NativeSQLStore sqldb.DB // Remote indicates whether the database backends are remote, possibly // replicated instances or local bbolt or sqlite backed databases. @@ -449,7 +452,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, } closeFuncs[NSWalletDB] = postgresWalletBackend.Close - var nativeSQLStore *sqldb.BaseDB + var nativeSQLStore sqldb.DB if db.UseNativeSQL { nativePostgresStore, err := sqldb.NewPostgresStore( db.Postgres, @@ -459,7 +462,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, "native postgres store: %v", err) } - nativeSQLStore = nativePostgresStore.BaseDB + nativeSQLStore = nativePostgresStore closeFuncs[PostgresBackend] = nativePostgresStore.Close } @@ -571,7 +574,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, } closeFuncs[NSWalletDB] = sqliteWalletBackend.Close - var nativeSQLStore *sqldb.BaseDB + var nativeSQLStore sqldb.DB if db.UseNativeSQL { nativeSQLiteStore, err := sqldb.NewSqliteStore( db.Sqlite, @@ -582,7 +585,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath, "native SQLite store: %v", err) } - nativeSQLStore = nativeSQLiteStore.BaseDB + nativeSQLStore = nativeSQLiteStore closeFuncs[SqliteBackend] = nativeSQLiteStore.Close } diff --git a/sample-lnd.conf b/sample-lnd.conf index 86ea824858..5390ccd77c 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -1472,6 +1472,9 @@ ; own risk. ; db.use-native-sql=false +; If set to true, native SQL invoice migration will be skipped. Note that this +; option is intended for users who experience non-resolvable migration errors. +; db.skip-sql-invoice-migration=false [etcd] diff --git a/sqldb/interfaces.go b/sqldb/interfaces.go index 3c042aa5a7..1c5b4878fb 100644 --- a/sqldb/interfaces.go +++ b/sqldb/interfaces.go @@ -355,6 +355,18 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, ) } +// DB is an interface that represents a generic SQL database. It provides +// methods to apply migrations and access the underlying database connection. +type DB interface { + // GetBaseDB returns the underlying BaseDB instance. + GetBaseDB() *BaseDB + + // ApplyAllMigrations applies all migrations to the database including + // both sqlc and custom in-code migrations. + ApplyAllMigrations(ctx context.Context, + customMigrations []MigrationConfig) error +} + // BaseDB is the base database struct that each implementation can embed to // gain some common functionality. type BaseDB struct { diff --git a/sqldb/migrations.go b/sqldb/migrations.go index 9d394ceed1..42b316c4a4 100644 --- a/sqldb/migrations.go +++ b/sqldb/migrations.go @@ -2,22 +2,113 @@ package sqldb import ( "bytes" + "context" + "database/sql" "errors" + "fmt" "io" "io/fs" "net/http" "strings" + "time" "github.com/btcsuite/btclog/v2" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/source/httpfs" + "github.com/lightningnetwork/lnd/sqldb/sqlc" ) +var ( + // migrationConfig defines a list of migrations to be applied to the + // database. Each migration is assigned a version number, determining + // its execution order. + // The schema version, tracked by golang-migrate, ensures migrations are + // applied to the correct schema. For migrations involving only schema + // changes, the migration function can be left nil. For custom + // migrations an implemented migration function is required. + // + // NOTE: The migration function may have runtime dependencies, which + // must be injected during runtime. + migrationConfig = []MigrationConfig{ + { + Name: "000001_invoices", + Version: 1, + SchemaVersion: 1, + }, + { + Name: "000002_amp_invoices", + Version: 2, + SchemaVersion: 2, + }, + { + Name: "000003_invoice_events", + Version: 3, + SchemaVersion: 3, + }, + { + Name: "000004_invoice_expiry_fix", + Version: 4, + SchemaVersion: 4, + }, + { + Name: "000005_migration_tracker", + Version: 5, + SchemaVersion: 5, + }, + { + Name: "000006_invoice_migration", + Version: 6, + SchemaVersion: 6, + // A migration function is may be attached to this + // migration to migrate KV invoices to the native SQL + // schema. This is optional and can be disabled by the + // user. + }, + } +) + +// MigrationConfig is a configuration struct that describes SQL migrations. Each +// migration is associated with a specific schema version and a global database +// version. Migrations are applied in the order of their global database +// version. If a migration includes a non-nil MigrationFn, it is executed after +// the SQL schema has been migrated to the corresponding schema version. +type MigrationConfig struct { + // Name is the name of the migration. + Name string + + // Version represents the "global" database version for this migration. + // Unlike the schema version tracked by golang-migrate, it encompasses + // all migrations, including those managed by golang-migrate as well + // as custom in-code migrations. + Version int + + // SchemaVersion represents the schema version tracked by golang-migrate + // at which the migration is applied. + SchemaVersion int + + // MigrationFn is the function executed for custom migrations at the + // specified version. It is used to handle migrations that cannot be + // performed through SQL alone. If set to nil, no custom migration is + // applied. + MigrationFn func(tx *sqlc.Queries) error +} + // MigrationTarget is a functional option that can be passed to applyMigrations // to specify a target version to migrate to. type MigrationTarget func(mig *migrate.Migrate) error +// MigrationExecutor is an interface that abstracts the migration functionality. +type MigrationExecutor interface { + // CurrentSchemaVersion returns the current schema version of the + // database. + CurrentSchemaVersion() (int, error) + + // ExecuteMigrations runs migrations for the database, depending on the + // target given, either all migrations or up to a given version. + ExecuteMigrations(target MigrationTarget) error +} + var ( // TargetLatest is a MigrationTarget that migrates to the latest // version available. @@ -34,6 +125,14 @@ var ( } ) +// GetMigrations returns a copy of the migration configuration. +func GetMigrations() []MigrationConfig { + migrations := make([]MigrationConfig, len(migrationConfig)) + copy(migrations, migrationConfig) + + return migrations +} + // migrationLogger is a logger that wraps the passed btclog.Logger so it can be // used to log migrations. type migrationLogger struct { @@ -216,3 +315,117 @@ func (t *replacerFile) Close() error { // instance, so there's nothing to do for us here. return nil } + +// MigrationTxOptions is the implementation of the TxOptions interface for +// migration transactions. +type MigrationTxOptions struct { +} + +// ReadOnly returns false to indicate that migration transactions are not read +// only. +func (m *MigrationTxOptions) ReadOnly() bool { + return false +} + +// ApplyMigrations applies the provided migrations to the database in sequence. +// It ensures migrations are executed in the correct order, applying both custom +// migration functions and SQL migrations as needed. +func ApplyMigrations(ctx context.Context, db *BaseDB, + migrator MigrationExecutor, migrations []MigrationConfig) error { + + // Ensure that the migrations are sorted by version. + for i := 0; i < len(migrations); i++ { + if migrations[i].Version != i+1 { + return fmt.Errorf("migration version %d is out of "+ + "order. Expected %d", migrations[i].Version, + i+1) + } + } + // Construct a transaction executor to apply custom migrations. + executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries { + return db.WithTx(tx) + }) + + currentVersion := 0 + version, err := db.GetDatabaseVersion(ctx) + if !errors.Is(err, sql.ErrNoRows) { + if err != nil { + return fmt.Errorf("error getting current database "+ + "version: %w", err) + } + + currentVersion = int(version) + } + + for _, migration := range migrations { + if migration.Version <= currentVersion { + log.Infof("Skipping migration '%s' (version %d) as it "+ + "has already been applied", migration.Name, + migration.Version) + + continue + } + + log.Infof("Migrating SQL schema to version %d", + migration.SchemaVersion) + + // Execute SQL schema migrations up to the target version. + err = migrator.ExecuteMigrations( + TargetVersion(uint(migration.SchemaVersion)), + ) + if err != nil { + return fmt.Errorf("error executing schema migrations "+ + "to target version %d: %w", + migration.SchemaVersion, err) + } + + var opts MigrationTxOptions + + // Run the custom migration as a transaction to ensure + // atomicity. If successful, mark the migration as complete in + // the migration tracker table. + err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error { + // Apply the migration function if one is provided. + if migration.MigrationFn != nil { + log.Infof("Applying custom migration '%v' "+ + "(version %d) to schema version %d", + migration.Name, migration.Version, + migration.SchemaVersion) + + err = migration.MigrationFn(tx) + if err != nil { + return fmt.Errorf("error applying "+ + "migration '%v' (version %d) "+ + "to schema version %d: %w", + migration.Name, + migration.Version, + migration.SchemaVersion, err) + } + + log.Infof("Migration '%v' (version %d) "+ + "applied ", migration.Name, + migration.Version) + } + + // Mark the migration as complete by adding the version + // to the migration tracker table along with the current + // timestamp. + err = tx.SetMigration(ctx, sqlc.SetMigrationParams{ + Version: int32(migration.Version), + MigrationTime: time.Now(), + }) + if err != nil { + return fmt.Errorf("error setting migration "+ + "version %d: %w", migration.Version, + err) + } + + return nil + }, func() {}) + if err != nil { + return err + } + } + + return nil +} diff --git a/sqldb/migrations_test.go b/sqldb/migrations_test.go index cd55e92cb8..385840364c 100644 --- a/sqldb/migrations_test.go +++ b/sqldb/migrations_test.go @@ -2,8 +2,15 @@ package sqldb import ( "context" + "database/sql" + "fmt" + "path/filepath" "testing" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5" + sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/lightningnetwork/lnd/sqldb/sqlc" "github.com/stretchr/testify/require" ) @@ -152,3 +159,296 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) { require.NoError(t, err) require.Equal(t, expected, invoices) } + +// TestCustomMigration tests that a custom in-code migrations are correctly +// executed during the migration process. +func TestCustomMigration(t *testing.T) { + var customMigrationLog []string + + logMigration := func(name string) { + customMigrationLog = append(customMigrationLog, name) + } + + // Some migrations to use for both the failure and success tests. Note + // that the migrations are not in order to test that they are executed + // in the correct order. + migrations := []MigrationConfig{ + { + Name: "1", + Version: 1, + SchemaVersion: 1, + MigrationFn: func(*sqlc.Queries) error { + logMigration("1") + + return nil + }, + }, + { + Name: "2", + Version: 2, + SchemaVersion: 1, + MigrationFn: func(*sqlc.Queries) error { + logMigration("2") + + return nil + }, + }, + { + Name: "3", + Version: 3, + SchemaVersion: 2, + MigrationFn: func(*sqlc.Queries) error { + logMigration("3") + + return nil + }, + }, + } + + tests := []struct { + name string + migrations []MigrationConfig + expectedSuccess bool + expectedMigrationLog []string + expectedSchemaVersion int + expectedVersion int + }{ + { + name: "success", + migrations: migrations, + expectedSuccess: true, + expectedMigrationLog: []string{"1", "2", "3"}, + expectedSchemaVersion: 2, + expectedVersion: 3, + }, + { + name: "unordered migrations", + migrations: append([]MigrationConfig{ + { + Name: "4", + Version: 4, + SchemaVersion: 3, + MigrationFn: func(*sqlc.Queries) error { + logMigration("4") + + return nil + }, + }, + }, migrations...), + expectedSuccess: false, + expectedMigrationLog: nil, + expectedSchemaVersion: 0, + }, + { + name: "failure of migration 4", + migrations: append(migrations, MigrationConfig{ + Name: "4", + Version: 4, + SchemaVersion: 3, + MigrationFn: func(*sqlc.Queries) error { + return fmt.Errorf("migration 4 failed") + }, + }), + expectedSuccess: false, + expectedMigrationLog: []string{"1", "2", "3"}, + // Since schema migration is a separate step we expect + // that migrating up to 3 succeeded. + expectedSchemaVersion: 3, + // We still remain on version 3 though. + expectedVersion: 3, + }, + { + name: "success of migration 4", + migrations: append(migrations, MigrationConfig{ + Name: "4", + Version: 4, + SchemaVersion: 3, + MigrationFn: func(*sqlc.Queries) error { + logMigration("4") + + return nil + }, + }), + expectedSuccess: true, + expectedMigrationLog: []string{"1", "2", "3", "4"}, + expectedSchemaVersion: 3, + expectedVersion: 4, + }, + } + + ctxb := context.Background() + for _, test := range tests { + // checkSchemaVersion checks the database schema version against + // the expected version. + getSchemaVersion := func(t *testing.T, + driver database.Driver, dbName string) int { + + sqlMigrate, err := migrate.NewWithInstance( + "migrations", nil, dbName, driver, + ) + require.NoError(t, err) + + version, _, err := sqlMigrate.Version() + if err != migrate.ErrNilVersion { + require.NoError(t, err) + } + + return int(version) + } + + t.Run("SQLite "+test.name, func(t *testing.T) { + customMigrationLog = nil + + // First instantiate the database and run the migrations + // including the custom migrations. + t.Logf("Creating new SQLite DB for testing migrations") + + dbFileName := filepath.Join(t.TempDir(), "tmp.db") + var ( + db *SqliteStore + err error + ) + + // Run the migration 3 times to test that the migrations + // are idempotent. + for i := 0; i < 3; i++ { + db, err = NewSqliteStore(&SqliteConfig{ + SkipMigrations: false, + }, dbFileName) + require.NoError(t, err) + + dbToCleanup := db.DB + t.Cleanup(func() { + require.NoError( + t, dbToCleanup.Close(), + ) + }) + + err = db.ApplyAllMigrations( + ctxb, test.migrations, + ) + if test.expectedSuccess { + require.NoError(t, err) + } else { + require.Error(t, err) + + // Also repoen the DB without migrations + // so we can read versions. + db, err = NewSqliteStore(&SqliteConfig{ + SkipMigrations: true, + }, dbFileName) + require.NoError(t, err) + } + + require.Equal(t, + test.expectedMigrationLog, + customMigrationLog, + ) + + // Create the migration executor to be able to + // query the current schema version. + driver, err := sqlite_migrate.WithInstance( + db.DB, &sqlite_migrate.Config{}, + ) + require.NoError(t, err) + + require.Equal( + t, test.expectedSchemaVersion, + getSchemaVersion(t, driver, ""), + ) + + // Check the migraton version in the database. + version, err := db.GetDatabaseVersion(ctxb) + if test.expectedSchemaVersion != 0 { + require.NoError(t, err) + } else { + require.Equal(t, sql.ErrNoRows, err) + } + + require.Equal( + t, test.expectedVersion, int(version), + ) + } + }) + + t.Run("Postgres "+test.name, func(t *testing.T) { + customMigrationLog = nil + + // First create a temporary Postgres database to run + // the migrations on. + fixture := NewTestPgFixture( + t, DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + fixture.TearDown(t) + }) + + dbName := randomDBName(t) + + // Next instantiate the database and run the migrations + // including the custom migrations. + t.Logf("Creating new Postgres DB '%s' for testing "+ + "migrations", dbName) + + _, err := fixture.db.ExecContext( + context.Background(), "CREATE DATABASE "+dbName, + ) + require.NoError(t, err) + + cfg := fixture.GetConfig(dbName) + var db *PostgresStore + + // Run the migration 3 times to test that the migrations + // are idempotent. + for i := 0; i < 3; i++ { + cfg.SkipMigrations = false + db, err = NewPostgresStore(cfg) + require.NoError(t, err) + + err = db.ApplyAllMigrations( + ctxb, test.migrations, + ) + if test.expectedSuccess { + require.NoError(t, err) + } else { + require.Error(t, err) + + // Also repoen the DB without migrations + // so we can read versions. + cfg.SkipMigrations = true + db, err = NewPostgresStore(cfg) + require.NoError(t, err) + } + + require.Equal(t, + test.expectedMigrationLog, + customMigrationLog, + ) + + // Create the migration executor to be able to + // query the current version. + driver, err := pgx_migrate.WithInstance( + db.DB, &pgx_migrate.Config{}, + ) + require.NoError(t, err) + + require.Equal( + t, test.expectedSchemaVersion, + getSchemaVersion(t, driver, ""), + ) + + // Check the migraton version in the database. + version, err := db.GetDatabaseVersion(ctxb) + if test.expectedSchemaVersion != 0 { + require.NoError(t, err) + } else { + require.Equal(t, sql.ErrNoRows, err) + } + + require.Equal( + t, test.expectedVersion, int(version), + ) + } + }) + } +} diff --git a/sqldb/no_sqlite.go b/sqldb/no_sqlite.go index 9ea35c43c6..9eb016239e 100644 --- a/sqldb/no_sqlite.go +++ b/sqldb/no_sqlite.go @@ -2,7 +2,15 @@ package sqldb -import "fmt" +import ( + "context" + "fmt" +) + +var ( + // Make sure SqliteStore implements the DB interface. + _ DB = (*SqliteStore)(nil) +) // SqliteStore is a database store implementation that uses a sqlite backend. type SqliteStore struct { @@ -16,3 +24,17 @@ type SqliteStore struct { func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { return nil, fmt.Errorf("SQLite backend not supported in WebAssembly") } + +// GetBaseDB returns the underlying BaseDB instance for the SQLite store. +// It is a trivial helper method to comply with the sqldb.DB interface. +func (s *SqliteStore) GetBaseDB() *BaseDB { + return s.BaseDB +} + +// ApplyAllMigrations applices both the SQLC and custom in-code migrations to +// the SQLite database. +func (s *SqliteStore) ApplyAllMigrations(context.Context, + []MigrationConfig) error { + + return fmt.Errorf("SQLite backend not supported in WebAssembly") +} diff --git a/sqldb/postgres.go b/sqldb/postgres.go index c855391574..a91581cd41 100644 --- a/sqldb/postgres.go +++ b/sqldb/postgres.go @@ -1,6 +1,7 @@ package sqldb import ( + "context" "database/sql" "fmt" "net/url" @@ -32,6 +33,12 @@ var ( "BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY", "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", } + + // Make sure PostgresStore implements the MigrationExecutor interface. + _ MigrationExecutor = (*PostgresStore)(nil) + + // Make sure PostgresStore implements the DB interface. + _ DB = (*PostgresStore)(nil) ) // replacePasswordInDSN takes a DSN string and returns it with the password @@ -92,40 +99,81 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { } log.Infof("Using SQL database '%s'", sanitizedDSN) - rawDB, err := sql.Open("pgx", cfg.Dsn) + db, err := sql.Open("pgx", cfg.Dsn) if err != nil { return nil, err } + // Create the migration tracker table before starting migrations to + // ensure it can be used to track migration progress. Note that a + // corresponding SQLC migration also creates this table, making this + // operation a no-op in that context. Its purpose is to ensure + // compatibility with SQLC query generation. + migrationTrackerSQL := ` + CREATE TABLE IF NOT EXISTS migration_tracker ( + version INTEGER UNIQUE NOT NULL, + migration_time TIMESTAMP NOT NULL + );` + + _, err = db.Exec(migrationTrackerSQL) + if err != nil { + return nil, fmt.Errorf("error creating migration tracker: %w", + err) + } maxConns := defaultMaxConns if cfg.MaxConnections > 0 { maxConns = cfg.MaxConnections } - rawDB.SetMaxOpenConns(maxConns) - rawDB.SetMaxIdleConns(maxConns) - rawDB.SetConnMaxLifetime(connIdleLifetime) + db.SetMaxOpenConns(maxConns) + db.SetMaxIdleConns(maxConns) + db.SetConnMaxLifetime(connIdleLifetime) - queries := sqlc.New(rawDB) + queries := sqlc.New(db) - s := &PostgresStore{ + return &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ - DB: rawDB, + DB: db, Queries: queries, }, - } + }, nil +} + +// GetBaseDB returns the underlying BaseDB instance for the Postgres store. +// It is a trivial helper method to comply with the sqldb.DB interface. +func (s *PostgresStore) GetBaseDB() *BaseDB { + return s.BaseDB +} + +// ApplyAllMigrations applices both the SQLC and custom in-code migrations to +// the Postgres database. +func (s *PostgresStore) ApplyAllMigrations(ctx context.Context, + migrations []MigrationConfig) error { // Execute migrations unless configured to skip them. - if !cfg.SkipMigrations { - err := s.ExecuteMigrations(TargetLatest) - if err != nil { - return nil, fmt.Errorf("error executing migrations: %w", - err) - } + if s.cfg.SkipMigrations { + return nil + } + + return ApplyMigrations(ctx, s.BaseDB, s, migrations) +} + +// CurrentSchemaVersion returns the current schema version of the Postgres +// database. +func (s *PostgresStore) CurrentSchemaVersion() (int, error) { + driver, err := pgx_migrate.WithInstance(s.DB, &pgx_migrate.Config{}) + if err != nil { + return 0, fmt.Errorf("error creating postgres migrator: %w", + err) + } + + version, _, err := driver.Version() + if err != nil { + return 0, fmt.Errorf("error getting current version: %w", err) } - return s, nil + return version, nil } // ExecuteMigrations runs migrations for the Postgres database, depending on the diff --git a/sqldb/postgres_fixture.go b/sqldb/postgres_fixture.go index da5769c429..ce21aab7d4 100644 --- a/sqldb/postgres_fixture.go +++ b/sqldb/postgres_fixture.go @@ -151,6 +151,10 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore { store, err := NewPostgresStore(cfg) require.NoError(t, err) + require.NoError(t, store.ApplyAllMigrations( + context.Background(), GetMigrations()), + ) + return store } diff --git a/sqldb/sqlc/amp_invoices.sql.go b/sqldb/sqlc/amp_invoices.sql.go index e47b1c803d..182848e146 100644 --- a/sqldb/sqlc/amp_invoices.sql.go +++ b/sqldb/sqlc/amp_invoices.sql.go @@ -235,6 +235,35 @@ func (q *Queries) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, err return invoice_id, err } +const insertAMPSubInvoice = `-- name: InsertAMPSubInvoice :exec +INSERT INTO amp_sub_invoices ( + set_id, state, created_at, settled_at, settle_index, invoice_id +) VALUES ( + $1, $2, $3, $4, $5, $6 +) +` + +type InsertAMPSubInvoiceParams struct { + SetID []byte + State int16 + CreatedAt time.Time + SettledAt sql.NullTime + SettleIndex sql.NullInt64 + InvoiceID int64 +} + +func (q *Queries) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error { + _, err := q.db.ExecContext(ctx, insertAMPSubInvoice, + arg.SetID, + arg.State, + arg.CreatedAt, + arg.SettledAt, + arg.SettleIndex, + arg.InvoiceID, + ) + return err +} + const insertAMPSubInvoiceHTLC = `-- name: InsertAMPSubInvoiceHTLC :exec INSERT INTO amp_sub_invoice_htlcs ( invoice_id, set_id, htlc_id, root_share, child_index, hash, preimage diff --git a/sqldb/sqlc/invoices.sql.go b/sqldb/sqlc/invoices.sql.go index 9e31380abb..1cd7dfff4e 100644 --- a/sqldb/sqlc/invoices.sql.go +++ b/sqldb/sqlc/invoices.sql.go @@ -11,6 +11,15 @@ import ( "time" ) +const clearKVInvoiceHashIndex = `-- name: ClearKVInvoiceHashIndex :exec +DELETE FROM invoice_payment_hashes +` + +func (q *Queries) ClearKVInvoiceHashIndex(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, clearKVInvoiceHashIndex) + return err +} + const deleteCanceledInvoices = `-- name: DeleteCanceledInvoices :execresult DELETE FROM invoices @@ -182,11 +191,8 @@ WHERE ( i.hash = $3 OR $3 IS NULL ) AND ( - i.preimage = $4 OR + i.payment_addr = $4 OR $4 IS NULL -) AND ( - i.payment_addr = $5 OR - $5 IS NULL ) GROUP BY i.id LIMIT 2 @@ -196,7 +202,6 @@ type GetInvoiceParams struct { SetID []byte AddIndex sql.NullInt64 Hash []byte - Preimage []byte PaymentAddr []byte } @@ -208,7 +213,6 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi arg.SetID, arg.AddIndex, arg.Hash, - arg.Preimage, arg.PaymentAddr, ) if err != nil { @@ -251,6 +255,38 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi return items, nil } +const getInvoiceByHash = `-- name: GetInvoiceByHash :one +SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at +FROM invoices i +WHERE i.hash = $1 +` + +func (q *Queries) GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) { + row := q.db.QueryRowContext(ctx, getInvoiceByHash, hash) + var i Invoice + err := row.Scan( + &i.ID, + &i.Hash, + &i.Preimage, + &i.SettleIndex, + &i.SettledAt, + &i.Memo, + &i.AmountMsat, + &i.CltvDelta, + &i.Expiry, + &i.PaymentAddr, + &i.PaymentRequest, + &i.PaymentRequestHash, + &i.State, + &i.AmountPaidMsat, + &i.IsAmp, + &i.IsHodl, + &i.IsKeysend, + &i.CreatedAt, + ) + return i, err +} + const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at FROM invoices i @@ -405,6 +441,19 @@ func (q *Queries) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]Invoi return items, nil } +const getKVInvoicePaymentHashByAddIndex = `-- name: GetKVInvoicePaymentHashByAddIndex :one +SELECT hash +FROM invoice_payment_hashes +WHERE add_index = $1 +` + +func (q *Queries) GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getKVInvoicePaymentHashByAddIndex, addIndex) + var hash []byte + err := row.Scan(&hash) + return hash, err +} + const insertInvoice = `-- name: InsertInvoice :one INSERT INTO invoices ( hash, preimage, memo, amount_msat, cltv_delta, expiry, payment_addr, @@ -533,6 +582,79 @@ func (q *Queries) InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertI return err } +const insertKVInvoiceKeyAndAddIndex = `-- name: InsertKVInvoiceKeyAndAddIndex :exec +INSERT INTO invoice_payment_hashes ( + id, add_index +) VALUES ( + $1, $2 +) +` + +type InsertKVInvoiceKeyAndAddIndexParams struct { + ID int64 + AddIndex int64 +} + +func (q *Queries) InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error { + _, err := q.db.ExecContext(ctx, insertKVInvoiceKeyAndAddIndex, arg.ID, arg.AddIndex) + return err +} + +const insertMigratedInvoice = `-- name: InsertMigratedInvoice :one +INSERT INTO invoices ( + hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta, + expiry, payment_addr, payment_request, payment_request_hash, state, + amount_paid_msat, is_amp, is_hodl, is_keysend, created_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 +) RETURNING id +` + +type InsertMigratedInvoiceParams struct { + Hash []byte + Preimage []byte + SettleIndex sql.NullInt64 + SettledAt sql.NullTime + Memo sql.NullString + AmountMsat int64 + CltvDelta sql.NullInt32 + Expiry int32 + PaymentAddr []byte + PaymentRequest sql.NullString + PaymentRequestHash []byte + State int16 + AmountPaidMsat int64 + IsAmp bool + IsHodl bool + IsKeysend bool + CreatedAt time.Time +} + +func (q *Queries) InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertMigratedInvoice, + arg.Hash, + arg.Preimage, + arg.SettleIndex, + arg.SettledAt, + arg.Memo, + arg.AmountMsat, + arg.CltvDelta, + arg.Expiry, + arg.PaymentAddr, + arg.PaymentRequest, + arg.PaymentRequestHash, + arg.State, + arg.AmountPaidMsat, + arg.IsAmp, + arg.IsHodl, + arg.IsKeysend, + arg.CreatedAt, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + const nextInvoiceSettleIndex = `-- name: NextInvoiceSettleIndex :one UPDATE invoice_sequences SET current_value = current_value + 1 WHERE name = 'settle_index' @@ -546,6 +668,22 @@ func (q *Queries) NextInvoiceSettleIndex(ctx context.Context) (int64, error) { return current_value, err } +const setKVInvoicePaymentHash = `-- name: SetKVInvoicePaymentHash :exec +UPDATE invoice_payment_hashes +SET hash = $2 +WHERE id = $1 +` + +type SetKVInvoicePaymentHashParams struct { + ID int64 + Hash []byte +} + +func (q *Queries) SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error { + _, err := q.db.ExecContext(ctx, setKVInvoicePaymentHash, arg.ID, arg.Hash) + return err +} + const updateInvoiceAmountPaid = `-- name: UpdateInvoiceAmountPaid :execresult UPDATE invoices SET amount_paid_msat = $2 diff --git a/sqldb/sqlc/migration.sql.go b/sqldb/sqlc/migration.sql.go new file mode 100644 index 0000000000..d65ff74f5d --- /dev/null +++ b/sqldb/sqlc/migration.sql.go @@ -0,0 +1,60 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: migration.sql + +package sqlc + +import ( + "context" + "time" +) + +const getDatabaseVersion = `-- name: GetDatabaseVersion :one +SELECT + version +FROM + migration_tracker +ORDER BY + version DESC +LIMIT 1 +` + +func (q *Queries) GetDatabaseVersion(ctx context.Context) (int32, error) { + row := q.db.QueryRowContext(ctx, getDatabaseVersion) + var version int32 + err := row.Scan(&version) + return version, err +} + +const getMigration = `-- name: GetMigration :one +SELECT + migration_time +FROM + migration_tracker +WHERE + version = $1 +` + +func (q *Queries) GetMigration(ctx context.Context, version int32) (time.Time, error) { + row := q.db.QueryRowContext(ctx, getMigration, version) + var migration_time time.Time + err := row.Scan(&migration_time) + return migration_time, err +} + +const setMigration = `-- name: SetMigration :exec +INSERT INTO + migration_tracker (version, migration_time) +VALUES ($1, $2) +` + +type SetMigrationParams struct { + Version int32 + MigrationTime time.Time +} + +func (q *Queries) SetMigration(ctx context.Context, arg SetMigrationParams) error { + _, err := q.db.ExecContext(ctx, setMigration, arg.Version, arg.MigrationTime) + return err +} diff --git a/sqldb/sqlc/migrations/000005_migration_tracker.down.sql b/sqldb/sqlc/migrations/000005_migration_tracker.down.sql new file mode 100644 index 0000000000..5f86e385c2 --- /dev/null +++ b/sqldb/sqlc/migrations/000005_migration_tracker.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS migration_tracker; diff --git a/sqldb/sqlc/migrations/000005_migration_tracker.up.sql b/sqldb/sqlc/migrations/000005_migration_tracker.up.sql new file mode 100644 index 0000000000..4b556744a8 --- /dev/null +++ b/sqldb/sqlc/migrations/000005_migration_tracker.up.sql @@ -0,0 +1,17 @@ +-- The migration_tracker table keeps track of migrations that have been applied +-- to the database. This table ensures that migrations are idempotent and are +-- only run once. It tracks a global database version that encompasses both +-- schema migrations handled by golang-migrate and custom in-code migrations +-- for more complex data conversions that cannot be expressed in pure SQL. +CREATE TABLE IF NOT EXISTS migration_tracker ( + -- version is the global version of the migration. Note that we + -- intentionally don't set it as PRIMARY KEY as it'd auto increment on + -- SQLite and our sqlc workflow will replace it with an auto incementing + -- SERIAL on Postgres too. UNIQUE achieves the same effect without the + -- auto increment. + version INTEGER UNIQUE NOT NULL, + + -- migration_time is the timestamp at which the migration was run. + migration_time TIMESTAMP NOT NULL +); + diff --git a/sqldb/sqlc/migrations/000006_invoice_migration.down.sql b/sqldb/sqlc/migrations/000006_invoice_migration.down.sql new file mode 100644 index 0000000000..a95d34f3a6 --- /dev/null +++ b/sqldb/sqlc/migrations/000006_invoice_migration.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS invoice_payment_hashes; diff --git a/sqldb/sqlc/migrations/000006_invoice_migration.up.sql b/sqldb/sqlc/migrations/000006_invoice_migration.up.sql new file mode 100644 index 0000000000..b95628d416 --- /dev/null +++ b/sqldb/sqlc/migrations/000006_invoice_migration.up.sql @@ -0,0 +1,17 @@ +-- invoice_payment_hashes table contains the hash of the invoices. This table +-- is used during KV to SQL invoice migration as in our KV representation we +-- don't have a mapping from hash to add index. +CREATE TABLE IF NOT EXISTS invoice_payment_hashes ( + -- id represents is the key of the invoice in the KV store. + id BIGINT NOT NULL PRIMARY KEY, + + -- add_index is the KV add index of the invoice. + add_index BIGINT NOT NULL, + + -- hash is the payment hash for this invoice. + hash BLOB +); + +-- Create an indexes on the add_index and hash columns to speed up lookups. +CREATE INDEX IF NOT EXISTS invoice_payment_hashes_add_index_idx ON invoice_payment_hashes(add_index); +CREATE INDEX IF NOT EXISTS invoice_payment_hashes_hash_idx ON invoice_payment_hashes(hash); diff --git a/sqldb/sqlc/models.go b/sqldb/sqlc/models.go index 83be5a708f..fdc5a83de0 100644 --- a/sqldb/sqlc/models.go +++ b/sqldb/sqlc/models.go @@ -87,7 +87,18 @@ type InvoiceHtlcCustomRecord struct { HtlcID int64 } +type InvoicePaymentHash struct { + ID int64 + AddIndex int64 + Hash []byte +} + type InvoiceSequence struct { Name string CurrentValue int64 } + +type MigrationTracker struct { + Version int32 + MigrationTime time.Time +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 04b61c7007..c63f7fadb8 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -7,9 +7,11 @@ package sqlc import ( "context" "database/sql" + "time" ) type Querier interface { + ClearKVInvoiceHashIndex(ctx context.Context) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) @@ -17,19 +19,26 @@ type Querier interface { FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) + GetDatabaseVersion(ctx context.Context) (int32, error) // This method may return more than one invoice if filter using multiple fields // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) + GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error) GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) + GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) + GetMigration(ctx context.Context, version int32) (time.Time, error) + InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error) InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error + InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error + InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error @@ -37,6 +46,8 @@ type Querier interface { OnInvoiceCanceled(ctx context.Context, arg OnInvoiceCanceledParams) error OnInvoiceCreated(ctx context.Context, arg OnInvoiceCreatedParams) error OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error + SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error + SetMigration(ctx context.Context, arg SetMigrationParams) error UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error) UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error) diff --git a/sqldb/sqlc/queries/amp_invoices.sql b/sqldb/sqlc/queries/amp_invoices.sql index 1fad75e0da..1184fd2a41 100644 --- a/sqldb/sqlc/queries/amp_invoices.sql +++ b/sqldb/sqlc/queries/amp_invoices.sql @@ -65,3 +65,11 @@ SET preimage = $5 WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4 ); + +-- name: InsertAMPSubInvoice :exec +INSERT INTO amp_sub_invoices ( + set_id, state, created_at, settled_at, settle_index, invoice_id +) VALUES ( + $1, $2, $3, $4, $5, $6 +); + diff --git a/sqldb/sqlc/queries/invoices.sql b/sqldb/sqlc/queries/invoices.sql index 2a49553e65..db1f46e617 100644 --- a/sqldb/sqlc/queries/invoices.sql +++ b/sqldb/sqlc/queries/invoices.sql @@ -7,6 +7,16 @@ INSERT INTO invoices ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15 ) RETURNING id; +-- name: InsertMigratedInvoice :one +INSERT INTO invoices ( + hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta, + expiry, payment_addr, payment_request, payment_request_hash, state, + amount_paid_msat, is_amp, is_hodl, is_keysend, created_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 +) RETURNING id; + + -- name: InsertInvoiceFeature :exec INSERT INTO invoice_features ( invoice_id, feature @@ -37,9 +47,6 @@ WHERE ( ) AND ( i.hash = sqlc.narg('hash') OR sqlc.narg('hash') IS NULL -) AND ( - i.preimage = sqlc.narg('preimage') OR - sqlc.narg('preimage') IS NULL ) AND ( i.payment_addr = sqlc.narg('payment_addr') OR sqlc.narg('payment_addr') IS NULL @@ -47,6 +54,11 @@ WHERE ( GROUP BY i.id LIMIT 2; +-- name: GetInvoiceByHash :one +SELECT i.* +FROM invoices i +WHERE i.hash = $1; + -- name: GetInvoiceBySetID :many SELECT i.* FROM invoices i @@ -169,3 +181,23 @@ INSERT INTO invoice_htlc_custom_records ( SELECT ihcr.htlc_id, key, value FROM invoice_htlcs ih JOIN invoice_htlc_custom_records ihcr ON ih.id=ihcr.htlc_id WHERE ih.invoice_id = $1; + +-- name: InsertKVInvoiceKeyAndAddIndex :exec +INSERT INTO invoice_payment_hashes ( + id, add_index +) VALUES ( + $1, $2 +); + +-- name: SetKVInvoicePaymentHash :exec +UPDATE invoice_payment_hashes +SET hash = $2 +WHERE id = $1; + +-- name: GetKVInvoicePaymentHashByAddIndex :one +SELECT hash +FROM invoice_payment_hashes +WHERE add_index = $1; + +-- name: ClearKVInvoiceHashIndex :exec +DELETE FROM invoice_payment_hashes; diff --git a/sqldb/sqlc/queries/migration.sql b/sqldb/sqlc/queries/migration.sql new file mode 100644 index 0000000000..aed90d1938 --- /dev/null +++ b/sqldb/sqlc/queries/migration.sql @@ -0,0 +1,21 @@ +-- name: SetMigration :exec +INSERT INTO + migration_tracker (version, migration_time) +VALUES ($1, $2); + +-- name: GetMigration :one +SELECT + migration_time +FROM + migration_tracker +WHERE + version = $1; + +-- name: GetDatabaseVersion :one +SELECT + version +FROM + migration_tracker +ORDER BY + version DESC +LIMIT 1; diff --git a/sqldb/sqlite.go b/sqldb/sqlite.go index 99e55d6eaf..b1063d0a2f 100644 --- a/sqldb/sqlite.go +++ b/sqldb/sqlite.go @@ -3,6 +3,7 @@ package sqldb import ( + "context" "database/sql" "fmt" "net/url" @@ -34,6 +35,12 @@ var ( sqliteSchemaReplacements = map[string]string{ "BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY", } + + // Make sure SqliteStore implements the MigrationExecutor interface. + _ MigrationExecutor = (*SqliteStore)(nil) + + // Make sure SqliteStore implements the DB interface. + _ DB = (*SqliteStore)(nil) ) // SqliteStore is a database store implementation that uses a sqlite backend. @@ -102,6 +109,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { return nil, err } + // Create the migration tracker table before starting migrations to + // ensure it can be used to track migration progress. Note that a + // corresponding SQLC migration also creates this table, making this + // operation a no-op in that context. Its purpose is to ensure + // compatibility with SQLC query generation. + migrationTrackerSQL := ` + CREATE TABLE IF NOT EXISTS migration_tracker ( + version INTEGER UNIQUE NOT NULL, + migration_time TIMESTAMP NOT NULL + );` + + _, err = db.Exec(migrationTrackerSQL) + if err != nil { + return nil, fmt.Errorf("error creating migration tracker: %w", + err) + } + db.SetMaxOpenConns(defaultMaxConns) db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(connIdleLifetime) @@ -115,16 +139,45 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { }, } + return s, nil +} + +// GetBaseDB returns the underlying BaseDB instance for the SQLite store. +// It is a trivial helper method to comply with the sqldb.DB interface. +func (s *SqliteStore) GetBaseDB() *BaseDB { + return s.BaseDB +} + +// ApplyAllMigrations applices both the SQLC and custom in-code migrations to +// the SQLite database. +func (s *SqliteStore) ApplyAllMigrations(ctx context.Context, + migrations []MigrationConfig) error { + // Execute migrations unless configured to skip them. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) + if s.cfg.SkipMigrations { + return nil + } - } + return ApplyMigrations(ctx, s.BaseDB, s, migrations) +} + +// CurrentSchemaVersion returns the current schema version of the SQLite +// database. +func (s *SqliteStore) CurrentSchemaVersion() (int, error) { + driver, err := sqlite_migrate.WithInstance( + s.DB, &sqlite_migrate.Config{}, + ) + if err != nil { + return 0, fmt.Errorf("error creating SQLite migrator: %w", + err) } - return s, nil + version, _, err := driver.Version() + if err != nil { + return 0, fmt.Errorf("error getting current version: %w", err) + } + + return version, nil } // ExecuteMigrations runs migrations for the sqlite database, depending on the @@ -160,6 +213,10 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore { }, dbFileName) require.NoError(t, err) + require.NoError(t, sqlDB.ApplyAllMigrations( + context.Background(), GetMigrations()), + ) + t.Cleanup(func() { require.NoError(t, sqlDB.DB.Close()) })