Skip to content

Commit

Permalink
invoices: add full invoice db sql migration
Browse files Browse the repository at this point in the history
  • Loading branch information
bhandras committed Aug 14, 2024
1 parent ae2e4b0 commit a7bf598
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 2 deletions.
99 changes: 99 additions & 0 deletions invoices/sql_migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package invoices
import (
"context"
"fmt"
"reflect"
"strconv"
"time"

"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/lightningnetwork/lnd/zpay32"
)

// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
Expand Down Expand Up @@ -190,3 +193,99 @@ func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,

return nil
}

// MigrateInvoices migrates all invoices from the old database to the new SQL
// schema. The migration is done in a single transaction to ensure that all
// invoices are migrated or none at all.
func MigrateInvoices(ctx context.Context, db InvoiceDB,
sqlStore *SQLStore, netParams *chaincfg.Params, batchSize int) error {

offset := uint64(0)
var ops SQLInvoiceQueriesTxOptions
return sqlStore.db.ExecTx(ctx, &ops, func(tx SQLInvoiceQueries) error {
for {
query := InvoiceQuery{
IndexOffset: offset,
NumMaxInvoices: uint64(batchSize),
}

queryResult, err := db.QueryInvoices(ctx, query)
if err != nil {
return fmt.Errorf("unable to query invoices: "+
"%v", err)
}

if len(queryResult.Invoices) == 0 {
log.Infof("All invoices migrated")

return nil
}

err = migrateInvoices(
ctx, tx, sqlStore, queryResult.Invoices,
netParams,
)
if err != nil {
return err
}

offset = queryResult.LastIndexOffset
}
}, func() {})
}

func migrateInvoices(ctx context.Context, tx SQLInvoiceQueries,
sqlStore *SQLStore, invoices []Invoice,
netParams *chaincfg.Params) error {

for i, invoice := range invoices {
var paymentHash lntypes.Hash
if invoice.Terms.PaymentPreimage != nil {
paymentHash = invoice.Terms.PaymentPreimage.Hash()
} else {
paymentRequest, err := zpay32.Decode(
string(invoice.PaymentRequest),
netParams,
)
if err != nil {
return fmt.Errorf("unable to decode payment "+
"request for invoice (add_index=%v): "+
"%v", invoice.AddIndex, err)
}

if paymentRequest.PaymentHash != nil {
copy(
paymentHash[:],
paymentRequest.PaymentHash[:],
)
} else {
log.Warnf("Cannot migrate invoice "+
"(add_index=%v)", invoice.AddIndex)

continue
}
}
err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash)
if err != nil {
return fmt.Errorf("unable to migrate invoice(%v): %w",
paymentHash, err)
}

migratedInvoice, err := sqlStore.fetchInvoice(
ctx, tx, InvoiceRefByHash(paymentHash),
)
if err != nil {
return fmt.Errorf("unable to fetch migrated "+
"invoice(%v): %v", paymentHash, err)
}

// Override the add index before checking for equality.
migratedInvoice.AddIndex = invoice.AddIndex
if !reflect.DeepEqual(invoice, *migratedInvoice) {
return fmt.Errorf("migrated invoice does not match "+
"original invoice: %v", paymentHash)
}
}

return nil
}
95 changes: 93 additions & 2 deletions invoices/sql_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
Expand Down Expand Up @@ -349,12 +350,18 @@ func TestMigrateSingleInvoice(t *testing.T) {
}

for _, test := range tests {
test := test

t.Run(test.name+"_SQLite", func(t *testing.T) {
t.Parallel()

store := makeSQLDB(t, true)
testMigrateSingleInvoice(t, store, test.mpp, test.amp)
})

t.Run(test.name+"_Postgres", func(t *testing.T) {
t.Parallel()

store := makeSQLDB(t, false)
testMigrateSingleInvoice(t, store, test.mpp, test.amp)
})
Expand All @@ -368,8 +375,6 @@ func TestMigrateSingleInvoice(t *testing.T) {
func testMigrateSingleInvoice(t *testing.T, store *SQLStore, mpp bool,
amp bool) {

t.Parallel()

ctxb := context.Background()
invoices := make(map[lntypes.Hash]*Invoice)

Expand Down Expand Up @@ -406,3 +411,89 @@ func testMigrateSingleInvoice(t *testing.T, store *SQLStore, mpp bool,
require.Equal(t, *invoice, sqlInvoice)
}
}

func TestMigration(t *testing.T) {
// First create a shared Postgres instance so we don't spawn a new
// docker container for each test.
pgFixture := sqldb.NewTestPgFixture(
t, sqldb.DefaultPostgresFixtureLifetime,
)
t.Cleanup(func() {
pgFixture.TearDown(t)
})

makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore {
var db *sqldb.BaseDB
if sqlite {
db = sqldb.NewTestSqliteDB(t).BaseDB
} else {
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
}

executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) SQLInvoiceQueries {
return db.WithTx(tx)
},
)

testClock := clock.NewTestClock(time.Unix(1, 0))

return NewSQLStore(executor, testClock)
}

// For simplicity we will migrate from one SQL db to another. This is
// because we have a much better control over the data and can easily
// generate random invoices.
// TODO(bhandras): potentially add a test where we migrate from a real
// KV store to the SQL store
store1 := makeSQLDB(t, true)
store2 := makeSQLDB(t, false)
ctxb := context.Background()
const numInvoices = 1111

var ops SQLInvoiceQueriesTxOptions
err := store1.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error {
for i := 0; i < numInvoices; i++ {
mpp := rand.Intn(2) == 1
amp := rand.Intn(2) == 1
invoice := generateTestInvoice(t, mpp, amp)
var hash lntypes.Hash
_, err := crand.Read(hash[:])
require.NoError(t, err)

err = MigrateSingleInvoice(ctxb, tx, invoice, hash)
require.NoError(t, err)
}

return nil
}, func() {})
require.NoError(t, err)

err = MigrateInvoices(ctxb, store1, store2, &chaincfg.SimNetParams, 44)
require.NoError(t, err)

// MigrateInvoices will check if the inserted invoice equals to the
// migrated one, but as a sanity check, we'll also fetch the invoices
// from the store and compare them to the original invoices.
query := InvoiceQuery{
IndexOffset: 0,
// As a sanity check, fetch more invoices than we have to ensure
// that we did not add any extra invoices.
NumMaxInvoices: numInvoices * 2,
}
result1, err := store1.QueryInvoices(ctxb, query)
require.NoError(t, err)
require.Equal(t, numInvoices, len(result1.Invoices))

result2, err := store2.QueryInvoices(ctxb, query)
require.NoError(t, err)
require.Equal(t, numInvoices, len(result2.Invoices))

// Simply zero out the add index so we don't fail on that when
// comparing.
for i := 0; i < numInvoices; i++ {
result1.Invoices[i].AddIndex = 0
result2.Invoices[i].AddIndex = 0
require.Equal(t, result1.Invoices[i], result2.Invoices[i])
}
}

0 comments on commit a7bf598

Please sign in to comment.