diff --git a/autopilot/graph.go b/autopilot/graph.go index 2ce49c1272..b4e415077f 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -3,6 +3,7 @@ package autopilot import ( "bytes" "encoding/hex" + "errors" "net" "sort" "sync/atomic" @@ -11,8 +12,8 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -35,7 +36,7 @@ var ( // // TODO(roasbeef): move inmpl to main package? type databaseChannelGraph struct { - db *channeldb.ChannelGraph + db *graphdb.ChannelGraph } // A compile time assertion to ensure databaseChannelGraph meets the @@ -44,7 +45,7 @@ var _ ChannelGraph = (*databaseChannelGraph)(nil) // ChannelGraphFromDatabase returns an instance of the autopilot.ChannelGraph // backed by a live, open channeldb instance. -func ChannelGraphFromDatabase(db *channeldb.ChannelGraph) ChannelGraph { +func ChannelGraphFromDatabase(db *graphdb.ChannelGraph) ChannelGraph { return &databaseChannelGraph{ db: db, } @@ -54,11 +55,11 @@ func ChannelGraphFromDatabase(db *channeldb.ChannelGraph) ChannelGraph { // channeldb.LightningNode. The wrapper method implement the autopilot.Node // interface. type dbNode struct { - db *channeldb.ChannelGraph + db *graphdb.ChannelGraph tx kvdb.RTx - node *channeldb.LightningNode + node *models.LightningNode } // A compile time assertion to ensure dbNode meets the autopilot.Node @@ -134,7 +135,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // // NOTE: Part of the autopilot.ChannelGraph interface. func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { - return d.db.ForEachNode(func(tx kvdb.RTx, n *channeldb.LightningNode) error { + return d.db.ForEachNode(func(tx kvdb.RTx, + n *models.LightningNode) error { + // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually // open any channels. @@ -157,7 +160,7 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, capacity btcutil.Amount) (*ChannelEdge, *ChannelEdge, error) { - fetchNode := func(pub *btcec.PublicKey) (*channeldb.LightningNode, error) { + fetchNode := func(pub *btcec.PublicKey) (*models.LightningNode, error) { if pub != nil { vertex, err := route.NewVertexFromBytes( pub.SerializeCompressed(), @@ -168,10 +171,10 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, dbNode, err := d.db.FetchLightningNode(vertex) switch { - case err == channeldb.ErrGraphNodeNotFound: + case errors.Is(err, graphdb.ErrGraphNodeNotFound): fallthrough - case err == channeldb.ErrGraphNotFound: - graphNode := &channeldb.LightningNode{ + case errors.Is(err, graphdb.ErrGraphNotFound): + graphNode := &models.LightningNode{ HaveNodeAnnouncement: true, Addresses: []net.Addr{ &net.TCPAddr{ @@ -198,7 +201,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, if err != nil { return nil, err } - dbNode := &channeldb.LightningNode{ + dbNode := &models.LightningNode{ HaveNodeAnnouncement: true, Addresses: []net.Addr{ &net.TCPAddr{ @@ -302,7 +305,7 @@ func (d *databaseChannelGraph) addRandNode() (*btcec.PublicKey, error) { if err != nil { return nil, err } - dbNode := &channeldb.LightningNode{ + dbNode := &models.LightningNode{ HaveNodeAnnouncement: true, Addresses: []net.Addr{ &net.TCPAddr{ @@ -478,7 +481,7 @@ func (m *memChannelGraph) addRandNode() (*btcec.PublicKey, error) { // databaseChannelGraphCached wraps a channeldb.ChannelGraph instance with the // necessary API to properly implement the autopilot.ChannelGraph interface. type databaseChannelGraphCached struct { - db *channeldb.ChannelGraph + db *graphdb.ChannelGraph } // A compile time assertion to ensure databaseChannelGraphCached meets the @@ -487,7 +490,7 @@ var _ ChannelGraph = (*databaseChannelGraphCached)(nil) // ChannelGraphFromCachedDatabase returns an instance of the // autopilot.ChannelGraph backed by a live, open channeldb instance. -func ChannelGraphFromCachedDatabase(db *channeldb.ChannelGraph) ChannelGraph { +func ChannelGraphFromCachedDatabase(db *graphdb.ChannelGraph) ChannelGraph { return &databaseChannelGraphCached{ db: db, } @@ -498,7 +501,7 @@ func ChannelGraphFromCachedDatabase(db *channeldb.ChannelGraph) ChannelGraph { // interface. type dbNodeCached struct { node route.Vertex - channels map[uint64]*channeldb.DirectedChannel + channels map[uint64]*graphdb.DirectedChannel } // A compile time assertion to ensure dbNodeCached meets the autopilot.Node @@ -552,7 +555,7 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { // NOTE: Part of the autopilot.ChannelGraph interface. func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { return dc.db.ForEachNodeCached(func(n route.Vertex, - channels map[uint64]*channeldb.DirectedChannel) error { + channels map[uint64]*graphdb.DirectedChannel) error { if len(channels) > 0 { node := dbNodeCached{ diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 524116aa4c..ab52c55f61 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -8,7 +8,8 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -24,17 +25,21 @@ type testGraph interface { } func newDiskChanGraph(t *testing.T) (testGraph, error) { - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - t.Cleanup(func() { - require.NoError(t, cdb.Close()) + backend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ + DBPath: t.TempDir(), + DBFileName: "graph.db", + NoFreelistSync: true, + AutoCompact: false, + AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, + DBTimeout: kvdb.DefaultDBTimeout, }) + require.NoError(t, err) + + graphDB, err := graphdb.NewChannelGraph(backend) + require.NoError(t, err) return &databaseChannelGraph{ - db: cdb.ChannelGraph(), + db: graphDB, }, nil } diff --git a/chainntnfs/bitcoindnotify/bitcoind_test.go b/chainntnfs/bitcoindnotify/bitcoind_test.go index be336ef6cd..fa51efa5ef 100644 --- a/chainntnfs/bitcoindnotify/bitcoind_test.go +++ b/chainntnfs/bitcoindnotify/bitcoind_test.go @@ -37,11 +37,7 @@ var ( func initHintCache(t *testing.T) *channeldb.HeightHintCache { t.Helper() - db, err := channeldb.Open(t.TempDir()) - require.NoError(t, err, "unable to create db") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) testCfg := channeldb.CacheConfig{ QueryDisable: false, diff --git a/chainntnfs/btcdnotify/btcd_test.go b/chainntnfs/btcdnotify/btcd_test.go index 1cfbff731f..6a1b978548 100644 --- a/chainntnfs/btcdnotify/btcd_test.go +++ b/chainntnfs/btcdnotify/btcd_test.go @@ -33,11 +33,7 @@ var ( func initHintCache(t *testing.T) *channeldb.HeightHintCache { t.Helper() - db, err := channeldb.Open(t.TempDir()) - require.NoError(t, err, "unable to create db") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) testCfg := channeldb.CacheConfig{ QueryDisable: false, diff --git a/chainntnfs/test/test_interface.go b/chainntnfs/test/test_interface.go index 35e63a45e9..99daf54f1e 100644 --- a/chainntnfs/test/test_interface.go +++ b/chainntnfs/test/test_interface.go @@ -1906,10 +1906,8 @@ func TestInterfaces(t *testing.T, targetBackEnd string) { // Initialize a height hint cache for each notifier. tempDir := t.TempDir() - db, err := channeldb.Open(tempDir) - if err != nil { - t.Fatalf("unable to create db: %v", err) - } + db := channeldb.OpenForTesting(t, tempDir) + testCfg := channeldb.CacheConfig{ QueryDisable: false, } diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index da1f8e08ad..e30eafc677 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -23,8 +23,8 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs/btcdnotify" "github.com/lightningnetwork/lnd/chainntnfs/neutrinonotify" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" diff --git a/chainreg/no_chain_backend.go b/chainreg/no_chain_backend.go index 303c8f4cdb..f68202ea9c 100644 --- a/chainreg/no_chain_backend.go +++ b/chainreg/no_chain_backend.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/routing/chainview" ) @@ -94,7 +94,7 @@ func (n *NoChainBackend) DisconnectedBlocks() <-chan *chainview.FilteredBlock { return make(chan *chainview.FilteredBlock) } -func (n *NoChainBackend) UpdateFilter([]channeldb.EdgePoint, uint32) error { +func (n *NoChainBackend) UpdateFilter([]graphdb.EdgePoint, uint32) error { return nil } diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 5d9d769e87..5853b37e45 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -2,13 +2,10 @@ package chanbackup import ( "fmt" - "net" - "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/kvdb" ) // LiveChannelSource is an interface that allows us to query for the set of @@ -20,23 +17,14 @@ type LiveChannelSource interface { // FetchChannel attempts to locate a live channel identified by the // passed chanPoint. Optionally an existing db tx can be supplied. - FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( - *channeldb.OpenChannel, error) -} - -// AddressSource is an interface that allows us to query for the set of -// addresses a node can be connected to. -type AddressSource interface { - // AddrsForNode returns all known addresses for the target node public - // key. - AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) + FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) } // assembleChanBackup attempts to assemble a static channel backup for the // passed open channel. The backup includes all information required to restore // the channel, as well as addressing information so we can find the peer and // reconnect to them to initiate the protocol. -func assembleChanBackup(addrSource AddressSource, +func assembleChanBackup(addrSource channeldb.AddrSource, openChan *channeldb.OpenChannel) (*Single, error) { log.Debugf("Crafting backup for ChannelPoint(%v)", @@ -44,10 +32,13 @@ func assembleChanBackup(addrSource AddressSource, // First, we'll query the channel source to obtain all the addresses // that are associated with the peer for this channel. - nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) + known, nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) if err != nil { return nil, err } + if !known { + return nil, fmt.Errorf("node unknown by address source") + } single := NewSingle(openChan, nodeAddrs) @@ -100,11 +91,11 @@ func buildCloseTxInputs( // the target channel identified by its channel point. If we're unable to find // the target channel, then an error will be returned. func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, - addrSource AddressSource) (*Single, error) { + addrSource channeldb.AddrSource) (*Single, error) { // First, we'll query the channel source to see if the channel is known // and open within the database. - targetChan, err := chanSource.FetchChannel(nil, chanPoint) + targetChan, err := chanSource.FetchChannel(chanPoint) if err != nil { // If we can't find the channel, then we return with an error, // as we have nothing to backup. @@ -124,7 +115,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, // FetchStaticChanBackups will return a plaintext static channel back up for // all known active/open channels within the passed channel source. func FetchStaticChanBackups(chanSource LiveChannelSource, - addrSource AddressSource) ([]Single, error) { + addrSource channeldb.AddrSource) ([]Single, error) { // First, we'll query the backup source for information concerning all // currently open and available channels. diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index 511b1081dc..46ccf4c244 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -8,7 +8,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -40,7 +39,7 @@ func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error) return chans, nil } -func (m *mockChannelSource) FetchChannel(_ kvdb.RTx, chanPoint wire.OutPoint) ( +func (m *mockChannelSource) FetchChannel(chanPoint wire.OutPoint) ( *channeldb.OpenChannel, error) { if m.failQuery { @@ -62,20 +61,19 @@ func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []ne m.addrs[nodeKey] = addrs } -func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { +func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, + []net.Addr, error) { + if m.failQuery { - return nil, fmt.Errorf("fail") + return false, nil, fmt.Errorf("fail") } var nodeKey [33]byte copy(nodeKey[:], nodePub.SerializeCompressed()) addrs, ok := m.addrs[nodeKey] - if !ok { - return nil, fmt.Errorf("can't find addr") - } - return addrs, nil + return ok, addrs, nil } // TestFetchBackupForChan tests that we're able to construct a single channel diff --git a/channel_notifier.go b/channel_notifier.go index 5995eb4097..88a05ac4ce 100644 --- a/channel_notifier.go +++ b/channel_notifier.go @@ -2,24 +2,13 @@ package lnd import ( "fmt" - "net" - "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" ) -// addrSource is an interface that allow us to get the addresses for a target -// node. We'll need this in order to be able to properly proxy the -// notifications to create SCBs. -type addrSource interface { - // AddrsForNode returns all known addresses for the target node public - // key. - AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) -} - // channelNotifier is an implementation of the chanbackup.ChannelNotifier // interface using the existing channelnotifier.ChannelNotifier struct. This // implementation allows us to satisfy all the dependencies of the @@ -32,7 +21,7 @@ type channelNotifier struct { // addrs is an implementation of the addrSource interface that allows // us to get the latest set of addresses for a given node. We'll need // this to be able to create an SCB for new channels. - addrs addrSource + addrs channeldb.AddrSource } // SubscribeChans requests a new channel subscription relative to the initial @@ -56,7 +45,7 @@ func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{ // chanUpdates channel to inform subscribers about new pending or // confirmed channels. sendChanOpenUpdate := func(newOrPendingChan *channeldb.OpenChannel) { - nodeAddrs, err := c.addrs.AddrsForNode( + _, nodeAddrs, err := c.addrs.AddrsForNode( newOrPendingChan.IdentityPub, ) if err != nil { diff --git a/channeldb/addr_source.go b/channeldb/addr_source.go new file mode 100644 index 0000000000..de933ed496 --- /dev/null +++ b/channeldb/addr_source.go @@ -0,0 +1,83 @@ +package channeldb + +import ( + "errors" + "net" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// AddrSource is an interface that allow us to get the addresses for a target +// node. It may combine the results of multiple address sources. +type AddrSource interface { + // AddrsForNode returns all known addresses for the target node public + // key. The returned boolean must indicate if the given node is unknown + // to the backing source. + AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) +} + +// multiAddrSource is an implementation of AddrSource which gathers all the +// known addresses for a given node from multiple backends and de-duplicates the +// results. +type multiAddrSource struct { + sources []AddrSource +} + +// NewMultiAddrSource constructs a new AddrSource which will query all the +// provided sources for a node's addresses and will then de-duplicate the +// results. +func NewMultiAddrSource(sources ...AddrSource) AddrSource { + return &multiAddrSource{ + sources: sources, + } +} + +// AddrsForNode returns all known addresses for the target node public key. It +// queries all the address sources provided and de-duplicates the results. The +// returned boolean is false only if none of the backing sources know of the +// node. +// +// NOTE: this implements the AddrSource interface. +func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool, + []net.Addr, error) { + + if len(c.sources) == 0 { + return false, nil, errors.New("no address sources") + } + + // The multiple address sources will likely contain duplicate addresses, + // so we use a map here to de-dup them. + dedupedAddrs := make(map[string]net.Addr) + + // known will be set to true if any backing source is aware of the node. + var known bool + + // Iterate over all the address sources and query each one for the + // addresses it has for the node in question. + for _, src := range c.sources { + isKnown, addrs, err := src.AddrsForNode(nodePub) + if err != nil { + return false, nil, err + } + + if isKnown { + known = true + } + + for _, addr := range addrs { + dedupedAddrs[addr.String()] = addr + } + } + + // Convert the map into a list we can return. + addrs := make([]net.Addr, 0, len(dedupedAddrs)) + for _, addr := range dedupedAddrs { + addrs = append(addrs, addr) + } + + return known, addrs, nil +} + +// A compile-time check to ensure that multiAddrSource implements the AddrSource +// interface. +var _ AddrSource = (*multiAddrSource)(nil) diff --git a/channeldb/addr_source_test.go b/channeldb/addr_source_test.go new file mode 100644 index 0000000000..85ee30bf53 --- /dev/null +++ b/channeldb/addr_source_test.go @@ -0,0 +1,149 @@ +package channeldb + +import ( + "net" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + addr1 = &net.TCPAddr{IP: (net.IP)([]byte{0x1}), Port: 1} + addr2 = &net.TCPAddr{IP: (net.IP)([]byte{0x2}), Port: 2} + addr3 = &net.TCPAddr{IP: (net.IP)([]byte{0x3}), Port: 3} +) + +// TestMultiAddrSource tests that the multiAddrSource correctly merges and +// deduplicates the results of a set of AddrSource implementations. +func TestMultiAddrSource(t *testing.T) { + t.Parallel() + + var pk1 = newTestPubKey(t) + + t.Run("both sources have results", func(t *testing.T) { + t.Parallel() + + var ( + src1 = newMockAddrSource(t) + src2 = newMockAddrSource(t) + ) + t.Cleanup(func() { + src1.AssertExpectations(t) + src2.AssertExpectations(t) + }) + + // Let source 1 know of 2 addresses (addr 1 and 2) for node 1. + src1.On("AddrsForNode", pk1).Return( + true, []net.Addr{addr1, addr2}, nil, + ).Once() + + // Let source 2 know of 2 addresses (addr 2 and 3) for node 1. + src2.On("AddrsForNode", pk1).Return( + true, []net.Addr{addr2, addr3}, nil, + []net.Addr{addr2, addr3}, nil, + ).Once() + + // Create a multi-addr source that consists of both source 1 + // and 2. + multiSrc := NewMultiAddrSource(src1, src2) + + // Query it for the addresses known for node 1. The results + // should contain addr 1, 2 and 3. + known, addrs, err := multiSrc.AddrsForNode(pk1) + require.NoError(t, err) + require.True(t, known) + require.ElementsMatch(t, addrs, []net.Addr{addr1, addr2, addr3}) + }) + + t.Run("only once source has results", func(t *testing.T) { + t.Parallel() + + var ( + src1 = newMockAddrSource(t) + src2 = newMockAddrSource(t) + ) + t.Cleanup(func() { + src1.AssertExpectations(t) + src2.AssertExpectations(t) + }) + + // Let source 1 know of address 1 for node 1. + src1.On("AddrsForNode", pk1).Return( + true, []net.Addr{addr1}, nil, + ).Once() + src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + + // Create a multi-addr source that consists of both source 1 + // and 2. + multiSrc := NewMultiAddrSource(src1, src2) + + // Query it for the addresses known for node 1. The results + // should contain addr 1. + known, addrs, err := multiSrc.AddrsForNode(pk1) + require.NoError(t, err) + require.True(t, known) + require.ElementsMatch(t, addrs, []net.Addr{addr1}) + }) + + t.Run("unknown address", func(t *testing.T) { + t.Parallel() + + var ( + src1 = newMockAddrSource(t) + src2 = newMockAddrSource(t) + ) + t.Cleanup(func() { + src1.AssertExpectations(t) + src2.AssertExpectations(t) + }) + + // Create a multi-addr source that consists of both source 1 + // and 2. Neither source known of node 1. + multiSrc := NewMultiAddrSource(src1, src2) + + src1.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once() + + // Query it for the addresses known for node 1. It should return + // false to indicate that the node is unknown to all backing + // sources. + known, addrs, err := multiSrc.AddrsForNode(pk1) + require.NoError(t, err) + require.False(t, known) + require.Empty(t, addrs) + }) +} + +type mockAddrSource struct { + t *testing.T + mock.Mock +} + +var _ AddrSource = (*mockAddrSource)(nil) + +func newMockAddrSource(t *testing.T) *mockAddrSource { + return &mockAddrSource{t: t} +} + +func (m *mockAddrSource) AddrsForNode(pub *btcec.PublicKey) (bool, []net.Addr, + error) { + + args := m.Called(pub) + if args.Get(1) == nil { + return args.Bool(0), nil, args.Error(2) + } + + addrs, ok := args.Get(1).([]net.Addr) + require.True(m.t, ok) + + return args.Bool(0), addrs, args.Error(2) +} + +func newTestPubKey(t *testing.T) *btcec.PublicKey { + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return priv.PubKey() +} diff --git a/channeldb/channel.go b/channeldb/channel.go index c21716a456..200dc39ff5 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -19,8 +19,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/walletdb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -1330,7 +1331,7 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, // With the bucket for the node and chain fetched, we can now go down // another level, for this channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } chanBucket := chainBucket.NestedReadBucket(chanPointBuf.Bytes()) @@ -1377,7 +1378,7 @@ func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // With the bucket for the node and chain fetched, we can now go down // another level, for this channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes()) @@ -1422,7 +1423,8 @@ func (c *OpenChannel) fullSync(tx kvdb.RwTx) error { } var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil { + err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint) + if err != nil { return err } @@ -3822,7 +3824,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, } var chanPointBuf bytes.Buffer - err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint) + err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint) if err != nil { return err } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index e92692201d..2cac0baced 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -2,6 +2,7 @@ package channeldb import ( "bytes" + "encoding/hex" "math/rand" "net" "reflect" @@ -10,20 +11,22 @@ import ( "testing" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" _ "github.com/btcsuite/btcwallet/walletdb/bdb" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" @@ -43,8 +46,20 @@ var ( } privKey, pubKey = btcec.PrivKeyFromBytes(key[:]) + testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e" + + "949ddfa2965fb6caa1bf0314f882d7") + testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b67" + + "00d72a0ead154c03be696a292d24ae") + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) + wireSig, _ = lnwire.NewSigFromSignature(testSig) + testPub = route.Vertex{2, 202, 4} + testClock = clock.NewTestClock(testNow) // defaultPendingHeight is the default height at which we set diff --git a/channeldb/codec.go b/channeldb/codec.go index 07f125742f..8c39f4d731 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -11,38 +11,13 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/tlv" ) -// writeOutpoint writes an outpoint to the passed writer using the minimal -// amount of bytes possible. -func writeOutpoint(w io.Writer, o *wire.OutPoint) error { - if _, err := w.Write(o.Hash[:]); err != nil { - return err - } - if err := binary.Write(w, byteOrder, o.Index); err != nil { - return err - } - - return nil -} - -// readOutpoint reads an outpoint from the passed reader that was previously -// written using the writeOutpoint struct. -func readOutpoint(r io.Reader, o *wire.OutPoint) error { - if _, err := io.ReadFull(r, o.Hash[:]); err != nil { - return err - } - if err := binary.Read(r, byteOrder, &o.Index); err != nil { - return err - } - - return nil -} - // UnknownElementType is an error returned when the codec is unable to encode or // decode a particular type. type UnknownElementType struct { @@ -98,7 +73,7 @@ func WriteElement(w io.Writer, element interface{}) error { } case wire.OutPoint: - return writeOutpoint(w, &e) + return graphdb.WriteOutpoint(w, &e) case lnwire.ShortChannelID: if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { @@ -218,7 +193,7 @@ func WriteElement(w io.Writer, element interface{}) error { } case net.Addr: - if err := serializeAddr(w, e); err != nil { + if err := graphdb.SerializeAddr(w, e); err != nil { return err } @@ -228,7 +203,7 @@ func WriteElement(w io.Writer, element interface{}) error { } for _, addr := range e { - if err := serializeAddr(w, addr); err != nil { + if err := graphdb.SerializeAddr(w, addr); err != nil { return err } } @@ -288,7 +263,7 @@ func ReadElement(r io.Reader, element interface{}) error { } case *wire.OutPoint: - return readOutpoint(r, e) + return graphdb.ReadOutpoint(r, e) case *lnwire.ShortChannelID: var a uint64 @@ -451,7 +426,7 @@ func ReadElement(r io.Reader, element interface{}) error { } case *net.Addr: - addr, err := deserializeAddr(r) + addr, err := graphdb.DeserializeAddr(r) if err != nil { return err } @@ -465,7 +440,7 @@ func ReadElement(r io.Reader, element interface{}) error { *e = make([]net.Addr, numAddrs) for i := uint32(0); i < numAddrs; i++ { - addr, err := deserializeAddr(r) + addr, err := graphdb.DeserializeAddr(r) if err != nil { return err } diff --git a/channeldb/db.go b/channeldb/db.go index 92e0498ece..bf7909ba52 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -30,10 +30,11 @@ import ( "github.com/lightningnetwork/lnd/channeldb/migration33" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" ) const ( @@ -335,7 +336,6 @@ type DB struct { channelStateDB *ChannelStateDB dbPath string - graph *ChannelGraph clock clock.Clock dryRun bool keepFailedPaymentAttempts bool @@ -346,38 +346,37 @@ type DB struct { noRevLogAmtData bool } -// Open opens or creates channeldb. Any necessary schemas migrations due -// to updates will take place as necessary. -// TODO(bhandras): deprecate this function. -func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { - opts := DefaultOptions() - for _, modifier := range modifiers { - modifier(&opts) - } +// OpenForTesting opens or creates a channeldb to be used for tests. Any +// necessary schemas migrations due to updates will take place as necessary. +func OpenForTesting(t testing.TB, dbPath string, + modifiers ...OptionModifier) *DB { backend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ DBPath: dbPath, DBFileName: dbName, - NoFreelistSync: opts.NoFreelistSync, - AutoCompact: opts.AutoCompact, - AutoCompactMinAge: opts.AutoCompactMinAge, - DBTimeout: opts.DBTimeout, + NoFreelistSync: true, + AutoCompact: false, + AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, + DBTimeout: kvdb.DefaultDBTimeout, }) - if err != nil { - return nil, err - } + require.NoError(t, err) db, err := CreateWithBackend(backend, modifiers...) - if err == nil { - db.dbPath = dbPath - } - return db, err + require.NoError(t, err) + + db.dbPath = dbPath + + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + return db } // CreateWithBackend creates channeldb instance using the passed kvdb.Backend. // Any necessary schemas migrations due to updates will take place as necessary. -func CreateWithBackend(backend kvdb.Backend, - modifiers ...OptionModifier) (*DB, error) { +func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, + error) { opts := DefaultOptions() for _, modifier := range modifiers { @@ -408,16 +407,6 @@ func CreateWithBackend(backend kvdb.Backend, // Set the parent pointer (only used in tests). chanDB.channelStateDB.parent = chanDB - var err error - chanDB.graph, err = NewChannelGraph( - backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, - opts.UseGraphCache, opts.NoMigration, - ) - if err != nil { - return nil, err - } - // Synchronize the version of database and apply migrations if needed. if !opts.NoMigration { if err := chanDB.syncVersions(dbVersions); err != nil { @@ -646,7 +635,9 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( chanBucket := chainBucket.NestedReadBucket(chanPoint) var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(chanPoint), &outPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(chanPoint), &outPoint, + ) if err != nil { return err } @@ -670,12 +661,12 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( // FetchChannel attempts to locate a channel specified by the passed channel // point. If the channel cannot be found, then an error will be returned. -// Optionally an existing db tx can be supplied. -func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( - *OpenChannel, error) { +func (c *ChannelStateDB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, + error) { var targetChanPoint bytes.Buffer - if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { + err := graphdb.WriteOutpoint(&targetChanPoint, &chanPoint) + if err != nil { return nil, err } @@ -686,7 +677,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( return targetChanPointBytes, &chanPoint, nil } - return c.channelScanner(tx, selector) + return c.channelScanner(nil, selector) } // FetchChannelByID attempts to locate a channel specified by the passed channel @@ -709,7 +700,9 @@ func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) ( ) err := chainBkt.ForEach(func(k, _ []byte) error { var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(k), &outPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(k), &outPoint, + ) if err != nil { return err } @@ -1089,7 +1082,7 @@ func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( var b bytes.Buffer var err error - if err = writeOutpoint(&b, chanID); err != nil { + if err = graphdb.WriteOutpoint(&b, chanID); err != nil { return err } @@ -1131,7 +1124,9 @@ func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) ( // We scan over all possible candidates for this channel ID. for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(op), &outPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(op), &outPoint, + ) if err != nil { return err } @@ -1173,7 +1168,7 @@ func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { ) err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&b, chanPoint); err != nil { return err } @@ -1344,48 +1339,24 @@ func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) er return nil } -// AddrsForNode consults the graph and channel database for all addresses known -// to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, - error) { - +// AddrsForNode consults the channel database for all addresses known to the +// passed node public key. The returned boolean indicates if the given node is +// unknown to the channel DB or not. +// +// NOTE: this is part of the AddrSource interface. +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) { linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub) - if err != nil { - return nil, err - } - - // We'll also query the graph for this peer to see if they have any - // addresses that we don't currently have stored within the link node - // database. - pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed()) - if err != nil { - return nil, err - } - graphNode, err := d.graph.FetchLightningNode(pubKey) - if err != nil && err != ErrGraphNodeNotFound { - return nil, err - } else if err == ErrGraphNodeNotFound { - // If the node isn't found, then that's OK, as we still have the - // link node data. But any other error needs to be returned. - graphNode = &LightningNode{} - } + // Only if the error is something other than ErrNodeNotFound do we + // return it. + switch { + case err != nil && !errors.Is(err, ErrNodeNotFound): + return false, nil, err - // Now that we have both sources of addrs for this node, we'll use a - // map to de-duplicate any addresses between the two sources, and - // produce a final list of the combined addrs. - addrs := make(map[string]net.Addr) - for _, addr := range linkNode.Addresses { - addrs[addr.String()] = addr - } - for _, addr := range graphNode.Addresses { - addrs[addr.String()] = addr - } - dedupedAddrs := make([]net.Addr, 0, len(addrs)) - for _, addr := range addrs { - dedupedAddrs = append(dedupedAddrs, addr) + case errors.Is(err, ErrNodeNotFound): + return false, nil, nil } - return dedupedAddrs, nil + return true, linkNode.Addresses, nil } // AbandonChannel attempts to remove the target channel from the open channel @@ -1398,7 +1369,7 @@ func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint, // With the chanPoint constructed, we'll attempt to find the target // channel in the database. If we can't find the channel, then we'll // return the error back to the caller. - dbChan, err := c.FetchChannel(nil, *chanPoint) + dbChan, err := c.FetchChannel(*chanPoint) switch { // If the channel wasn't found, then it's possible that it was already // abandoned from the database. @@ -1638,11 +1609,6 @@ func (d *DB) applyOptionalVersions(cfg OptionalMiragtionConfig) error { return nil } -// ChannelGraph returns the current instance of the directed channel graph. -func (d *DB) ChannelGraph() *ChannelGraph { - return d.graph -} - // ChannelStateDB returns the sub database that is concerned with the channel // state. func (d *DB) ChannelStateDB() *ChannelStateDB { @@ -1693,7 +1659,7 @@ func fetchHistoricalChanBucket(tx kvdb.RTx, // With the bucket for the node and chain fetched, we can now go down // another level, for the channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } chanBucket := historicalChanBucket.NestedReadBucket( diff --git a/channeldb/db_test.go b/channeldb/db_test.go index d8113db830..9fed9934ba 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -1,17 +1,21 @@ package channeldb import ( + "image/color" "math" "math/rand" "net" "path/filepath" "reflect" "testing" + "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" @@ -20,6 +24,16 @@ import ( "github.com/stretchr/testify/require" ) +var ( + testAddr = &net.TCPAddr{IP: (net.IP)([]byte{0xA, 0x0, 0x0, 0x1}), + Port: 9000} + anotherAddr, _ = net.ResolveTCPAddr("tcp", + "[2001:db8:85a3:0:0:8a2e:370:7334]:80") + testAddrs = []net.Addr{testAddr} + + testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) +) + func TestOpenWithCreate(t *testing.T) { t.Parallel() @@ -51,11 +65,7 @@ func TestOpenWithCreate(t *testing.T) { // Now, reopen the same db in dry run migration mode. Since we have not // applied any migrations, this should ignore the flag and not fail. - cdb, err = Open(dbPath, OptionDryRunMigration(true)) - require.NoError(t, err, "unable to create channeldb") - if err := cdb.Close(); err != nil { - t.Fatalf("unable to close channeldb: %v", err) - } + OpenForTesting(t, dbPath, OptionDryRunMigration(true)) } // TestWipe tests that the database wipe operation completes successfully @@ -166,25 +176,25 @@ func TestFetchClosedChannelForID(t *testing.T) { } } -// TestAddrsForNode tests the we're able to properly obtain all the addresses -// for a target node. -func TestAddrsForNode(t *testing.T) { +// TestMultiSourceAddrsForNode tests the we're able to properly obtain all the +// addresses for a target node from multiple backends - in this case, the +// channel db and graph db. +func TestMultiSourceAddrsForNode(t *testing.T) { t.Parallel() fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") - graph := fullDB.ChannelGraph() + graph, err := graphdb.MakeTestGraph(t) + require.NoError(t, err) // We'll make a test vertex to insert into the database, as the source // node, but this node will only have half the number of addresses it // usually does. - testNode, err := createTestVertex(fullDB) + testNode := createTestVertex(t) require.NoError(t, err, "unable to create test node") testNode.Addresses = []net.Addr{testAddr} - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } + require.NoError(t, graph.SetSourceNode(testNode)) // Next, we'll make a link node with the same pubkey, but with an // additional address. @@ -194,28 +204,27 @@ func TestAddrsForNode(t *testing.T) { fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub, anotherAddr, ) - if err := linkNode.Sync(); err != nil { - t.Fatalf("unable to sync link node: %v", err) - } + require.NoError(t, linkNode.Sync()) + + // Create a multi-backend address source from the channel db and graph + // db. + addrSource := NewMultiAddrSource(fullDB, graph) // Now that we've created a link node, as well as a vertex for the // node, we'll query for all its addresses. - nodeAddrs, err := fullDB.AddrsForNode(nodePub) + known, nodeAddrs, err := addrSource.AddrsForNode(nodePub) require.NoError(t, err, "unable to obtain node addrs") + require.True(t, known) expectedAddrs := make(map[string]struct{}) expectedAddrs[testAddr.String()] = struct{}{} expectedAddrs[anotherAddr.String()] = struct{}{} // Finally, ensure that all the expected addresses are found. - if len(nodeAddrs) != len(expectedAddrs) { - t.Fatalf("expected %v addrs, got %v", - len(expectedAddrs), len(nodeAddrs)) - } + require.Len(t, nodeAddrs, len(expectedAddrs)) + for _, addr := range nodeAddrs { - if _, ok := expectedAddrs[addr.String()]; !ok { - t.Fatalf("unexpected addr: %v", addr) - } + require.Contains(t, expectedAddrs, addr.String()) } } @@ -233,7 +242,7 @@ func TestFetchChannel(t *testing.T) { channelState := createTestChannel(t, cdb, openChannelOption()) // Next, attempt to fetch the channel by its chan point. - dbChannel, err := cdb.FetchChannel(nil, channelState.FundingOutpoint) + dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) require.NoError(t, err, "unable to fetch channel") // The decoded channel state should be identical to what we stored @@ -257,7 +266,7 @@ func TestFetchChannel(t *testing.T) { uniqueOutputIndex.Add(1) channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load() - _, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint) + _, err = cdb.FetchChannel(channelState2.FundingOutpoint) require.ErrorIs(t, err, ErrChannelNotFound) chanID2 := lnwire.NewChanIDFromOutPoint(channelState2.FundingOutpoint) @@ -397,7 +406,7 @@ func TestRestoreChannelShells(t *testing.T) { // We should also be able to find the channel if we query for it // directly. - _, err = cdb.FetchChannel(nil, channelShell.Chan.FundingOutpoint) + _, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint) require.NoError(t, err, "unable to fetch channel") // We should also be able to find the link node that was inserted by @@ -446,7 +455,7 @@ func TestAbandonChannel(t *testing.T) { // At this point, the channel should no longer be found in the set of // open channels. - _, err = cdb.FetchChannel(nil, chanState.FundingOutpoint) + _, err = cdb.FetchChannel(chanState.FundingOutpoint) if err != ErrChannelNotFound { t.Fatalf("channel should not have been found: %v", err) } @@ -711,3 +720,28 @@ func TestFetchHistoricalChannel(t *testing.T) { t.Fatalf("expected chan not found, got: %v", err) } } + +func createLightningNode(priv *btcec.PrivateKey) *models.LightningNode { + updateTime := rand.Int63() + + pub := priv.PubKey().SerializeCompressed() + n := &models.LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: time.Unix(updateTime, 0), + Color: color.RGBA{1, 2, 3, 0}, + Alias: "kek" + string(pub), + Features: testFeatures, + Addresses: testAddrs, + } + copy(n.PubKeyBytes[:], priv.PubKey().SerializeCompressed()) + + return n +} + +func createTestVertex(t *testing.T) *models.LightningNode { + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return createLightningNode(priv) +} diff --git a/channeldb/error.go b/channeldb/error.go index 629cd93c6f..c2b2dde0d7 100644 --- a/channeldb/error.go +++ b/channeldb/error.go @@ -1,7 +1,6 @@ package channeldb import ( - "errors" "fmt" ) @@ -43,57 +42,6 @@ var ( // created. ErrMetaNotFound = fmt.Errorf("unable to locate meta information") - // ErrClosedScidsNotFound is returned when the closed scid bucket - // hasn't been created. - ErrClosedScidsNotFound = fmt.Errorf("closed scid bucket doesn't exist") - - // ErrGraphNotFound is returned when at least one of the components of - // graph doesn't exist. - ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") - - // ErrGraphNeverPruned is returned when graph was never pruned. - ErrGraphNeverPruned = fmt.Errorf("graph never pruned") - - // ErrSourceNodeNotSet is returned if the source node of the graph - // hasn't been added The source node is the center node within a - // star-graph. - ErrSourceNodeNotSet = fmt.Errorf("source node does not exist") - - // ErrGraphNodesNotFound is returned in case none of the nodes has - // been added in graph node bucket. - ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist") - - // ErrGraphNoEdgesFound is returned in case of none of the channel/edges - // has been added in graph edge bucket. - ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist") - - // ErrGraphNodeNotFound is returned when we're unable to find the target - // node. - ErrGraphNodeNotFound = fmt.Errorf("unable to find node") - - // ErrEdgeNotFound is returned when an edge for the target chanID - // can't be found. - ErrEdgeNotFound = fmt.Errorf("edge not found") - - // ErrZombieEdge is an error returned when we attempt to look up an edge - // but it is marked as a zombie within the zombie index. - ErrZombieEdge = errors.New("edge marked as zombie") - - // ErrZombieEdgeNotFound is an error returned when we attempt to find an - // edge in the zombie index which is not there. - ErrZombieEdgeNotFound = errors.New("edge not found in zombie index") - - // ErrEdgeAlreadyExist is returned when edge with specific - // channel id can't be added because it already exist. - ErrEdgeAlreadyExist = fmt.Errorf("edge already exist") - - // ErrNodeAliasNotFound is returned when alias for node can't be found. - ErrNodeAliasNotFound = fmt.Errorf("alias for node not found") - - // ErrUnknownAddressType is returned when a node's addressType is not - // an expected value. - ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved") - // ErrNoClosedChannels is returned when a node is queries for all the // channels it has closed, but it hasn't yet closed any channels. ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") @@ -102,24 +50,8 @@ var ( // to the log not having any recorded events. ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events") - // ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel - // policy field is not found in the db even though its message flags - // indicate it should be. - ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " + - "present") - // ErrChanAlreadyExists is return when the caller attempts to create a // channel with a channel point that is already present in the // database. ErrChanAlreadyExists = fmt.Errorf("channel already exists") ) - -// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the -// caller attempts to write an announcement message which bares too many extra -// opaque bytes. We limit this value in order to ensure that we don't waste -// disk space due to nodes unnecessarily padding out their announcements with -// garbage data. -func ErrTooManyExtraOpaqueBytes(numBytes int) error { - return fmt.Errorf("max allowed number of opaque bytes is %v, received "+ - "%v bytes", MaxAllowedExtraOpaqueBytes, numBytes) -} diff --git a/channeldb/forwarding_policy.go b/channeldb/forwarding_policy.go index 2a41230457..2df2e308f8 100644 --- a/channeldb/forwarding_policy.go +++ b/channeldb/forwarding_policy.go @@ -1,7 +1,7 @@ package channeldb import ( - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/channeldb/height_hint_test.go b/channeldb/height_hint_test.go index 3d98707e55..1549ee5f47 100644 --- a/channeldb/height_hint_test.go +++ b/channeldb/height_hint_test.go @@ -23,15 +23,11 @@ func initHintCache(t *testing.T) *HeightHintCache { func initHintCacheWithConfig(t *testing.T, cfg CacheConfig) *HeightHintCache { t.Helper() - db, err := Open(t.TempDir()) - require.NoError(t, err, "unable to create db") + db := OpenForTesting(t, t.TempDir()) + hintCache, err := NewHeightHintCache(cfg, db.Backend) require.NoError(t, err, "unable to create hint cache") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) - return hintCache } diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 3fe6b668e8..1a4409f2db 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 9da504a5d8..669da16080 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -9,7 +9,7 @@ import ( "io" "time" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index defa9f3291..5b6bd29a94 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/go-errors/errors" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -22,14 +23,11 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), cdb.dryRun = dryRun // Create a test node that will be our source node. - testNode, err := createTestVertex(cdb) - if err != nil { - t.Fatal(err) - } - graph := cdb.ChannelGraph() - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatal(err) - } + testNode := createTestVertex(t) + + graph, err := graphdb.MakeTestGraph(t) + require.NoError(t, err) + require.NoError(t, graph.SetSourceNode(testNode)) // beforeMigration usually used for populating the database // with test data. diff --git a/channeldb/nodes.go b/channeldb/nodes.go index b6e1573cff..b17d5c360d 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" ) @@ -273,7 +274,7 @@ func serializeLinkNode(w io.Writer, l *LinkNode) error { } for _, addr := range l.Addresses { - if err := serializeAddr(w, addr); err != nil { + if err := graphdb.SerializeAddr(w, addr); err != nil { return err } } @@ -315,7 +316,7 @@ func deserializeLinkNode(r io.Reader) (*LinkNode, error) { node.Addresses = make([]net.Addr, numAddrs) for i := uint32(0); i < numAddrs; i++ { - addr, err := deserializeAddr(r) + addr, err := graphdb.DeserializeAddr(r) if err != nil { return nil, err } diff --git a/channeldb/options.go b/channeldb/options.go index 3f5b472f61..6e631e2cb9 100644 --- a/channeldb/options.go +++ b/channeldb/options.go @@ -1,10 +1,7 @@ package channeldb import ( - "time" - "github.com/lightningnetwork/lnd/clock" - "github.com/lightningnetwork/lnd/kvdb" ) const ( @@ -35,30 +32,8 @@ type OptionalMiragtionConfig struct { // Options holds parameters for tuning and customizing a channeldb.DB. type Options struct { - kvdb.BoltBackendConfig OptionalMiragtionConfig - // RejectCacheSize is the maximum number of rejectCacheEntries to hold - // in the rejection cache. - RejectCacheSize int - - // ChannelCacheSize is the maximum number of ChannelEdges to hold in the - // channel cache. - ChannelCacheSize int - - // BatchCommitInterval is the maximum duration the batch schedulers will - // wait before attempting to commit a pending set of updates. - BatchCommitInterval time.Duration - - // PreAllocCacheNumNodes is the number of nodes we expect to be in the - // graph cache, so we can pre-allocate the map accordingly. - PreAllocCacheNumNodes int - - // UseGraphCache denotes whether the in-memory graph cache should be - // used or a fallback version that uses the underlying database for - // path finding. - UseGraphCache bool - // NoMigration specifies that underlying backend was opened in read-only // mode and migrations shouldn't be performed. This can be useful for // applications that use the channeldb package as a library. @@ -87,17 +62,7 @@ type Options struct { // DefaultOptions returns an Options populated with default values. func DefaultOptions() Options { return Options{ - BoltBackendConfig: kvdb.BoltBackendConfig{ - NoFreelistSync: true, - AutoCompact: false, - AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, - DBTimeout: kvdb.DefaultDBTimeout, - }, OptionalMiragtionConfig: OptionalMiragtionConfig{}, - RejectCacheSize: DefaultRejectCacheSize, - ChannelCacheSize: DefaultChannelCacheSize, - PreAllocCacheNumNodes: DefaultPreAllocCacheNumNodes, - UseGraphCache: true, NoMigration: false, clock: clock.NewDefaultClock(), } @@ -106,34 +71,6 @@ func DefaultOptions() Options { // OptionModifier is a function signature for modifying the default Options. type OptionModifier func(*Options) -// OptionSetRejectCacheSize sets the RejectCacheSize to n. -func OptionSetRejectCacheSize(n int) OptionModifier { - return func(o *Options) { - o.RejectCacheSize = n - } -} - -// OptionSetChannelCacheSize sets the ChannelCacheSize to n. -func OptionSetChannelCacheSize(n int) OptionModifier { - return func(o *Options) { - o.ChannelCacheSize = n - } -} - -// OptionSetPreAllocCacheNumNodes sets the PreAllocCacheNumNodes to n. -func OptionSetPreAllocCacheNumNodes(n int) OptionModifier { - return func(o *Options) { - o.PreAllocCacheNumNodes = n - } -} - -// OptionSetUseGraphCache sets the UseGraphCache option to the given value. -func OptionSetUseGraphCache(use bool) OptionModifier { - return func(o *Options) { - o.UseGraphCache = use - } -} - // OptionNoRevLogAmtData sets the NoRevLogAmtData option to the given value. If // it is set to true then amount data will not be stored in the revocation log. func OptionNoRevLogAmtData(noAmtData bool) OptionModifier { @@ -142,36 +79,6 @@ func OptionNoRevLogAmtData(noAmtData bool) OptionModifier { } } -// OptionSetSyncFreelist allows the database to sync its freelist. -func OptionSetSyncFreelist(b bool) OptionModifier { - return func(o *Options) { - o.NoFreelistSync = !b - } -} - -// OptionAutoCompact turns on automatic database compaction on startup. -func OptionAutoCompact() OptionModifier { - return func(o *Options) { - o.AutoCompact = true - } -} - -// OptionAutoCompactMinAge sets the minimum age for automatic database -// compaction. -func OptionAutoCompactMinAge(minAge time.Duration) OptionModifier { - return func(o *Options) { - o.AutoCompactMinAge = minAge - } -} - -// OptionSetBatchCommitInterval sets the batch commit interval for the internval -// batch schedulers. -func OptionSetBatchCommitInterval(interval time.Duration) OptionModifier { - return func(o *Options) { - o.BatchCommitInterval = interval - } -} - // OptionNoMigration allows the database to be opened in read only mode by // disabling migrations. func OptionNoMigration(b bool) OptionModifier { diff --git a/channeldb/options_test.go b/channeldb/options_test.go deleted file mode 100644 index e60c1cfcfc..0000000000 --- a/channeldb/options_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package channeldb_test - -import ( - "testing" - - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/kvdb" - "github.com/stretchr/testify/require" -) - -// TestDefaultOptions tests the default options are created as intended. -func TestDefaultOptions(t *testing.T) { - opts := channeldb.DefaultOptions() - - require.True(t, opts.NoFreelistSync) - require.False(t, opts.AutoCompact) - require.Equal( - t, kvdb.DefaultBoltAutoCompactMinAge, opts.AutoCompactMinAge, - ) - require.Equal(t, kvdb.DefaultDBTimeout, opts.DBTimeout) - require.Equal( - t, channeldb.DefaultRejectCacheSize, opts.RejectCacheSize, - ) - require.Equal( - t, channeldb.DefaultChannelCacheSize, opts.ChannelCacheSize, - ) -} diff --git a/channeldb/reports.go b/channeldb/reports.go index c4e58d81e6..4f46bd9e1f 100644 --- a/channeldb/reports.go +++ b/channeldb/reports.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/tlv" ) @@ -164,7 +165,7 @@ func putReport(tx kvdb.RwTx, chainHash chainhash.Hash, // Finally write our outpoint to be used as the key for this record. var keyBuf bytes.Buffer - if err := writeOutpoint(&keyBuf, &report.OutPoint); err != nil { + if err := graphdb.WriteOutpoint(&keyBuf, &report.OutPoint); err != nil { return err } @@ -317,7 +318,7 @@ func fetchReportWriteBucket(tx kvdb.RwTx, chainHash chainhash.Hash, } var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } @@ -341,7 +342,7 @@ func fetchReportReadBucket(tx kvdb.RTx, chainHash chainhash.Hash, // With the bucket for the node and chain fetched, we can now go down // another level, for the channel itself. var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } diff --git a/channeldb/reports_test.go b/channeldb/reports_test.go index 48a41914fa..1148fdf03e 100644 --- a/channeldb/reports_test.go +++ b/channeldb/reports_test.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -137,7 +138,7 @@ func TestFetchChannelWriteBucket(t *testing.T) { error) { var chanPointBuf bytes.Buffer - err := writeOutpoint(&chanPointBuf, &testChanPoint1) + err := graphdb.WriteOutpoint(&chanPointBuf, &testChanPoint1) require.NoError(t, err) return chainHash.CreateBucketIfNotExists(chanPointBuf.Bytes()) diff --git a/channelnotifier/channelnotifier.go b/channelnotifier/channelnotifier.go index 2b39396df8..b03d913130 100644 --- a/channelnotifier/channelnotifier.go +++ b/channelnotifier/channelnotifier.go @@ -144,7 +144,7 @@ func (c *ChannelNotifier) NotifyPendingOpenChannelEvent(chanPoint wire.OutPoint, // channel has gone from pending open to open. func (c *ChannelNotifier) NotifyOpenChannelEvent(chanPoint wire.OutPoint) { // Fetch the relevant channel from the database. - channel, err := c.chanDB.FetchChannel(nil, chanPoint) + channel, err := c.chanDB.FetchChannel(chanPoint) if err != nil { log.Warnf("Unable to fetch open channel from the db: %v", err) } diff --git a/config_builder.go b/config_builder.go index 7cc1a112d2..42650bb68b 100644 --- a/config_builder.go +++ b/config_builder.go @@ -35,6 +35,7 @@ import ( "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -900,18 +901,10 @@ func (d *RPCSignerWalletImpl) BuildChainControl( type DatabaseInstances struct { // GraphDB is the database that stores the channel graph used for path // finding. - // - // NOTE/TODO: This currently _needs_ to be the same instance as the - // ChanStateDB below until the separation of the two databases is fully - // complete! - GraphDB *channeldb.DB + GraphDB *graphdb.ChannelGraph // ChanStateDB is the database that stores all of our node's channel // state. - // - // NOTE/TODO: This currently _needs_ to be the same instance as the - // GraphDB above until the separation of the two databases is fully - // complete! ChanStateDB *channeldb.DB // HeightHintDB is the database that stores height hints for spends. @@ -1022,16 +1015,37 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( "instances") } + graphDBOptions := []graphdb.OptionModifier{ + graphdb.WithRejectCacheSize(cfg.Caches.RejectCacheSize), + graphdb.WithChannelCacheSize(cfg.Caches.ChannelCacheSize), + graphdb.WithBatchCommitInterval(cfg.DB.BatchCommitInterval), + graphdb.WithUseGraphCache(!cfg.DB.NoGraphCache), + } + + // We want to pre-allocate the channel graph cache according to what we + // expect for mainnet to speed up memory allocation. + if cfg.ActiveNetParams.Name == chaincfg.MainNetParams.Name { + graphDBOptions = append( + graphDBOptions, graphdb.WithPreAllocCacheNumNodes( + graphdb.DefaultPreAllocCacheNumNodes, + ), + ) + } + + dbs.GraphDB, err = graphdb.NewChannelGraph( + databaseBackends.GraphDB, graphDBOptions..., + ) + if err != nil { + cleanUp() + + err := fmt.Errorf("unable to open graph DB: %w", err) + d.logger.Error(err) + + return nil, nil, err + } + dbOptions := []channeldb.OptionModifier{ - channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize), - channeldb.OptionSetChannelCacheSize( - cfg.Caches.ChannelCacheSize, - ), - channeldb.OptionSetBatchCommitInterval( - cfg.DB.BatchCommitInterval, - ), channeldb.OptionDryRunMigration(cfg.DryRunMigration), - channeldb.OptionSetUseGraphCache(!cfg.DB.NoGraphCache), channeldb.OptionKeepFailedPaymentAttempts( cfg.KeepFailedPaymentAttempts, ), @@ -1042,27 +1056,17 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( channeldb.OptionNoRevLogAmtData(cfg.DB.NoRevLogAmtData), } - // We want to pre-allocate the channel graph cache according to what we - // expect for mainnet to speed up memory allocation. - if cfg.ActiveNetParams.Name == chaincfg.MainNetParams.Name { - dbOptions = append( - dbOptions, channeldb.OptionSetPreAllocCacheNumNodes( - channeldb.DefaultPreAllocCacheNumNodes, - ), - ) - } - // Otherwise, we'll open two instances, one for the state we only need // locally, and the other for things we want to ensure are replicated. - dbs.GraphDB, err = channeldb.CreateWithBackend( - databaseBackends.GraphDB, dbOptions..., + dbs.ChanStateDB, err = channeldb.CreateWithBackend( + databaseBackends.ChanStateDB, dbOptions..., ) switch { // Give the DB a chance to dry run the migration. Since we know that // both the channel state and graph DBs are still always behind the same // backend, we know this would be applied to both of those DBs. case err == channeldb.ErrDryRunMigrationOK: - d.logger.Infof("Graph DB dry run migration successful") + d.logger.Infof("Channel DB dry run migration successful") return nil, nil, err case err != nil: @@ -1073,27 +1077,14 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( return nil, nil, err } - // For now, we don't _actually_ split the graph and channel state DBs on - // the code level. Since they both are based upon the *channeldb.DB - // struct it will require more refactoring to fully separate them. With - // the full remote mode we at least know for now that they both point to - // the same DB backend (and also namespace within that) so we only need - // to apply any migration once. - // - // TODO(guggero): Once the full separation of anything graph related - // from the channeldb.DB is complete, the decorated instance of the - // channel state DB should be created here individually instead of just - // using the same struct (and DB backend) instance. - dbs.ChanStateDB = dbs.GraphDB - // Instantiate a native SQL invoice store if the flag is set. if d.cfg.DB.UseNativeSQL { - // KV invoice db resides in the same database as the graph and - // channel state DB. Let's query the database to see if we have - // any invoices there. If we do, we won't allow the user to - // start lnd with native SQL enabled, as we don't currently - // migrate the invoices to the new database schema. - invoiceSlice, err := dbs.GraphDB.QueryInvoices( + // KV invoice db resides in the same database as the channel + // state DB. Let's query the database to see if we have any + // invoices there. If we do, we won't allow the user to start + // lnd with native SQL enabled, as we don't currently migrate + // the invoices to the new database schema. + invoiceSlice, err := dbs.ChanStateDB.QueryInvoices( ctx, invoices.InvoiceQuery{ NumMaxInvoices: 1, }, @@ -1127,7 +1118,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( executor, clock.NewDefaultClock(), ) } else { - dbs.InvoiceDB = dbs.GraphDB + dbs.InvoiceDB = dbs.ChanStateDB } // Wrap the watchtower client DB and make sure we clean up. diff --git a/contractcourt/breach_arbitrator.go b/contractcourt/breach_arbitrator.go index a8154d0e61..89c596f7ae 100644 --- a/contractcourt/breach_arbitrator.go +++ b/contractcourt/breach_arbitrator.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" @@ -1856,7 +1857,8 @@ func (rs *RetributionStore) Add(ret *retributionInfo) error { } var outBuf bytes.Buffer - if err := writeOutpoint(&outBuf, &ret.chanPoint); err != nil { + err = graphdb.WriteOutpoint(&outBuf, &ret.chanPoint) + if err != nil { return err } @@ -1907,7 +1909,8 @@ func (rs *RetributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) { } var chanBuf bytes.Buffer - if err := writeOutpoint(&chanBuf, chanPoint); err != nil { + err := graphdb.WriteOutpoint(&chanBuf, chanPoint) + if err != nil { return err } @@ -1947,7 +1950,8 @@ func (rs *RetributionStore) Remove(chanPoint *wire.OutPoint) error { // Serialize the channel point we are intending to remove. var chanBuf bytes.Buffer - if err := writeOutpoint(&chanBuf, chanPoint); err != nil { + err = graphdb.WriteOutpoint(&chanBuf, chanPoint) + if err != nil { return err } chanBytes := chanBuf.Bytes() @@ -2017,7 +2021,7 @@ func (ret *retributionInfo) Encode(w io.Writer) error { return err } - if err := writeOutpoint(w, &ret.chanPoint); err != nil { + if err := graphdb.WriteOutpoint(w, &ret.chanPoint); err != nil { return err } @@ -2057,7 +2061,7 @@ func (ret *retributionInfo) Decode(r io.Reader) error { } ret.commitHash = *hash - if err := readOutpoint(r, &ret.chanPoint); err != nil { + if err := graphdb.ReadOutpoint(r, &ret.chanPoint); err != nil { return err } @@ -2100,7 +2104,7 @@ func (bo *breachedOutput) Encode(w io.Writer) error { return err } - if err := writeOutpoint(w, &bo.outpoint); err != nil { + if err := graphdb.WriteOutpoint(w, &bo.outpoint); err != nil { return err } @@ -2131,7 +2135,7 @@ func (bo *breachedOutput) Decode(r io.Reader) error { } bo.amt = btcutil.Amount(binary.BigEndian.Uint64(scratch[:8])) - if err := readOutpoint(r, &bo.outpoint); err != nil { + if err := graphdb.ReadOutpoint(r, &bo.outpoint); err != nil { return err } diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index bd4ad85683..576009eda4 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -635,15 +635,6 @@ func TestMockRetributionStore(t *testing.T) { } } -func makeTestChannelDB(t *testing.T) (*channeldb.DB, error) { - db, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - - return db, nil -} - // TestChannelDBRetributionStore instantiates a retributionStore backed by a // channeldb.DB, and tests its behavior using the general RetributionStore test // suite. @@ -654,25 +645,19 @@ func TestChannelDBRetributionStore(t *testing.T) { t.Run( "channeldbDBRetributionStore."+test.name, func(tt *testing.T) { - db, err := makeTestChannelDB(t) - if err != nil { - t.Fatalf("unable to open channeldb: %v", err) - } - defer db.Close() + db := channeldb.OpenForTesting(t, t.TempDir()) restartDb := func() RetributionStorer { // Close and reopen channeldb - if err = db.Close(); err != nil { + if err := db.Close(); err != nil { t.Fatalf("unable to close "+ "channeldb during "+ "restart: %v", err) } - db, err = channeldb.Open(db.Path()) - if err != nil { - t.Fatalf("unable to open "+ - "channeldb: %v", err) - } + db = channeldb.OpenForTesting( + t, db.Path(), + ) return NewRetributionStore(db) } @@ -2279,21 +2264,8 @@ func createInitChannels(t *testing.T) ( return nil, nil, err } - dbAlice, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - - dbBob, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbAlice := channeldb.OpenForTesting(t, t.TempDir()) + dbBob := channeldb.OpenForTesting(t, t.TempDir()) estimator := chainfee.NewStaticEstimator(12500, 0) feePerKw, err := estimator.EstimateFeePerKW(1) diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index d7d10ba252..78a79a3c2f 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -13,9 +13,9 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" @@ -304,9 +304,7 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, // same instance that is used by the link. chanPoint := a.channel.FundingOutpoint - channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( - nil, chanPoint, - ) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(chanPoint) if err != nil { return nil, err } @@ -359,9 +357,7 @@ func (a *arbChannel) ForceCloseChan() (*wire.MsgTx, error) { // Now that we know the link can't mutate the channel // state, we'll read the channel from disk the target // channel according to its channel point. - channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( - nil, chanPoint, - ) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(chanPoint) if err != nil { return nil, err } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index abaca5c2ba..fe2603ca5a 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -8,8 +8,8 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" @@ -22,13 +22,7 @@ import ( func TestChainArbitratorRepublishCloses(t *testing.T) { t.Parallel() - db, err := channeldb.Open(t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) // Create 10 test channels and sync them to the database. const numChans = 10 @@ -139,11 +133,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { func TestResolveContract(t *testing.T) { t.Parallel() - db, err := channeldb.Open(t.TempDir()) - require.NoError(t, err, "unable to open db") - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := channeldb.OpenForTesting(t, t.TempDir()) // With the DB created, we'll make a new channel, and mark it as // pending open within the database. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index ffa4a5d6e2..3a7c2cfe93 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -15,8 +15,8 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index ac5253787d..3b367bf548 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -15,9 +15,9 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index f2b43b0f80..077fb8f82c 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 6bda4e398b..e7e21fff68 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -10,8 +10,8 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index 55d93a6fb3..22280f953e 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -8,7 +8,7 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/kvdb" diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index f67c34ff4e..6608a6fb51 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -7,7 +7,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 4c9d2b200b..159b642dde 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -12,8 +12,8 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnutils" diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index b9182500bb..23023729fa 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -12,8 +12,8 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 47be71d3ec..92cc587fc1 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -14,8 +14,8 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index 90ad2b1c87..75b81e9dd4 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" diff --git a/contractcourt/mock_htlcnotifier_test.go b/contractcourt/mock_htlcnotifier_test.go index 52bd18676a..6b8f40659d 100644 --- a/contractcourt/mock_htlcnotifier_test.go +++ b/contractcourt/mock_htlcnotifier_test.go @@ -2,7 +2,7 @@ package contractcourt import ( "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" ) type mockHTLCNotifier struct { diff --git a/contractcourt/mock_registry_test.go b/contractcourt/mock_registry_test.go index 5c75185623..5bba11afcb 100644 --- a/contractcourt/mock_registry_test.go +++ b/contractcourt/mock_registry_test.go @@ -3,7 +3,7 @@ package contractcourt import ( "context" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/contractcourt/nursery_store.go b/contractcourt/nursery_store.go index c668f22b62..428b37f97b 100644 --- a/contractcourt/nursery_store.go +++ b/contractcourt/nursery_store.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" ) @@ -221,7 +222,7 @@ func prefixOutputKey(statePrefix []byte, return nil, err } - err := writeOutpoint(&pfxOutputBuffer, &outpoint) + err := graphdb.WriteOutpoint(&pfxOutputBuffer, &outpoint) if err != nil { return nil, err } @@ -738,7 +739,9 @@ func (ns *NurseryStore) ListChannels() ([]wire.OutPoint, error) { return chanIndex.ForEach(func(chanBytes, _ []byte) error { var chanPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(chanBytes), &chanPoint) + err := graphdb.ReadOutpoint( + bytes.NewReader(chanBytes), &chanPoint, + ) if err != nil { return err } @@ -804,12 +807,13 @@ func (ns *NurseryStore) RemoveChannel(chanPoint *wire.OutPoint) error { // Serialize the provided channel point, such that we can delete // the mature channel bucket. var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + err := graphdb.WriteOutpoint(&chanBuffer, chanPoint) + if err != nil { return err } chanBytes := chanBuffer.Bytes() - err := ns.forChanOutputs(tx, chanPoint, func(k, v []byte) error { + err = ns.forChanOutputs(tx, chanPoint, func(k, v []byte) error { if !bytes.HasPrefix(k, gradPrefix) { return ErrImmatureChannel } @@ -959,7 +963,7 @@ func (ns *NurseryStore) createChannelBucket(tx kvdb.RwTx, // Serialize the provided channel point, as this provides the name of // the channel bucket of interest. var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil { return nil, err } @@ -989,7 +993,7 @@ func (ns *NurseryStore) getChannelBucket(tx kvdb.RTx, // Serialize the provided channel point and return the bucket matching // the serialized key. var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil { return nil } @@ -1017,7 +1021,7 @@ func (ns *NurseryStore) getChannelBucketWrite(tx kvdb.RwTx, // Serialize the provided channel point and return the bucket matching // the serialized key. var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil { return nil } @@ -1142,7 +1146,7 @@ func (ns *NurseryStore) createHeightChanBucket(tx kvdb.RwTx, // Serialize the provided channel point, as this generates the name of // the subdirectory corresponding to the channel of interest. var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil { return nil, err } chanBytes := chanBuffer.Bytes() @@ -1168,7 +1172,7 @@ func (ns *NurseryStore) getHeightChanBucketWrite(tx kvdb.RwTx, // Serialize the provided channel point, which generates the key for // looking up the proper height-channel bucket inside the height bucket. var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil { return nil } chanBytes := chanBuffer.Bytes() @@ -1312,7 +1316,7 @@ func (ns *NurseryStore) removeOutputFromHeight(tx kvdb.RwTx, height uint32, } var chanBuffer bytes.Buffer - if err := writeOutpoint(&chanBuffer, chanPoint); err != nil { + if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil { return err } diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 9a3c5308eb..994bc57a88 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -65,12 +65,9 @@ func copyChannelState(t *testing.T, state *channeldb.OpenChannel) ( return nil, err } - newDb, err := channeldb.Open(tempDbPath) - if err != nil { - return nil, err - } + newDB := channeldb.OpenForTesting(t, tempDbPath) - chans, err := newDb.ChannelStateDB().FetchAllChannels() + chans, err := newDB.ChannelStateDB().FetchAllChannels() if err != nil { return nil, err } diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index b7b4d33a8b..afd6c18c99 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnutils" @@ -1466,10 +1467,10 @@ func (k *kidOutput) Encode(w io.Writer) error { } op := k.OutPoint() - if err := writeOutpoint(w, &op); err != nil { + if err := graphdb.WriteOutpoint(w, &op); err != nil { return err } - if err := writeOutpoint(w, k.OriginChanPoint()); err != nil { + if err := graphdb.WriteOutpoint(w, k.OriginChanPoint()); err != nil { return err } @@ -1521,11 +1522,12 @@ func (k *kidOutput) Decode(r io.Reader) error { } k.amt = btcutil.Amount(byteOrder.Uint64(scratch[:])) - if err := readOutpoint(io.LimitReader(r, 40), &k.outpoint); err != nil { + err := graphdb.ReadOutpoint(io.LimitReader(r, 40), &k.outpoint) + if err != nil { return err } - err := readOutpoint(io.LimitReader(r, 40), &k.originChanPoint) + err = graphdb.ReadOutpoint(io.LimitReader(r, 40), &k.originChanPoint) if err != nil { return err } @@ -1577,40 +1579,6 @@ func (k *kidOutput) Decode(r io.Reader) error { return nil } -// TODO(bvu): copied from channeldb, remove repetition -func writeOutpoint(w io.Writer, o *wire.OutPoint) error { - // TODO(roasbeef): make all scratch buffers on the stack - scratch := make([]byte, 4) - - // TODO(roasbeef): write raw 32 bytes instead of wasting the extra - // byte. - if err := wire.WriteVarBytes(w, 0, o.Hash[:]); err != nil { - return err - } - - byteOrder.PutUint32(scratch, o.Index) - _, err := w.Write(scratch) - return err -} - -// TODO(bvu): copied from channeldb, remove repetition -func readOutpoint(r io.Reader, o *wire.OutPoint) error { - scratch := make([]byte, 4) - - txid, err := wire.ReadVarBytes(r, 0, 32, "prevout") - if err != nil { - return err - } - copy(o.Hash[:], txid) - - if _, err := r.Read(scratch); err != nil { - return err - } - o.Index = byteOrder.Uint32(scratch) - - return nil -} - // Compile-time constraint to ensure kidOutput implements the // Input interface. diff --git a/discovery/chan_series.go b/discovery/chan_series.go index b7b9af9890..696a908c4b 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -4,7 +4,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" "github.com/lightningnetwork/lnd/routing/route" @@ -36,7 +36,7 @@ type ChannelGraphTimeSeries interface { // ID's represents the ID's that we don't know of which were in the // passed superSet. FilterKnownChanIDs(chain chainhash.Hash, - superSet []channeldb.ChannelUpdateInfo, + superSet []graphdb.ChannelUpdateInfo, isZombieChan func(time.Time, time.Time) bool) ( []lnwire.ShortChannelID, error) @@ -45,7 +45,7 @@ type ChannelGraphTimeSeries interface { // grouped by their common block height. We'll use this to to a remote // peer's QueryChannelRange message. FilterChannelRange(chain chainhash.Hash, startHeight, endHeight uint32, - withTimestamps bool) ([]channeldb.BlockChannelRange, error) + withTimestamps bool) ([]graphdb.BlockChannelRange, error) // FetchChanAnns returns a full set of channel announcements as well as // their updates that match the set of specified short channel ID's. @@ -70,12 +70,12 @@ type ChannelGraphTimeSeries interface { // in-protocol channel range queries to quickly and efficiently synchronize our // channel state with all peers. type ChanSeries struct { - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph } // NewChanSeries constructs a new ChanSeries backed by a channeldb.ChannelGraph. // The returned ChanSeries implements the ChannelGraphTimeSeries interface. -func NewChanSeries(graph *channeldb.ChannelGraph) *ChanSeries { +func NewChanSeries(graph *graphdb.ChannelGraph) *ChanSeries { return &ChanSeries{ graph: graph, } @@ -200,7 +200,7 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FilterKnownChanIDs(_ chainhash.Hash, - superSet []channeldb.ChannelUpdateInfo, + superSet []graphdb.ChannelUpdateInfo, isZombieChan func(time.Time, time.Time) bool) ( []lnwire.ShortChannelID, error) { @@ -226,7 +226,7 @@ func (c *ChanSeries) FilterKnownChanIDs(_ chainhash.Hash, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FilterChannelRange(_ chainhash.Hash, startHeight, - endHeight uint32, withTimestamps bool) ([]channeldb.BlockChannelRange, + endHeight uint32, withTimestamps bool) ([]graphdb.BlockChannelRange, error) { return c.graph.FilterChannelRange( diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 284cc42212..41e58c404e 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -19,11 +19,11 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" @@ -1636,7 +1636,6 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { edgesToUpdate []updateTuple ) err := d.cfg.Graph.ForAllOutgoingChannels(func( - _ kvdb.RTx, info *models.ChannelEdgeInfo, edge *models.ChannelEdgePolicy) error { @@ -1686,7 +1685,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { return nil }) - if err != nil && err != channeldb.ErrGraphNoEdgesFound { + if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) { return fmt.Errorf("unable to retrieve outgoing channels: %w", err) } @@ -1963,7 +1962,7 @@ func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement, timestamp := time.Unix(int64(msg.Timestamp), 0) features := lnwire.NewFeatureVector(msg.Features, lnwire.Features) - node := &channeldb.LightningNode{ + node := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: timestamp, Addresses: msg.Addresses, @@ -2121,7 +2120,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( // come through again. err = d.cfg.Graph.MarkEdgeLive(scid) switch { - case errors.Is(err, channeldb.ErrZombieEdgeNotFound): + case errors.Is(err, graphdb.ErrZombieEdgeNotFound): log.Errorf("edge with chan_id=%v was not found in the "+ "zombie index: %v", err) @@ -2166,7 +2165,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // If the channel cannot be found, it is most likely a leftover // message for a channel that was closed, so we can consider it // stale. - if errors.Is(err, channeldb.ErrEdgeNotFound) { + if errors.Is(err, graphdb.ErrEdgeNotFound) { return true } if err != nil { @@ -2186,7 +2185,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // If the channel cannot be found, it is most likely a leftover // message for a channel that was closed, so we can consider it // stale. - if errors.Is(err, channeldb.ErrEdgeNotFound) { + if errors.Is(err, graphdb.ErrEdgeNotFound) { return true } if err != nil { @@ -2936,7 +2935,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, case err == nil: break - case errors.Is(err, channeldb.ErrZombieEdge): + case errors.Is(err, graphdb.ErrZombieEdge): err = d.processZombieUpdate(chanInfo, graphScid, upd) if err != nil { log.Debug(err) @@ -2949,11 +2948,11 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // needed to ensure the edge exists in the graph before // applying the update. fallthrough - case errors.Is(err, channeldb.ErrGraphNotFound): + case errors.Is(err, graphdb.ErrGraphNotFound): fallthrough - case errors.Is(err, channeldb.ErrGraphNoEdgesFound): + case errors.Is(err, graphdb.ErrGraphNoEdgesFound): fallthrough - case errors.Is(err, channeldb.ErrEdgeNotFound): + case errors.Is(err, graphdb.ErrEdgeNotFound): // If the edge corresponding to this ChannelUpdate was not // found in the graph, this might be a channel in the process // of being opened, and we haven't processed our own diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 1577ce8c2c..056069940a 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -24,11 +24,11 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntest/wait" @@ -73,26 +73,11 @@ var ( rebroadcastInterval = time.Hour * 1000000 ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. -func makeTestDB(t *testing.T) (*channeldb.DB, error) { - // Create channeldb for the first time. - cdb, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - - t.Cleanup(func() { - cdb.Close() - }) - - return cdb, nil -} - type mockGraphSource struct { bestHeight uint32 mu sync.Mutex - nodes []channeldb.LightningNode + nodes []models.LightningNode infos map[uint64]models.ChannelEdgeInfo edges map[uint64][]models.ChannelEdgePolicy zombies map[uint64][][33]byte @@ -112,7 +97,7 @@ func newMockRouter(height uint32) *mockGraphSource { var _ graph.ChannelGraphSource = (*mockGraphSource)(nil) -func (r *mockGraphSource) AddNode(node *channeldb.LightningNode, +func (r *mockGraphSource) AddNode(node *models.LightningNode, _ ...batch.SchedulerOption) error { r.mu.Lock() @@ -202,18 +187,19 @@ func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, return nil } -func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) error { +func (r *mockGraphSource) ForEachNode( + func(node *models.LightningNode) error) error { + return nil } -func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, - i *models.ChannelEdgeInfo, - c *models.ChannelEdgePolicy) error) error { +func (r *mockGraphSource) ForAllOutgoingChannels(cb func( + i *models.ChannelEdgeInfo, c *models.ChannelEdgePolicy) error) error { r.mu.Lock() defer r.mu.Unlock() - chans := make(map[uint64]channeldb.ChannelEdge) + chans := make(map[uint64]graphdb.ChannelEdge) for _, info := range r.infos { info := info @@ -230,7 +216,7 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, } for _, channel := range chans { - if err := cb(nil, channel.Info, channel.Policy1); err != nil { + if err := cb(channel.Info, channel.Policy1); err != nil { return err } } @@ -251,13 +237,13 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( if !ok { pubKeys, isZombie := r.zombies[chanIDInt] if !isZombie { - return nil, nil, nil, channeldb.ErrEdgeNotFound + return nil, nil, nil, graphdb.ErrEdgeNotFound } return &models.ChannelEdgeInfo{ NodeKey1Bytes: pubKeys[0], NodeKey2Bytes: pubKeys[1], - }, nil, nil, channeldb.ErrZombieEdge + }, nil, nil, graphdb.ErrZombieEdge } edges := r.edges[chanID.ToUint64()] @@ -279,7 +265,7 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( } func (r *mockGraphSource) FetchLightningNode( - nodePub route.Vertex) (*channeldb.LightningNode, error) { + nodePub route.Vertex) (*models.LightningNode, error) { for _, node := range r.nodes { if bytes.Equal(nodePub[:], node.PubKeyBytes[:]) { @@ -287,7 +273,7 @@ func (r *mockGraphSource) FetchLightningNode( } } - return nil, channeldb.ErrGraphNodeNotFound + return nil, graphdb.ErrGraphNodeNotFound } // IsStaleNode returns true if the graph source has a node announcement for the @@ -733,10 +719,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( notifier := newMockNotifier() router := newMockRouter(startHeight) - db, err := makeTestDB(t) - if err != nil { - return nil, err - } + db := channeldb.OpenForTesting(t, t.TempDir()) waitingProofStore, err := channeldb.NewWaitingProofStore(db) if err != nil { @@ -2319,9 +2302,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // At this point, the channel should still be considered a zombie. _, _, _, err = ctx.router.GetChannelByID(chanID) - if err != channeldb.ErrZombieEdge { - t.Fatalf("channel should still be a zombie") - } + require.ErrorIs(t, err, graphdb.ErrZombieEdge) // Attempting to process the current channel update should fail due to // its edge being considered a zombie and its timestamp not being within @@ -2442,7 +2423,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // to the map of premature ChannelUpdates. Check that nothing // was added to the graph. chanInfo, e1, e2, err := ctx.router.GetChannelByID(batch.chanUpdAnn1.ShortChannelID) - if err != channeldb.ErrEdgeNotFound { + if !errors.Is(err, graphdb.ErrEdgeNotFound) { t.Fatalf("Expected ErrEdgeNotFound, got: %v", err) } if chanInfo != nil { @@ -3482,7 +3463,6 @@ out: const newTimeLockDelta = 100 var edgesToUpdate []EdgeWithInfo err = ctx.router.ForAllOutgoingChannels(func( - _ kvdb.RTx, info *models.ChannelEdgeInfo, edge *models.ChannelEdgePolicy) error { diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index 36c082e36f..88c40c144b 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -17,19 +17,10 @@ import ( func createTestMessageStore(t *testing.T) *MessageStore { t.Helper() - db, err := channeldb.Open(t.TempDir()) - if err != nil { - t.Fatalf("unable to open db: %v", err) - } - - t.Cleanup(func() { - db.Close() - }) + db := channeldb.OpenForTesting(t, t.TempDir()) store, err := NewMessageStore(db) - if err != nil { - t.Fatalf("unable to initialize message store: %v", err) - } + require.NoError(t, err) return store } diff --git a/discovery/syncer.go b/discovery/syncer.go index 3043b39cd6..745fda24b3 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -11,8 +11,8 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" "golang.org/x/time/rate" @@ -373,7 +373,7 @@ type GossipSyncer struct { // bufferedChanRangeReplies is used in the waitingQueryChanReply to // buffer all the chunked response to our query. - bufferedChanRangeReplies []channeldb.ChannelUpdateInfo + bufferedChanRangeReplies []graphdb.ChannelUpdateInfo // numChanRangeRepliesRcvd is used to track the number of replies // received as part of a QueryChannelRange. This field is primarily used @@ -837,7 +837,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro g.prevReplyChannelRange = msg for i, scid := range msg.ShortChanIDs { - info := channeldb.NewChannelUpdateInfo( + info := graphdb.NewChannelUpdateInfo( scid, time.Time{}, time.Time{}, ) @@ -1115,7 +1115,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // this as there's a transport message size limit which we'll need to // adhere to. We also need to make sure all of our replies cover the // expected range of the query. - sendReplyForChunk := func(channelChunk []channeldb.ChannelUpdateInfo, + sendReplyForChunk := func(channelChunk []graphdb.ChannelUpdateInfo, firstHeight, lastHeight uint32, finalChunk bool) error { // The number of blocks contained in the current chunk (the @@ -1164,7 +1164,7 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro var ( firstHeight = query.FirstBlockHeight lastHeight uint32 - channelChunk []channeldb.ChannelUpdateInfo + channelChunk []graphdb.ChannelUpdateInfo ) // chunkSize is the maximum number of SCIDs that we can safely put in a diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 0ee635a0f2..ebc557525b 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" ) @@ -42,7 +42,7 @@ type mockChannelGraphTimeSeries struct { horizonReq chan horizonQuery horizonResp chan []lnwire.Message - filterReq chan []channeldb.ChannelUpdateInfo + filterReq chan []graphdb.ChannelUpdateInfo filterResp chan []lnwire.ShortChannelID filterRangeReqs chan filterRangeReq @@ -64,7 +64,7 @@ func newMockChannelGraphTimeSeries( horizonReq: make(chan horizonQuery, 1), horizonResp: make(chan []lnwire.Message, 1), - filterReq: make(chan []channeldb.ChannelUpdateInfo, 1), + filterReq: make(chan []graphdb.ChannelUpdateInfo, 1), filterResp: make(chan []lnwire.ShortChannelID, 1), filterRangeReqs: make(chan filterRangeReq, 1), @@ -92,7 +92,7 @@ func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash, } func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, - superSet []channeldb.ChannelUpdateInfo, + superSet []graphdb.ChannelUpdateInfo, isZombieChan func(time.Time, time.Time) bool) ( []lnwire.ShortChannelID, error) { @@ -102,16 +102,16 @@ func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, } func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, startHeight, endHeight uint32, withTimestamps bool) ( - []channeldb.BlockChannelRange, error) { + []graphdb.BlockChannelRange, error) { m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} reply := <-m.filterRangeResp - channelsPerBlock := make(map[uint32][]channeldb.ChannelUpdateInfo) + channelsPerBlock := make(map[uint32][]graphdb.ChannelUpdateInfo) for _, cid := range reply { channelsPerBlock[cid.BlockHeight] = append( channelsPerBlock[cid.BlockHeight], - channeldb.ChannelUpdateInfo{ + graphdb.ChannelUpdateInfo{ ShortChannelID: cid, }, ) @@ -127,11 +127,11 @@ func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, }) channelRanges := make( - []channeldb.BlockChannelRange, 0, len(channelsPerBlock), + []graphdb.BlockChannelRange, 0, len(channelsPerBlock), ) for _, block := range blocks { channelRanges = append( - channelRanges, channeldb.BlockChannelRange{ + channelRanges, graphdb.BlockChannelRange{ Height: block, Channels: channelsPerBlock[block], }, diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 4429f5cf5c..377255bbf9 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -200,6 +200,10 @@ The underlying functionality between those two options remain the same. ## Code Health +* A code refactor that [moves all the graph related DB code out of the + `channeldb` package](https://github.com/lightningnetwork/lnd/pull/9236) and + into the `graph/db` package. + ## Tooling and Documentation * [Improved `lncli create` command help text](https://github.com/lightningnetwork/lnd/pull/9077) diff --git a/funding/manager.go b/funding/manager.go index 1fa90c6932..075be12788 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -22,10 +22,10 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/labels" diff --git a/funding/manager_test.go b/funding/manager_test.go index 18dd8cf382..525f69f9a5 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -25,10 +25,10 @@ import ( "github.com/lightningnetwork/lnd/chainreg" acpt "github.com/lightningnetwork/lnd/chanacceptor" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lncfg" @@ -427,10 +427,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, } dbDir := filepath.Join(tempTestDir, "cdb") - fullDB, err := channeldb.Open(dbDir) - if err != nil { - return nil, err - } + fullDB := channeldb.OpenForTesting(t, dbDir) cdb := fullDB.ChannelStateDB() diff --git a/graph/builder.go b/graph/builder.go index 6930f1a894..8c2ba2e3b8 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -16,9 +16,9 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnutils" @@ -201,10 +201,10 @@ func (b *Builder) Start() error { // then we don't treat this as an explicit error. if _, _, err := b.cfg.Graph.PruneTip(); err != nil { switch { - case errors.Is(err, channeldb.ErrGraphNeverPruned): + case errors.Is(err, graphdb.ErrGraphNeverPruned): fallthrough - case errors.Is(err, channeldb.ErrGraphNotFound): + case errors.Is(err, graphdb.ErrGraphNotFound): // If the graph has never been pruned, then we'll set // the prune height to the current best height of the // chain backend. @@ -256,7 +256,7 @@ func (b *Builder) Start() error { // been applied. channelView, err := b.cfg.Graph.ChannelView() if err != nil && !errors.Is( - err, channeldb.ErrGraphNoEdgesFound, + err, graphdb.ErrGraphNoEdgesFound, ) { return err @@ -294,7 +294,7 @@ func (b *Builder) Start() error { // of "useful" nodes. err = b.cfg.Graph.PruneGraphNodes() if err != nil && - !errors.Is(err, channeldb.ErrGraphNodesNotFound) { + !errors.Is(err, graphdb.ErrGraphNodesNotFound) { return err } @@ -352,8 +352,8 @@ func (b *Builder) syncGraphWithChain() error { switch { // If the graph has never been pruned, or hasn't fully been // created yet, then we don't treat this as an explicit error. - case errors.Is(err, channeldb.ErrGraphNeverPruned): - case errors.Is(err, channeldb.ErrGraphNotFound): + case errors.Is(err, graphdb.ErrGraphNeverPruned): + case errors.Is(err, graphdb.ErrGraphNotFound): default: return err } @@ -400,10 +400,10 @@ func (b *Builder) syncGraphWithChain() error { // as this entails we are back to the point where it hasn't seen // any block or created channels, alas there's nothing left to // prune. - case errors.Is(err, channeldb.ErrGraphNeverPruned): + case errors.Is(err, graphdb.ErrGraphNeverPruned): return nil - case errors.Is(err, channeldb.ErrGraphNotFound): + case errors.Is(err, graphdb.ErrGraphNotFound): return nil case err != nil: @@ -658,7 +658,7 @@ func (b *Builder) pruneZombieChans() error { // With the channels pruned, we'll also attempt to prune any nodes that // were a part of them. err = b.cfg.Graph.PruneGraphNodes() - if err != nil && !errors.Is(err, channeldb.ErrGraphNodesNotFound) { + if err != nil && !errors.Is(err, graphdb.ErrGraphNodesNotFound) { return fmt.Errorf("unable to prune graph nodes: %w", err) } @@ -1165,7 +1165,7 @@ func (b *Builder) processUpdate(msg interface{}, op ...batch.SchedulerOption) error { switch msg := msg.(type) { - case *channeldb.LightningNode: + case *models.LightningNode: // Before we add the node to the database, we'll check to see // if the announcement is "fresh" or not. If it isn't, then // we'll return an error. @@ -1192,7 +1192,7 @@ func (b *Builder) processUpdate(msg interface{}, msg.ChannelID, ) if err != nil && - !errors.Is(err, channeldb.ErrGraphNoEdgesFound) { + !errors.Is(err, graphdb.ErrGraphNoEdgesFound) { return errors.Errorf("unable to check for edge "+ "existence: %v", err) @@ -1344,7 +1344,7 @@ func (b *Builder) processUpdate(msg interface{}, // update the current UTXO filter within our active // FilteredChainView so we are notified if/when this channel is // closed. - filterUpdate := []channeldb.EdgePoint{ + filterUpdate := []graphdb.EdgePoint{ { FundingPkScript: fundingPkScript, OutPoint: *fundingPoint, @@ -1371,7 +1371,7 @@ func (b *Builder) processUpdate(msg interface{}, edge1Timestamp, edge2Timestamp, exists, isZombie, err := b.cfg.Graph.HasChannelEdge(msg.ChannelID) if err != nil && !errors.Is( - err, channeldb.ErrGraphNoEdgesFound, + err, graphdb.ErrGraphNoEdgesFound, ) { return errors.Errorf("unable to check for edge "+ @@ -1517,7 +1517,7 @@ func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { // be ignored. // // NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) AddNode(node *channeldb.LightningNode, +func (b *Builder) AddNode(node *models.LightningNode, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -1619,12 +1619,12 @@ func (b *Builder) GetChannelByID(chanID lnwire.ShortChannelID) ( } // FetchLightningNode attempts to look up a target node by its identity public -// key. channeldb.ErrGraphNodeNotFound is returned if the node doesn't exist +// key. graphdb.ErrGraphNodeNotFound is returned if the node doesn't exist // within the graph. // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) FetchLightningNode( - node route.Vertex) (*channeldb.LightningNode, error) { + node route.Vertex) (*models.LightningNode, error) { return b.cfg.Graph.FetchLightningNode(node) } @@ -1633,10 +1633,10 @@ func (b *Builder) FetchLightningNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) ForEachNode( - cb func(*channeldb.LightningNode) error) error { + cb func(*models.LightningNode) error) error { return b.cfg.Graph.ForEachNode( - func(_ kvdb.RTx, n *channeldb.LightningNode) error { + func(_ kvdb.RTx, n *models.LightningNode) error { return cb(n) }) } @@ -1645,11 +1645,11 @@ func (b *Builder) ForEachNode( // the router. // // NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { +func (b *Builder) ForAllOutgoingChannels(cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy) error) error { return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode, - func(tx kvdb.RTx, c *models.ChannelEdgeInfo, + func(_ kvdb.RTx, c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -1658,7 +1658,7 @@ func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx, "has no policy") } - return cb(tx, c, e) + return cb(c, e) }, ) } diff --git a/graph/builder_test.go b/graph/builder_test.go index f6c5dcf9cb..24aff26249 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -20,8 +20,8 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwire" @@ -93,7 +93,7 @@ func TestIgnoreNodeAnnouncement(t *testing.T) { ctx := createTestCtxFromFile(t, startingBlockHeight, basicGraphFilePath) pub := priv1.PubKey() - node := &channeldb.LightningNode{ + node := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix(123, 0), Addresses: testAddrs, @@ -1038,7 +1038,7 @@ func TestIsStaleNode(t *testing.T) { // With the node stub in the database, we'll add the fully node // announcement to the database. - n1 := &channeldb.LightningNode{ + n1 := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: updateTimeStamp, Addresses: testAddrs, @@ -1453,7 +1453,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( privKeyMap := make(map[string]*btcec.PrivateKey) channelIDs := make(map[route.Vertex]map[route.Vertex]uint64) links := make(map[lnwire.ShortChannelID]htlcswitch.ChannelLink) - var source *channeldb.LightningNode + var source *models.LightningNode // First we insert all the nodes within the graph as vertexes. for _, node := range g.Nodes { @@ -1462,7 +1462,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( return nil, err } - dbNode := &channeldb.LightningNode{ + dbNode := &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), LastUpdate: testTime, @@ -1593,10 +1593,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } err = graph.AddChannelEdge(&edgeInfo) - if err != nil && !errors.Is( - err, channeldb.ErrEdgeAlreadyExist, - ) { - + if err != nil && !errors.Is(err, graphdb.ErrEdgeAlreadyExist) { return nil, err } @@ -1753,7 +1750,7 @@ func asymmetricTestChannel(alias1, alias2 string, capacity btcutil.Amount, // assertChannelsPruned ensures that only the given channels are pruned from the // graph out of the set of all channels. -func assertChannelsPruned(t *testing.T, graph *channeldb.ChannelGraph, +func assertChannelsPruned(t *testing.T, graph *graphdb.ChannelGraph, channels []*testChannel, prunedChanIDs ...uint64) { t.Helper() @@ -1835,7 +1832,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, nodeIndex := byte(0) addNodeWithAlias := func(alias string, features *lnwire.FeatureVector) ( - *channeldb.LightningNode, error) { + *models.LightningNode, error) { keyBytes := []byte{ 0, 0, 0, 0, 0, 0, 0, 0, @@ -1850,7 +1847,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, features = lnwire.EmptyFeatureVector() } - dbNode := &channeldb.LightningNode{ + dbNode := &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), LastUpdate: testTime, @@ -1959,7 +1956,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, err = graph.AddChannelEdge(&edgeInfo) if err != nil && - !errors.Is(err, channeldb.ErrEdgeAlreadyExist) { + !errors.Is(err, graphdb.ErrEdgeAlreadyExist) { return nil, err } diff --git a/channeldb/addr.go b/graph/db/addr.go similarity index 94% rename from channeldb/addr.go rename to graph/db/addr.go index dd057265c2..f994131582 100644 --- a/channeldb/addr.go +++ b/graph/db/addr.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb import ( "encoding/binary" @@ -121,10 +121,10 @@ func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error { return nil } -// deserializeAddr reads the serialized raw representation of an address and +// DeserializeAddr reads the serialized raw representation of an address and // deserializes it into the actual address. This allows us to avoid address // resolution within the channeldb package. -func deserializeAddr(r io.Reader) (net.Addr, error) { +func DeserializeAddr(r io.Reader) (net.Addr, error) { var addrType [1]byte if _, err := r.Read(addrType[:]); err != nil { return nil, err @@ -207,9 +207,9 @@ func deserializeAddr(r io.Reader) (net.Addr, error) { return address, nil } -// serializeAddr serializes an address into its raw bytes representation so that +// SerializeAddr serializes an address into its raw bytes representation so that // it can be deserialized without requiring address resolution. -func serializeAddr(w io.Writer, address net.Addr) error { +func SerializeAddr(w io.Writer, address net.Addr) error { switch addr := address.(type) { case *net.TCPAddr: return encodeTCPAddr(w, addr) diff --git a/channeldb/addr_test.go b/graph/db/addr_test.go similarity index 97% rename from channeldb/addr_test.go rename to graph/db/addr_test.go index c761989c05..2e6e5439cd 100644 --- a/channeldb/addr_test.go +++ b/graph/db/addr_test.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb import ( "bytes" @@ -117,7 +117,7 @@ func TestAddrSerialization(t *testing.T) { var b bytes.Buffer for _, test := range addrTests { - err := serializeAddr(&b, test.expAddr) + err := SerializeAddr(&b, test.expAddr) switch { case err == nil && test.serErr != "": t.Fatalf("expected serialization err for addr %v", @@ -136,7 +136,7 @@ func TestAddrSerialization(t *testing.T) { continue } - addr, err := deserializeAddr(&b) + addr, err := DeserializeAddr(&b) if err != nil { t.Fatalf("unable to deserialize address: %v", err) } diff --git a/channeldb/channel_cache.go b/graph/db/channel_cache.go similarity index 98% rename from channeldb/channel_cache.go rename to graph/db/channel_cache.go index 2f26c185fb..b50bbf4988 100644 --- a/channeldb/channel_cache.go +++ b/graph/db/channel_cache.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb // channelCache is an in-memory cache used to improve the performance of // ChanUpdatesInHorizon. It caches the chan info and edge policies for a diff --git a/channeldb/channel_cache_test.go b/graph/db/channel_cache_test.go similarity index 97% rename from channeldb/channel_cache_test.go rename to graph/db/channel_cache_test.go index 7cb857293b..767958d9a1 100644 --- a/channeldb/channel_cache_test.go +++ b/graph/db/channel_cache_test.go @@ -1,10 +1,10 @@ -package channeldb +package graphdb import ( "reflect" "testing" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" ) // TestChannelCache checks the behavior of the channelCache with respect to diff --git a/graph/db/codec.go b/graph/db/codec.go new file mode 100644 index 0000000000..029f9b93d4 --- /dev/null +++ b/graph/db/codec.go @@ -0,0 +1,39 @@ +package graphdb + +import ( + "encoding/binary" + "io" + + "github.com/btcsuite/btcd/wire" +) + +var ( + // byteOrder defines the preferred byte order, which is Big Endian. + byteOrder = binary.BigEndian +) + +// WriteOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func WriteOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// ReadOutpoint reads an outpoint from the passed reader that was previously +// written using the WriteOutpoint struct. +func ReadOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} diff --git a/graph/db/errors.go b/graph/db/errors.go new file mode 100644 index 0000000000..afdbdb457e --- /dev/null +++ b/graph/db/errors.go @@ -0,0 +1,75 @@ +package graphdb + +import ( + "errors" + "fmt" +) + +var ( + // ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel + // policy field is not found in the db even though its message flags + // indicate it should be. + ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " + + "present") + + // ErrGraphNotFound is returned when at least one of the components of + // graph doesn't exist. + ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") + + // ErrGraphNeverPruned is returned when graph was never pruned. + ErrGraphNeverPruned = fmt.Errorf("graph never pruned") + + // ErrSourceNodeNotSet is returned if the source node of the graph + // hasn't been added The source node is the center node within a + // star-graph. + ErrSourceNodeNotSet = fmt.Errorf("source node does not exist") + + // ErrGraphNodesNotFound is returned in case none of the nodes has + // been added in graph node bucket. + ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist") + + // ErrGraphNoEdgesFound is returned in case of none of the channel/edges + // has been added in graph edge bucket. + ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist") + + // ErrGraphNodeNotFound is returned when we're unable to find the target + // node. + ErrGraphNodeNotFound = fmt.Errorf("unable to find node") + + // ErrZombieEdge is an error returned when we attempt to look up an edge + // but it is marked as a zombie within the zombie index. + ErrZombieEdge = errors.New("edge marked as zombie") + + // ErrEdgeNotFound is returned when an edge for the target chanID + // can't be found. + ErrEdgeNotFound = fmt.Errorf("edge not found") + + // ErrEdgeAlreadyExist is returned when edge with specific + // channel id can't be added because it already exist. + ErrEdgeAlreadyExist = fmt.Errorf("edge already exist") + + // ErrNodeAliasNotFound is returned when alias for node can't be found. + ErrNodeAliasNotFound = fmt.Errorf("alias for node not found") + + // ErrClosedScidsNotFound is returned when the closed scid bucket + // hasn't been created. + ErrClosedScidsNotFound = fmt.Errorf("closed scid bucket doesn't exist") + + // ErrZombieEdgeNotFound is an error returned when we attempt to find an + // edge in the zombie index which is not there. + ErrZombieEdgeNotFound = errors.New("edge not found in zombie index") + + // ErrUnknownAddressType is returned when a node's addressType is not + // an expected value. + ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved") +) + +// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the +// caller attempts to write an announcement message which bares too many extra +// opaque bytes. We limit this value in order to ensure that we don't waste +// disk space due to nodes unnecessarily padding out their announcements with +// garbage data. +func ErrTooManyExtraOpaqueBytes(numBytes int) error { + return fmt.Errorf("max allowed number of opaque bytes is %v, received "+ + "%v bytes", MaxAllowedExtraOpaqueBytes, numBytes) +} diff --git a/channeldb/graph.go b/graph/db/graph.go similarity index 93% rename from channeldb/graph.go rename to graph/db/graph.go index d7a4480d03..fc1b26ad0f 100644 --- a/channeldb/graph.go +++ b/graph/db/graph.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb import ( "bytes" @@ -6,22 +6,21 @@ import ( "encoding/binary" "errors" "fmt" - "image/color" "io" "math" "net" "sort" "sync" + "testing" "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/batch" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -124,9 +123,9 @@ var ( // edge's participants. zombieBucket = []byte("zombie-index") - // disabledEdgePolicyBucket is a sub-bucket of the main edgeBucket bucket - // responsible for maintaining an index of disabled edge policies. Each - // entry exists within the bucket as follows: + // disabledEdgePolicyBucket is a sub-bucket of the main edgeBucket + // bucket responsible for maintaining an index of disabled edge + // policies. Each entry exists within the bucket as follows: // // maps: -> []byte{} // @@ -198,11 +197,15 @@ type ChannelGraph struct { // NewChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. -func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, - batchCommitInterval time.Duration, preAllocCacheNumNodes int, - useGraphCache, noMigrations bool) (*ChannelGraph, error) { +func NewChannelGraph(db kvdb.Backend, options ...OptionModifier) (*ChannelGraph, + error) { + + opts := DefaultOptions() + for _, o := range options { + o(opts) + } - if !noMigrations { + if !opts.NoMigration { if err := initChannelGraph(db); err != nil { return nil, err } @@ -210,20 +213,20 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, g := &ChannelGraph{ db: db, - rejectCache: newRejectCache(rejectCacheSize), - chanCache: newChannelCache(chanCacheSize), + rejectCache: newRejectCache(opts.RejectCacheSize), + chanCache: newChannelCache(opts.ChannelCacheSize), } g.chanScheduler = batch.NewTimeScheduler( - db, &g.cacheMu, batchCommitInterval, + db, &g.cacheMu, opts.BatchCommitInterval, ) g.nodeScheduler = batch.NewTimeScheduler( - db, nil, batchCommitInterval, + db, nil, opts.BatchCommitInterval, ) // The graph cache can be turned off (e.g. for mobile users) for a // speed/memory usage tradeoff. - if useGraphCache { - g.graphCache = NewGraphCache(preAllocCacheNumNodes) + if opts.UseGraphCache { + g.graphCache = NewGraphCache(opts.PreAllocCacheNumNodes) startTime := time.Now() log.Debugf("Populating in-memory channel graph, this might " + "take a while...") @@ -349,7 +352,7 @@ func (c *ChannelGraph) Wipe() error { return initChannelGraph(c.db) } -// createChannelDB creates and initializes a fresh version of channeldb. In +// createChannelDB creates and initializes a fresh version of In // the case that the target path has not yet been created or doesn't yet exist, // then the path is created. Additionally, all required top-level buckets used // within the database are created. @@ -410,6 +413,32 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { return c.db.BeginReadTx() } +// AddrsForNode returns all known addresses for the target node public key that +// the graph DB is aware of. The returned boolean indicates if the given node is +// unknown to the graph DB or not. +// +// NOTE: this is part of the channeldb.AddrSource interface. +func (c *ChannelGraph) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, + error) { + + pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed()) + if err != nil { + return false, nil, err + } + + node, err := c.FetchLightningNode(pubKey) + // We don't consider it an error if the graph is unaware of the node. + switch { + case err != nil && !errors.Is(err, ErrGraphNodeNotFound): + return false, nil, err + + case errors.Is(err, ErrGraphNodeNotFound): + return false, nil, nil + } + + return true, node.Addresses, nil +} + // ForEachChannel iterates through all the channel edges stored within the // graph and invokes the passed callback for each edge. The callback takes two // edges as since this is a directed graph, both the in/out edges are visited. @@ -442,28 +471,32 @@ func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo, // Load edge index, recombine each channel with the policies // loaded above and invoke the callback. - return kvdb.ForAll(edgeIndex, func(k, edgeInfoBytes []byte) error { - var chanID [8]byte - copy(chanID[:], k) - - edgeInfoReader := bytes.NewReader(edgeInfoBytes) - info, err := deserializeChanEdgeInfo(edgeInfoReader) - if err != nil { - return err - } + return kvdb.ForAll( + edgeIndex, func(k, edgeInfoBytes []byte) error { + var chanID [8]byte + copy(chanID[:], k) + + edgeInfoReader := bytes.NewReader(edgeInfoBytes) + info, err := deserializeChanEdgeInfo( + edgeInfoReader, + ) + if err != nil { + return err + } - policy1 := channelMap[channelMapKey{ - nodeKey: info.NodeKey1Bytes, - chanID: chanID, - }] + policy1 := channelMap[channelMapKey{ + nodeKey: info.NodeKey1Bytes, + chanID: chanID, + }] - policy2 := channelMap[channelMapKey{ - nodeKey: info.NodeKey2Bytes, - chanID: chanID, - }] + policy2 := channelMap[channelMapKey{ + nodeKey: info.NodeKey2Bytes, + chanID: chanID, + }] - return cb(&info, policy1, policy2) - }) + return cb(&info, policy1, policy2) + }, + ) }, func() {}) } @@ -571,7 +604,9 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, // We'll iterate over each node, then the set of channels for each // node, and construct a similar callback functiopn signature as the // main funcotin expects. - return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { + return c.ForEachNode(func(tx kvdb.RTx, + node *models.LightningNode) error { + channels := make(map[uint64]*DirectedChannel) err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes, @@ -646,20 +681,27 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { return nil } - // We iterate over all disabled policies and we add each channel that - // has more than one disabled policy to disabledChanIDs array. - return disabledEdgePolicyIndex.ForEach(func(k, v []byte) error { - chanID := byteOrder.Uint64(k[:8]) - _, edgeFound := chanEdgeFound[chanID] - if edgeFound { - delete(chanEdgeFound, chanID) - disabledChanIDs = append(disabledChanIDs, chanID) - return nil - } + // We iterate over all disabled policies and we add each channel + // that has more than one disabled policy to disabledChanIDs + // array. + return disabledEdgePolicyIndex.ForEach( + func(k, v []byte) error { + chanID := byteOrder.Uint64(k[:8]) + _, edgeFound := chanEdgeFound[chanID] + if edgeFound { + delete(chanEdgeFound, chanID) + disabledChanIDs = append( + disabledChanIDs, chanID, + ) + + return nil + } - chanEdgeFound[chanID] = struct{}{} - return nil - }) + chanEdgeFound[chanID] = struct{}{} + + return nil + }, + ) }, func() { disabledChanIDs = nil chanEdgeFound = make(map[uint64]struct{}) @@ -679,7 +721,7 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { // TODO(roasbeef): add iterator interface to allow for memory efficient graph // traversal when graph gets mega func (c *ChannelGraph) ForEachNode( - cb func(kvdb.RTx, *LightningNode) error) error { + cb func(kvdb.RTx, *models.LightningNode) error) error { traversal := func(tx kvdb.RTx) error { // First grab the nodes bucket which stores the mapping from @@ -756,8 +798,8 @@ func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx, // as the center node within a star-graph. This method may be used to kick off // a path finding algorithm in order to explore the reachability of another // node based off the source node. -func (c *ChannelGraph) SourceNode() (*LightningNode, error) { - var source *LightningNode +func (c *ChannelGraph) SourceNode() (*models.LightningNode, error) { + var source *models.LightningNode err := kvdb.View(c.db, func(tx kvdb.RTx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. @@ -787,7 +829,9 @@ func (c *ChannelGraph) SourceNode() (*LightningNode, error) { // of the graph. The source node is treated as the center node within a // star-graph. This method may be used to kick off a path finding algorithm in // order to explore the reachability of another node based off the source node. -func (c *ChannelGraph) sourceNode(nodes kvdb.RBucket) (*LightningNode, error) { +func (c *ChannelGraph) sourceNode(nodes kvdb.RBucket) (*models.LightningNode, + error) { + selfPub := nodes.Get(sourceKey) if selfPub == nil { return nil, ErrSourceNodeNotSet @@ -806,7 +850,7 @@ func (c *ChannelGraph) sourceNode(nodes kvdb.RBucket) (*LightningNode, error) { // SetSourceNode sets the source node within the graph database. The source // node is to be used as the center of a star-graph within path finding // algorithms. -func (c *ChannelGraph) SetSourceNode(node *LightningNode) error { +func (c *ChannelGraph) SetSourceNode(node *models.LightningNode) error { nodePubBytes := node.PubKeyBytes[:] return kvdb.Update(c.db, func(tx kvdb.RwTx) error { @@ -837,7 +881,7 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error { // channel update. // // TODO(roasbeef): also need sig of announcement -func (c *ChannelGraph) AddLightningNode(node *LightningNode, +func (c *ChannelGraph) AddLightningNode(node *models.LightningNode, op ...batch.SchedulerOption) error { r := &batch.Request{ @@ -863,7 +907,7 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode, return c.nodeScheduler.Execute(r) } -func addLightningNode(tx kvdb.RwTx, node *LightningNode) error { +func addLightningNode(tx kvdb.RwTx, node *models.LightningNode) error { nodes, err := tx.CreateTopLevelBucket(nodeBucket) if err != nil { return err @@ -1076,7 +1120,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, _, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:]) switch { case node1Err == ErrGraphNodeNotFound: - node1Shell := LightningNode{ + node1Shell := models.LightningNode{ PubKeyBytes: edge.NodeKey1Bytes, HaveNodeAnnouncement: false, } @@ -1092,7 +1136,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, _, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:]) switch { case node2Err == ErrGraphNodeNotFound: - node2Shell := LightningNode{ + node2Shell := models.LightningNode{ PubKeyBytes: edge.NodeKey2Bytes, HaveNodeAnnouncement: false, } @@ -1128,7 +1172,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, // Finally we add it to the channel index which maps channel points // (outpoints) to the shorter channel ID's. var b bytes.Buffer - if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { + if err := WriteOutpoint(&b, &edge.ChannelPoint); err != nil { return err } return chanIndex.Put(b.Bytes(), chanKey[:]) @@ -1309,12 +1353,15 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return err } - // Next grab the two edge indexes which will also need to be updated. + // Next grab the two edge indexes which will also need to be + // updated. edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) if err != nil { return err } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + chanIndex, err := edges.CreateBucketIfNotExists( + channelPointBucket, + ) if err != nil { return err } @@ -1335,7 +1382,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // if NOT if filter var opBytes bytes.Buffer - if err := writeOutpoint(&opBytes, chanPoint); err != nil { + err := WriteOutpoint(&opBytes, chanPoint) + if err != nil { return err } @@ -1374,7 +1422,9 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return err } - pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + pruneBucket, err := metaBucket.CreateBucketIfNotExists( + pruneLogBucket, + ) if err != nil { return err } @@ -1520,7 +1570,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // If we reach this point, then there are no longer any edges // that connect this node, so we can delete it. - if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { + err := c.deleteLightningNode(nodes, nodePubKey[:]) + if err != nil { if errors.Is(err, ErrGraphNodeNotFound) || errors.Is(err, ErrGraphNodesNotFound) { @@ -1587,7 +1638,9 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( if err != nil { return err } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + chanIndex, err := edges.CreateBucketIfNotExists( + channelPointBucket, + ) if err != nil { return err } @@ -1635,7 +1688,9 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( return err } - pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + pruneBucket, err := metaBucket.CreateBucketIfNotExists( + pruneLogBucket, + ) if err != nil { return err } @@ -1650,9 +1705,9 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( // the keys in a second loop. var pruneKeys [][]byte pruneCursor := pruneBucket.ReadWriteCursor() + //nolint:lll for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil && bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() { - pruneKeys = append(pruneKeys, k) } @@ -1807,7 +1862,7 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { // getChanID returns the assigned channel ID for a given channel point. func getChanID(tx kvdb.RTx, chanPoint *wire.OutPoint) (uint64, error) { var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { + if err := WriteOutpoint(&b, chanPoint); err != nil { return 0, err } @@ -1891,11 +1946,11 @@ type ChannelEdge struct { // Node1 is "node 1" in the channel. This is the node that would have // produced Policy1 if it exists. - Node1 *LightningNode + Node1 *models.LightningNode // Node2 is "node 2" in the channel. This is the node that would have // produced Policy2 if it exists. - Node2 *LightningNode + Node2 *models.LightningNode } // ChanUpdatesInHorizon returns all the known channel edges which have at least @@ -1948,6 +2003,8 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, // With our start and end times constructed, we'll step through // the index collecting the info and policy of each update of // each channel that has a last update within the time range. + // + //nolint:lll for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { @@ -2051,9 +2108,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, // nodes to quickly determine if they have the same set of up to date node // announcements. func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, - endTime time.Time) ([]LightningNode, error) { + endTime time.Time) ([]models.LightningNode, error) { - var nodesInHorizon []LightningNode + var nodesInHorizon []models.LightningNode err := kvdb.View(c.db, func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) @@ -2081,6 +2138,8 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, // With our start and end times constructed, we'll step through // the index collecting info for each node within the time // range. + // + //nolint:lll for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { @@ -2317,6 +2376,8 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, // We'll now iterate through the database, and find each // channel ID that resides within the specified range. + // + //nolint:lll for k, v := cursor.Seek(chanIDStart[:]); k != nil && bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() { // Don't send alias SCIDs during gossip sync. @@ -2549,7 +2610,9 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, // would have been created by both edges: we'll alternate the update // times, as one may had overridden the other. if edge1 != nil { - byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix())) + byteOrder.PutUint64( + indexKey[:8], uint64(edge1.LastUpdate.Unix()), + ) if err := updateIndex.Delete(indexKey[:]); err != nil { return err } @@ -2558,7 +2621,9 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, // We'll also attempt to delete the entry that may have been created by // the second edge. if edge2 != nil { - byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix())) + byteOrder.PutUint64( + indexKey[:8], uint64(edge2.LastUpdate.Unix()), + ) if err := updateIndex.Delete(indexKey[:]); err != nil { return err } @@ -2625,7 +2690,8 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, } // As part of deleting the edge we also remove all disabled entries - // from the edgePolicyDisabledIndex bucket. We do that for both directions. + // from the edgePolicyDisabledIndex bucket. We do that for both + // directions. updateEdgePolicyDisabledIndex(edges, cid, false, false) updateEdgePolicyDisabledIndex(edges, cid, true, false) @@ -2635,7 +2701,7 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, return err } var b bytes.Buffer - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := WriteOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { return err } if err := chanIndex.Delete(b.Bytes()); err != nil { @@ -2849,127 +2915,6 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy, return isUpdate1, nil } -// LightningNode represents an individual vertex/node within the channel graph. -// A node is connected to other nodes by one or more channel edges emanating -// from it. As the graph is directed, a node will also have an incoming edge -// attached to it for each outgoing edge. -type LightningNode struct { - // PubKeyBytes is the raw bytes of the public key of the target node. - PubKeyBytes [33]byte - pubKey *btcec.PublicKey - - // HaveNodeAnnouncement indicates whether we received a node - // announcement for this particular node. If true, the remaining fields - // will be set, if false only the PubKey is known for this node. - HaveNodeAnnouncement bool - - // LastUpdate is the last time the vertex information for this node has - // been updated. - LastUpdate time.Time - - // Address is the TCP address this node is reachable over. - Addresses []net.Addr - - // Color is the selected color for the node. - Color color.RGBA - - // Alias is a nick-name for the node. The alias can be used to confirm - // a node's identity or to serve as a short ID for an address book. - Alias string - - // AuthSigBytes is the raw signature under the advertised public key - // which serves to authenticate the attributes announced by this node. - AuthSigBytes []byte - - // Features is the list of protocol features supported by this node. - Features *lnwire.FeatureVector - - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData []byte - - // TODO(roasbeef): discovery will need storage to keep it's last IP - // address and re-announce if interface changes? - - // TODO(roasbeef): add update method and fetch? -} - -// PubKey is the node's long-term identity public key. This key will be used to -// authenticated any advertisements/updates sent by the node. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (l *LightningNode) PubKey() (*btcec.PublicKey, error) { - if l.pubKey != nil { - return l.pubKey, nil - } - - key, err := btcec.ParsePubKey(l.PubKeyBytes[:]) - if err != nil { - return nil, err - } - l.pubKey = key - - return key, nil -} - -// AuthSig is a signature under the advertised public key which serves to -// authenticate the attributes announced by this node. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (l *LightningNode) AuthSig() (*ecdsa.Signature, error) { - return ecdsa.ParseSignature(l.AuthSigBytes) -} - -// AddPubKey is a setter-link method that can be used to swap out the public -// key for a node. -func (l *LightningNode) AddPubKey(key *btcec.PublicKey) { - l.pubKey = key - copy(l.PubKeyBytes[:], key.SerializeCompressed()) -} - -// NodeAnnouncement retrieves the latest node announcement of the node. -func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, - error) { - - if !l.HaveNodeAnnouncement { - return nil, fmt.Errorf("node does not have node announcement") - } - - alias, err := lnwire.NewNodeAlias(l.Alias) - if err != nil { - return nil, err - } - - nodeAnn := &lnwire.NodeAnnouncement{ - Features: l.Features.RawFeatureVector, - NodeID: l.PubKeyBytes, - RGBColor: l.Color, - Alias: alias, - Addresses: l.Addresses, - Timestamp: uint32(l.LastUpdate.Unix()), - ExtraOpaqueData: l.ExtraOpaqueData, - } - - if !signed { - return nodeAnn, nil - } - - sig, err := lnwire.NewSigFromECDSARawSignature(l.AuthSigBytes) - if err != nil { - return nil, err - } - - nodeAnn.Signature = sig - - return nodeAnn, nil -} - // isPublic determines whether the node is seen as public within the graph from // the source node's point of view. An existing database transaction can also be // specified. @@ -3019,7 +2964,7 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, // ErrGraphNodeNotFound is returned. An optional transaction may be provided. // If none is provided, then a new one will be created. func (c *ChannelGraph) FetchLightningNodeTx(tx kvdb.RTx, nodePub route.Vertex) ( - *LightningNode, error) { + *models.LightningNode, error) { return c.fetchLightningNode(tx, nodePub) } @@ -3027,8 +2972,8 @@ func (c *ChannelGraph) FetchLightningNodeTx(tx kvdb.RTx, nodePub route.Vertex) ( // FetchLightningNode attempts to look up a target node by its identity public // key. If the node isn't found in the database, then ErrGraphNodeNotFound is // returned. -func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (*LightningNode, - error) { +func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( + *models.LightningNode, error) { return c.fetchLightningNode(nil, nodePub) } @@ -3038,9 +2983,9 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (*LightningNode, // returned. An optional transaction may be provided. If none is provided, then // a new one will be created. func (c *ChannelGraph) fetchLightningNode(tx kvdb.RTx, - nodePub route.Vertex) (*LightningNode, error) { + nodePub route.Vertex) (*models.LightningNode, error) { - var node *LightningNode + var node *models.LightningNode fetch := func(tx kvdb.RTx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. @@ -3139,7 +3084,9 @@ var _ GraphCacheNode = (*graphCacheNode)(nil) // timestamp of when the data for the node was lasted updated is returned along // with a true boolean. Otherwise, an empty time.Time is returned with a false // boolean. -func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, error) { +func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, + error) { + var ( updateTime time.Time exists bool @@ -3216,7 +3163,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // as its prefix. This indicates that we've stepped over into // another node's edges, so we can terminate our scan. edgeCursor := edges.ReadCursor() - for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { + for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { //nolint:lll // If the prefix still matches, the channel id is // returned in nodeEdge. Channel id is used to lookup // the node at the other end of the channel and both @@ -3308,8 +3255,8 @@ func (c *ChannelGraph) ForEachNodeChannelTx(tx kvdb.RTx, // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, - channel *models.ChannelEdgeInfo, thisNodeKey []byte) (*LightningNode, - error) { + channel *models.ChannelEdgeInfo, thisNodeKey []byte) ( + *models.LightningNode, error) { // Ensure that the node passed in is actually a member of the channel. var targetNodeBytes [33]byte @@ -3322,7 +3269,7 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, return nil, fmt.Errorf("node not participating in this channel") } - var targetNode *LightningNode + var targetNode *models.LightningNode fetchNodeFunc := func(tx kvdb.RTx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. @@ -3345,7 +3292,9 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, // otherwise we can use the existing db transaction. var err error if tx == nil { - err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil }) + err = kvdb.View(c.db, fetchNodeFunc, func() { + targetNode = nil + }) } else { err = fetchNodeFunc(tx) } @@ -3413,7 +3362,7 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( return ErrGraphNoEdgesFound } var b bytes.Buffer - if err := writeOutpoint(&b, op); err != nil { + if err := WriteOutpoint(&b, op); err != nil { return err } chanID := chanIndex.Get(b.Bytes()) @@ -3655,37 +3604,41 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { // Once we have the proper bucket, we'll range over each key // (which is the channel point for the channel) and decode it, // accumulating each entry. - return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error { - chanPointReader := bytes.NewReader(chanPointBytes) + return chanIndex.ForEach( + func(chanPointBytes, chanID []byte) error { + chanPointReader := bytes.NewReader( + chanPointBytes, + ) - var chanPoint wire.OutPoint - err := readOutpoint(chanPointReader, &chanPoint) - if err != nil { - return err - } + var chanPoint wire.OutPoint + err := ReadOutpoint(chanPointReader, &chanPoint) + if err != nil { + return err + } - edgeInfo, err := fetchChanEdgeInfo( - edgeIndex, chanID, - ) - if err != nil { - return err - } + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, chanID, + ) + if err != nil { + return err + } - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], - edgeInfo.BitcoinKey2Bytes[:], - ) - if err != nil { - return err - } + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], + edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + return err + } - edgePoints = append(edgePoints, EdgePoint{ - FundingPkScript: pkScript, - OutPoint: chanPoint, - }) + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: chanPoint, + }) - return nil - }) + return nil + }, + ) }, func() { edgePoints = nil }); err != nil { @@ -3945,7 +3898,7 @@ func (c *ChannelGraph) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) { } func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket, // nolint:dupl - updateIndex kvdb.RwBucket, node *LightningNode) error { + updateIndex kvdb.RwBucket, node *models.LightningNode) error { var ( scratch [16]byte @@ -4016,7 +3969,7 @@ func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket, // no } for _, address := range node.Addresses { - if err := serializeAddr(&b, address); err != nil { + if err := SerializeAddr(&b, address); err != nil { return err } } @@ -4074,11 +4027,11 @@ func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket, // no } func fetchLightningNode(nodeBucket kvdb.RBucket, - nodePub []byte) (LightningNode, error) { + nodePub []byte) (models.LightningNode, error) { nodeBytes := nodeBucket.Get(nodePub) if nodeBytes == nil { - return LightningNode{}, ErrGraphNodeNotFound + return models.LightningNode{}, ErrGraphNodeNotFound } nodeReader := bytes.NewReader(nodeBytes) @@ -4141,9 +4094,9 @@ func deserializeLightningNodeCacheable(r io.Reader) (*graphCacheNode, error) { return node, nil } -func deserializeLightningNode(r io.Reader) (LightningNode, error) { +func deserializeLightningNode(r io.Reader) (models.LightningNode, error) { var ( - node LightningNode + node models.LightningNode scratch [8]byte err error ) @@ -4153,18 +4106,18 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { node.Features = lnwire.EmptyFeatureVector() if _, err := r.Read(scratch[:]); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } unix := int64(byteOrder.Uint64(scratch[:])) node.LastUpdate = time.Unix(unix, 0) if _, err := io.ReadFull(r, node.PubKeyBytes[:]); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } if _, err := r.Read(scratch[:2]); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } hasNodeAnn := byteOrder.Uint16(scratch[:2]) @@ -4174,8 +4127,8 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { node.HaveNodeAnnouncement = false } - // The rest of the data is optional, and will only be there if we got a node - // announcement for this node. + // The rest of the data is optional, and will only be there if we got a + // node announcement for this node. if !node.HaveNodeAnnouncement { return node, nil } @@ -4183,35 +4136,35 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { // We did get a node announcement for this node, so we'll have the rest // of the data available. if err := binary.Read(r, byteOrder, &node.Color.R); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } if err := binary.Read(r, byteOrder, &node.Color.G); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } if err := binary.Read(r, byteOrder, &node.Color.B); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } node.Alias, err = wire.ReadVarString(r, 0) if err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } err = node.Features.Decode(r) if err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } if _, err := r.Read(scratch[:2]); err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } numAddresses := int(byteOrder.Uint16(scratch[:2])) var addresses []net.Addr for i := 0; i < numAddresses; i++ { - address, err := deserializeAddr(r) + address, err := DeserializeAddr(r) if err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } addresses = append(addresses, address) } @@ -4219,7 +4172,7 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { node.AuthSigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") if err != nil { - return LightningNode{}, err + return models.LightningNode{}, err } // We'll try and see if there are any opaque bytes left, if not, then @@ -4231,7 +4184,7 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { case err == io.ErrUnexpectedEOF: case err == io.EOF: case err != nil: - return LightningNode{}, err + return models.LightningNode{}, err } return node, nil @@ -4281,10 +4234,11 @@ func putChanEdgeInfo(edgeIndex kvdb.RwBucket, return err } - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := WriteOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { return err } - if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { + err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)) + if err != nil { return err } if _, err := b.Write(chanID[:]); err != nil { @@ -4297,7 +4251,7 @@ func putChanEdgeInfo(edgeIndex kvdb.RwBucket, if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) } - err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) + err = wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) if err != nil { return err } @@ -4365,7 +4319,7 @@ func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo, error) { } edgeInfo.ChannelPoint = wire.OutPoint{} - if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { + if err := ReadOutpoint(r, &edgeInfo.ChannelPoint); err != nil { return models.ChannelEdgeInfo{}, err } if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { @@ -4610,10 +4564,14 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy, if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil { return err } - if err := binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)); err != nil { + err = binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)) + if err != nil { return err } - if err := binary.Write(w, byteOrder, uint64(edge.FeeProportionalMillionths)); err != nil { + err = binary.Write( + w, byteOrder, uint64(edge.FeeProportionalMillionths), + ) + if err != nil { return err } @@ -4745,3 +4703,34 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy, return edge, nil } + +// MakeTestGraph creates a new instance of the ChannelGraph for testing +// purposes. +func MakeTestGraph(t testing.TB, modifiers ...OptionModifier) (*ChannelGraph, + error) { + + opts := DefaultOptions() + for _, modifier := range modifiers { + modifier(opts) + } + + // Next, create channelgraph for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(t.TempDir(), "cgr") + if err != nil { + backendCleanup() + return nil, err + } + + graph, err := NewChannelGraph(backend) + if err != nil { + backendCleanup() + return nil, err + } + + t.Cleanup(func() { + _ = backend.Close() + backendCleanup() + }) + + return graph, nil +} diff --git a/channeldb/graph_cache.go b/graph/db/graph_cache.go similarity index 99% rename from channeldb/graph_cache.go rename to graph/db/graph_cache.go index 9bd2a82658..2b10a0a15a 100644 --- a/channeldb/graph_cache.go +++ b/graph/db/graph_cache.go @@ -1,11 +1,11 @@ -package channeldb +package graphdb import ( "fmt" "sync" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/channeldb/graph_cache_test.go b/graph/db/graph_cache_test.go similarity index 98% rename from channeldb/graph_cache_test.go rename to graph/db/graph_cache_test.go index f7ed5cee60..3f140c4c5f 100644 --- a/channeldb/graph_cache_test.go +++ b/graph/db/graph_cache_test.go @@ -1,10 +1,10 @@ -package channeldb +package graphdb import ( "encoding/hex" "testing" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" diff --git a/channeldb/graph_test.go b/graph/db/graph_test.go similarity index 95% rename from channeldb/graph_test.go rename to graph/db/graph_test.go index 89197a0a80..82c4965fba 100644 --- a/channeldb/graph_test.go +++ b/graph/db/graph_test.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb import ( "bytes" @@ -21,8 +21,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -37,13 +36,15 @@ var ( "[2001:db8:85a3:0:0:8a2e:370:7334]:80") testAddrs = []net.Addr{testAddr, anotherAddr} - testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") - testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") - testRScalar = new(btcec.ModNScalar) - testSScalar = new(btcec.ModNScalar) - _ = testRScalar.SetByteSlice(testRBytes) - _ = testSScalar.SetByteSlice(testSBytes) - testSig = ecdsa.NewSignature(testRScalar, testSScalar) + testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18" + + "e949ddfa2965fb6caa1bf0314f882d7") + testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6" + + "700d72a0ead154c03be696a292d24ae") + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) testFeatures = lnwire.NewFeatureVector( lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), @@ -51,45 +52,27 @@ var ( ) testPub = route.Vertex{2, 202, 4} -) -// MakeTestGraph creates a new instance of the ChannelGraph for testing purposes. -func MakeTestGraph(t testing.TB, modifiers ...OptionModifier) (*ChannelGraph, error) { - opts := DefaultOptions() - for _, modifier := range modifiers { - modifier(&opts) + key = [chainhash.HashSize]byte{ + 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, + 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, } - - // Next, create channelgraph for the first time. - backend, backendCleanup, err := kvdb.GetTestBackend(t.TempDir(), "cgr") - if err != nil { - backendCleanup() - return nil, err - } - - graph, err := NewChannelGraph( - backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, - true, false, - ) - if err != nil { - backendCleanup() - return nil, err + rev = [chainhash.HashSize]byte{ + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, } +) - t.Cleanup(func() { - _ = backend.Close() - backendCleanup() - }) - - return graph, nil -} +func createLightningNode(_ kvdb.Backend, priv *btcec.PrivateKey) ( + *models.LightningNode, error) { -func createLightningNode(db kvdb.Backend, priv *btcec.PrivateKey) (*LightningNode, error) { updateTime := prand.Int63() pub := priv.PubKey().SerializeCompressed() - n := &LightningNode{ + n := &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), LastUpdate: time.Unix(updateTime, 0), @@ -103,7 +86,7 @@ func createLightningNode(db kvdb.Backend, priv *btcec.PrivateKey) (*LightningNod return n, nil } -func createTestVertex(db kvdb.Backend) (*LightningNode, error) { +func createTestVertex(db kvdb.Backend) (*models.LightningNode, error) { priv, err := btcec.NewPrivateKey() if err != nil { return nil, err @@ -120,7 +103,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // We'd like to test basic insertion/deletion for vertexes from the // graph, so we'll create a test vertex to start with. - node := &LightningNode{ + node := &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), LastUpdate: time.Unix(1232342, 0), @@ -144,7 +127,8 @@ func TestNodeInsertionAndDeletion(t *testing.T) { dbNode, err := graph.FetchLightningNode(testPub) require.NoError(t, err, "unable to locate node") - if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { + _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes) + if err != nil { t.Fatalf("unable to query for node: %v", err) } else if !exists { t.Fatalf("node should be found but wasn't") @@ -180,7 +164,7 @@ func TestPartialNode(t *testing.T) { // We want to be able to insert nodes into the graph that only has the // PubKey set. - node := &LightningNode{ + node := &models.LightningNode{ HaveNodeAnnouncement: false, PubKeyBytes: testPub, } @@ -195,7 +179,8 @@ func TestPartialNode(t *testing.T) { dbNode, err := graph.FetchLightningNode(testPub) require.NoError(t, err, "unable to locate node") - if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { + _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes) + if err != nil { t.Fatalf("unable to query for node: %v", err) } else if !exists { t.Fatalf("node should be found but wasn't") @@ -203,7 +188,7 @@ func TestPartialNode(t *testing.T) { // The two nodes should match exactly! (with default values for // LastUpdate and db set to satisfy compareNodes()) - node = &LightningNode{ + node = &models.LightningNode{ HaveNodeAnnouncement: false, LastUpdate: time.Unix(0, 0), PubKeyBytes: testPub, @@ -366,7 +351,8 @@ func TestEdgeInsertionDeletion(t *testing.T) { // Ensure that any query attempts to lookup the delete channel edge are // properly deleted. - if _, _, _, err := graph.FetchChannelEdgesByOutpoint(&outpoint); err == nil { + _, _, _, err = graph.FetchChannelEdgesByOutpoint(&outpoint) + if err == nil { t.Fatalf("channel edge not deleted") } if _, _, _, err := graph.FetchChannelEdgesByID(chanID); err == nil { @@ -386,7 +372,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { } func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (models.ChannelEdgeInfo, + node1, node2 *models.LightningNode) (models.ChannelEdgeInfo, lnwire.ShortChannelID) { shortChanID := lnwire.ShortChannelID{ @@ -548,12 +534,8 @@ func TestDisconnectBlockAtHeight(t *testing.T) { // at height 155. hash, h, err := graph.PruneTip() require.NoError(t, err, "unable to get prune tip") - if !blockHash.IsEqual(hash) { - t.Fatalf("expected best block to be %x, was %x", blockHash, hash) - } - if h != height-1 { - t.Fatalf("expected best block height to be %d, was %d", height-1, h) - } + require.True(t, blockHash.IsEqual(hash)) + require.Equal(t, h, height-1) } func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, @@ -587,20 +569,19 @@ func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, e2.Features) } - if !bytes.Equal(e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes) { - t.Fatalf("nodesig1 doesn't match: %v vs %v", - spew.Sdump(e1.AuthProof.NodeSig1Bytes), - spew.Sdump(e2.AuthProof.NodeSig1Bytes)) - } - if !bytes.Equal(e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes) { - t.Fatalf("nodesig2 doesn't match") - } - if !bytes.Equal(e1.AuthProof.BitcoinSig1Bytes, e2.AuthProof.BitcoinSig1Bytes) { - t.Fatalf("bitcoinsig1 doesn't match") - } - if !bytes.Equal(e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes) { - t.Fatalf("bitcoinsig2 doesn't match") - } + require.True(t, bytes.Equal( + e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes, + )) + require.True(t, bytes.Equal( + e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes, + )) + require.True(t, bytes.Equal( + e1.AuthProof.BitcoinSig1Bytes, + e2.AuthProof.BitcoinSig1Bytes, + )) + require.True(t, bytes.Equal( + e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes, + )) if e1.ChannelPoint != e2.ChannelPoint { t.Fatalf("channel point match: %v vs %v", e1.ChannelPoint, @@ -618,7 +599,7 @@ func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, } } -func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( +func createChannelEdge(db kvdb.Backend, node1, node2 *models.LightningNode) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) { @@ -779,7 +760,9 @@ func TestEdgeInfoUpdates(t *testing.T) { // Next, attempt to query the channel edges according to the outpoint // of the channel. - dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint(&outpoint) + dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint( + &outpoint, + ) require.NoError(t, err, "unable to fetch channel by ID") if err := compareEdgePolicies(dbEdge1, edge1); err != nil { t.Fatalf("edge doesn't match: %v", err) @@ -790,7 +773,7 @@ func TestEdgeInfoUpdates(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } -func assertNodeInCache(t *testing.T, g *ChannelGraph, n *LightningNode, +func assertNodeInCache(t *testing.T, g *ChannelGraph, n *models.LightningNode, expectedFeatures *lnwire.FeatureVector) { // Let's check the internal view first. @@ -1109,11 +1092,13 @@ func TestGraphTraversalCacheable(t *testing.T) { // Create a map of all nodes with the iteration we know works (because // it is tested in another test). nodeMap := make(map[route.Vertex]struct{}) - err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error { - nodeMap[n.PubKeyBytes] = struct{}{} + err = graph.ForEachNode( + func(tx kvdb.RTx, n *models.LightningNode) error { + nodeMap[n.PubKeyBytes] = struct{}{} - return nil - }) + return nil + }, + ) require.NoError(t, err) require.Len(t, nodeMap, numNodes) @@ -1182,8 +1167,8 @@ func TestGraphCacheTraversal(t *testing.T) { delete(chanIndex, d.ChannelID) if !d.OutPolicySet || d.InPolicy == nil { - return fmt.Errorf("channel policy not " + - "present") + return fmt.Errorf("channel policy " + + "not present") } // The incoming edge should also indicate that @@ -1212,9 +1197,9 @@ func TestGraphCacheTraversal(t *testing.T) { } func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, - numChannels int) (map[uint64]struct{}, []*LightningNode) { + numChannels int) (map[uint64]struct{}, []*models.LightningNode) { - nodes := make([]*LightningNode, numNodes) + nodes := make([]*models.LightningNode, numNodes) nodeIndex := map[string]struct{}{} for i := 0; i < numNodes; i++ { node, err := createTestVertex(graph.db) @@ -1232,10 +1217,12 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, // Iterate over each node as returned by the graph, if all nodes are // reached, then the map created above should be empty. - err := graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }) + err := graph.ForEachNode( + func(_ kvdb.RTx, node *models.LightningNode) error { + delete(nodeIndex, node.Alias) + return nil + }, + ) require.NoError(t, err) require.Len(t, nodeIndex, 0) @@ -1245,7 +1232,9 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, for n := 0; n < numNodes-1; n++ { node1 := nodes[n] node2 := nodes[n+1] - if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { + if bytes.Compare( + node1.PubKeyBytes[:], node2.PubKeyBytes[:], + ) == -1 { node1, node2 = node2, node1 } @@ -1299,8 +1288,8 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, return chanIndex, nodes } -func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, - blockHeight uint32) { +func assertPruneTip(t *testing.T, graph *ChannelGraph, + blockHash *chainhash.Hash, blockHeight uint32) { pruneHash, pruneHeight, err := graph.PruneTip() if err != nil { @@ -1340,10 +1329,12 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { numNodes := 0 - err := graph.ForEachNode(func(_ kvdb.RTx, _ *LightningNode) error { - numNodes++ - return nil - }) + err := graph.ForEachNode( + func(_ kvdb.RTx, _ *models.LightningNode) error { + numNodes++ + return nil + }, + ) if err != nil { _, _, line, _ := runtime.Caller(1) t.Fatalf("line %v: unable to scan nodes: %v", line, err) @@ -1351,7 +1342,8 @@ func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { if numNodes != n { _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes) + t.Fatalf("line %v: expected %v nodes, got %v", line, n, + numNodes) } } @@ -1375,7 +1367,9 @@ func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) { } } -func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) { +func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, + b []*wire.OutPoint) { + if len(a) != len(b) { _, _, line, _ := runtime.Caller(1) t.Fatalf("line %v: chan views don't match", line) @@ -1411,7 +1405,7 @@ func TestGraphPruning(t *testing.T) { // and enough edges to create a fully connected graph. The graph will // be rather simple, representing a straight line. const numNodes = 5 - graphNodes := make([]*LightningNode, numNodes) + graphNodes := make([]*models.LightningNode, numNodes) for i := 0; i < numNodes; i++ { node, err := createTestVertex(graph.db) if err != nil { @@ -1454,13 +1448,17 @@ func TestGraphPruning(t *testing.T) { copy(edgeInfo.NodeKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) copy(edgeInfo.NodeKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) copy(edgeInfo.BitcoinKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) + copy( + edgeInfo.BitcoinKey2Bytes[:], + graphNodes[i+1].PubKeyBytes[:], + ) if err := graph.AddChannelEdge(&edgeInfo); err != nil { t.Fatalf("unable to add node: %v", err) } pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], + edgeInfo.BitcoinKey1Bytes[:], + edgeInfo.BitcoinKey2Bytes[:], ) if err != nil { t.Fatalf("unable to gen multi-sig p2wsh: %v", err) @@ -1792,7 +1790,9 @@ func TestChanUpdatesInHorizon(t *testing.T) { assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) - err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) + err := compareEdgePolicies( + chanExp.Policy1, chanRet.Policy1, + ) if err != nil { t.Fatal(err) } @@ -1829,7 +1829,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { // We'll create 10 node announcements, each with an update timestamp 10 // seconds after the other. const numNodes = 10 - nodeAnns := make([]LightningNode, 0, numNodes) + nodeAnns := make([]models.LightningNode, 0, numNodes) for i := 0; i < numNodes; i++ { nodeAnn, err := createTestVertex(graph.db) if err != nil { @@ -1855,7 +1855,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { start time.Time end time.Time - resp []LightningNode + resp []models.LightningNode }{ // If we query for a time range that's strictly below our set // of updates, then we'll get an empty result back. @@ -1899,7 +1899,9 @@ func TestNodeUpdatesInHorizon(t *testing.T) { }, } for _, queryCase := range queryCases { - resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) + resp, err := graph.NodeUpdatesInHorizon( + queryCase.start, queryCase.end, + ) if err != nil { t.Fatalf("unable to query for nodes: %v", err) } @@ -2428,7 +2430,7 @@ func TestFilterChannelRange(t *testing.T) { ) updateTimeSeed := time.Now().Unix() - maybeAddPolicy := func(chanID uint64, node *LightningNode, + maybeAddPolicy := func(chanID uint64, node *models.LightningNode, node2 bool) time.Time { var chanFlags lnwire.ChanUpdateChanFlags @@ -2745,7 +2747,9 @@ func TestIncompleteChannelPolicies(t *testing.T) { } // Ensure that channel is reported with unknown policies. - checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { + checkPolicies := func(node *models.LightningNode, expectedIn, + expectedOut bool) { + calls := 0 err := graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, @@ -3066,9 +3070,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { node2, err = graph.FetchLightningNode(node2.PubKeyBytes) require.NoError(t, err, "unable to fetch node2") - if node2.HaveNodeAnnouncement { - t.Fatalf("should have shell announcement for node2, but is full") - } + require.False(t, node2.HaveNodeAnnouncement) } // TestNodePruningUpdateIndexDeletion tests that once a node has been removed @@ -3165,7 +3167,7 @@ func TestNodeIsPublic(t *testing.T) { // After creating all of our nodes and edges, we'll add them to each // participant's graph. - nodes := []*LightningNode{aliceNode, bobNode, carolNode} + nodes := []*models.LightningNode{aliceNode, bobNode, carolNode} edges := []*models.ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for _, graph := range graphs { @@ -3183,17 +3185,19 @@ func TestNodeIsPublic(t *testing.T) { // checkNodes is a helper closure that will be used to assert that the // given nodes are seen as public/private within the given graphs. - checkNodes := func(nodes []*LightningNode, graphs []*ChannelGraph, - public bool) { + checkNodes := func(nodes []*models.LightningNode, + graphs []*ChannelGraph, public bool) { t.Helper() for _, node := range nodes { for _, graph := range graphs { - isPublic, err := graph.IsPublicNode(node.PubKeyBytes) + isPublic, err := graph.IsPublicNode( + node.PubKeyBytes, + ) if err != nil { - t.Fatalf("unable to determine if pivot "+ - "is public: %v", err) + t.Fatalf("unable to determine if "+ + "pivot is public: %v", err) } switch { @@ -3224,7 +3228,7 @@ func TestNodeIsPublic(t *testing.T) { } } checkNodes( - []*LightningNode{aliceNode}, + []*models.LightningNode{aliceNode}, []*ChannelGraph{bobGraph, carolGraph}, false, ) @@ -3255,7 +3259,7 @@ func TestNodeIsPublic(t *testing.T) { // With the modifications above, Bob should now be seen as a private // node from both Alice's and Carol's perspective. checkNodes( - []*LightningNode{bobNode}, + []*models.LightningNode{bobNode}, []*ChannelGraph{aliceGraph, carolGraph}, false, ) @@ -3298,8 +3302,8 @@ func TestDisabledChannelIDs(t *testing.T) { disabledChanIds, err := graph.DisabledChannelIDs() require.NoError(t, err, "unable to get disabled channel ids") if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) + t.Fatalf("expected empty disabled channels, got %v disabled "+ + "channels", len(disabledChanIds)) } // Add one disabled policy and ensure the channel is still not in the @@ -3311,8 +3315,8 @@ func TestDisabledChannelIDs(t *testing.T) { disabledChanIds, err = graph.DisabledChannelIDs() require.NoError(t, err, "unable to get disabled channel ids") if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) + t.Fatalf("expected empty disabled channels, got %v disabled "+ + "channels", len(disabledChanIds)) } // Add second disabled policy and ensure the channel is now in the @@ -3323,12 +3327,15 @@ func TestDisabledChannelIDs(t *testing.T) { } disabledChanIds, err = graph.DisabledChannelIDs() require.NoError(t, err, "unable to get disabled channel ids") - if len(disabledChanIds) != 1 || disabledChanIds[0] != edgeInfo.ChannelID { + if len(disabledChanIds) != 1 || + disabledChanIds[0] != edgeInfo.ChannelID { + t.Fatalf("expected disabled channel with id %v, "+ "got %v", edgeInfo.ChannelID, disabledChanIds) } - // Delete the channel edge and ensure it is removed from the disabled list. + // Delete the channel edge and ensure it is removed from the disabled + // list. if err = graph.DeleteChannelEdges( false, true, edgeInfo.ChannelID, ); err != nil { @@ -3337,8 +3344,8 @@ func TestDisabledChannelIDs(t *testing.T) { disabledChanIds, err = graph.DisabledChannelIDs() require.NoError(t, err, "unable to get disabled channel ids") if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) + t.Fatalf("expected empty disabled channels, got %v disabled "+ + "channels", len(disabledChanIds)) } } @@ -3441,7 +3448,9 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { copy(indexKey[:], scratch[:]) byteOrder.PutUint64(indexKey[8:], edge1.ChannelID) - updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + updateIndex, err := edges.CreateBucketIfNotExists( + edgeUpdateIndexBucket, + ) if err != nil { return err } @@ -3574,10 +3583,10 @@ func TestGraphZombieIndex(t *testing.T) { // compareNodes is used to compare two LightningNodes while excluding the // Features struct, which cannot be compared as the semantics for reserializing // the featuresMap have not been defined. -func compareNodes(a, b *LightningNode) error { +func compareNodes(a, b *models.LightningNode) error { if a.LastUpdate != b.LastUpdate { - return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ - "got %v", a.LastUpdate, b.LastUpdate) + return fmt.Errorf("node LastUpdate doesn't match: expected "+ + "%v, got %v", a.LastUpdate, b.LastUpdate) } if !reflect.DeepEqual(a.Addresses, b.Addresses) { return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ @@ -3596,8 +3605,9 @@ func compareNodes(a, b *LightningNode) error { "got %#v", a.Alias, b.Alias) } if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { - return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+ - "got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) + return fmt.Errorf("HaveNodeAnnouncement doesn't match: "+ + "expected %#v, got %#v", a.HaveNodeAnnouncement, + b.HaveNodeAnnouncement) } if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { return fmt.Errorf("extra data doesn't match: %v vs %v", @@ -3615,8 +3625,8 @@ func compareEdgePolicies(a, b *models.ChannelEdgePolicy) error { "got %v", a.ChannelID, b.ChannelID) } if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ - "got %#v", a.LastUpdate, b.LastUpdate) + return fmt.Errorf("edge LastUpdate doesn't match: "+ + "expected %#v, got %#v", a.LastUpdate, b.LastUpdate) } if a.MessageFlags != b.MessageFlags { return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ @@ -4004,12 +4014,7 @@ func TestGraphLoading(t *testing.T) { defer backend.Close() defer backendCleanup() - opts := DefaultOptions() - graph, err := NewChannelGraph( - backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, - true, false, - ) + graph, err := NewChannelGraph(backend) require.NoError(t, err) // Populate the graph with test data. @@ -4019,11 +4024,7 @@ func TestGraphLoading(t *testing.T) { // Recreate the graph. This should cause the graph cache to be // populated. - graphReloaded, err := NewChannelGraph( - backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, - true, false, - ) + graphReloaded, err := NewChannelGraph(backend) require.NoError(t, err) // Assert that the cache content is identical. diff --git a/graph/db/log.go b/graph/db/log.go new file mode 100644 index 0000000000..242e78c99a --- /dev/null +++ b/graph/db/log.go @@ -0,0 +1,31 @@ +package graphdb + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the logging code for this subsystem. +const Subsystem = "GRDB" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/models/cached_edge_policy.go b/graph/db/models/cached_edge_policy.go similarity index 100% rename from channeldb/models/cached_edge_policy.go rename to graph/db/models/cached_edge_policy.go diff --git a/channeldb/models/channel.go b/graph/db/models/channel.go similarity index 100% rename from channeldb/models/channel.go rename to graph/db/models/channel.go diff --git a/channeldb/models/channel_auth_proof.go b/graph/db/models/channel_auth_proof.go similarity index 100% rename from channeldb/models/channel_auth_proof.go rename to graph/db/models/channel_auth_proof.go diff --git a/channeldb/models/channel_edge_info.go b/graph/db/models/channel_edge_info.go similarity index 100% rename from channeldb/models/channel_edge_info.go rename to graph/db/models/channel_edge_info.go diff --git a/channeldb/models/channel_edge_policy.go b/graph/db/models/channel_edge_policy.go similarity index 100% rename from channeldb/models/channel_edge_policy.go rename to graph/db/models/channel_edge_policy.go diff --git a/channeldb/models/inbound_fee.go b/graph/db/models/inbound_fee.go similarity index 100% rename from channeldb/models/inbound_fee.go rename to graph/db/models/inbound_fee.go diff --git a/channeldb/models/inbound_fee_test.go b/graph/db/models/inbound_fee_test.go similarity index 100% rename from channeldb/models/inbound_fee_test.go rename to graph/db/models/inbound_fee_test.go diff --git a/graph/db/models/node.go b/graph/db/models/node.go new file mode 100644 index 0000000000..9624154339 --- /dev/null +++ b/graph/db/models/node.go @@ -0,0 +1,133 @@ +package models + +import ( + "fmt" + "image/color" + "net" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/lightningnetwork/lnd/lnwire" +) + +// LightningNode represents an individual vertex/node within the channel graph. +// A node is connected to other nodes by one or more channel edges emanating +// from it. As the graph is directed, a node will also have an incoming edge +// attached to it for each outgoing edge. +type LightningNode struct { + // PubKeyBytes is the raw bytes of the public key of the target node. + PubKeyBytes [33]byte + pubKey *btcec.PublicKey + + // HaveNodeAnnouncement indicates whether we received a node + // announcement for this particular node. If true, the remaining fields + // will be set, if false only the PubKey is known for this node. + HaveNodeAnnouncement bool + + // LastUpdate is the last time the vertex information for this node has + // been updated. + LastUpdate time.Time + + // Address is the TCP address this node is reachable over. + Addresses []net.Addr + + // Color is the selected color for the node. + Color color.RGBA + + // Alias is a nick-name for the node. The alias can be used to confirm + // a node's identity or to serve as a short ID for an address book. + Alias string + + // AuthSigBytes is the raw signature under the advertised public key + // which serves to authenticate the attributes announced by this node. + AuthSigBytes []byte + + // Features is the list of protocol features supported by this node. + Features *lnwire.FeatureVector + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + // TODO(roasbeef): discovery will need storage to keep it's last IP + // address and re-announce if interface changes? + + // TODO(roasbeef): add update method and fetch? +} + +// PubKey is the node's long-term identity public key. This key will be used to +// authenticated any advertisements/updates sent by the node. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (l *LightningNode) PubKey() (*btcec.PublicKey, error) { + if l.pubKey != nil { + return l.pubKey, nil + } + + key, err := btcec.ParsePubKey(l.PubKeyBytes[:]) + if err != nil { + return nil, err + } + l.pubKey = key + + return key, nil +} + +// AuthSig is a signature under the advertised public key which serves to +// authenticate the attributes announced by this node. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (l *LightningNode) AuthSig() (*ecdsa.Signature, error) { + return ecdsa.ParseSignature(l.AuthSigBytes) +} + +// AddPubKey is a setter-link method that can be used to swap out the public +// key for a node. +func (l *LightningNode) AddPubKey(key *btcec.PublicKey) { + l.pubKey = key + copy(l.PubKeyBytes[:], key.SerializeCompressed()) +} + +// NodeAnnouncement retrieves the latest node announcement of the node. +func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, + error) { + + if !l.HaveNodeAnnouncement { + return nil, fmt.Errorf("node does not have node announcement") + } + + alias, err := lnwire.NewNodeAlias(l.Alias) + if err != nil { + return nil, err + } + + nodeAnn := &lnwire.NodeAnnouncement{ + Features: l.Features.RawFeatureVector, + NodeID: l.PubKeyBytes, + RGBColor: l.Color, + Alias: alias, + Addresses: l.Addresses, + Timestamp: uint32(l.LastUpdate.Unix()), + ExtraOpaqueData: l.ExtraOpaqueData, + } + + if !signed { + return nodeAnn, nil + } + + sig, err := lnwire.NewSigFromECDSARawSignature(l.AuthSigBytes) + if err != nil { + return nil, err + } + + nodeAnn.Signature = sig + + return nodeAnn, nil +} diff --git a/graph/db/options.go b/graph/db/options.go new file mode 100644 index 0000000000..a512ec4bce --- /dev/null +++ b/graph/db/options.go @@ -0,0 +1,100 @@ +package graphdb + +import "time" + +const ( + // DefaultRejectCacheSize is the default number of rejectCacheEntries to + // cache for use in the rejection cache of incoming gossip traffic. This + // produces a cache size of around 1MB. + DefaultRejectCacheSize = 50000 + + // DefaultChannelCacheSize is the default number of ChannelEdges cached + // in order to reply to gossip queries. This produces a cache size of + // around 40MB. + DefaultChannelCacheSize = 20000 + + // DefaultPreAllocCacheNumNodes is the default number of channels we + // assume for mainnet for pre-allocating the graph cache. As of + // September 2021, there currently are 14k nodes in a strictly pruned + // graph, so we choose a number that is slightly higher. + DefaultPreAllocCacheNumNodes = 15000 +) + +// Options holds parameters for tuning and customizing a graph.DB. +type Options struct { + // RejectCacheSize is the maximum number of rejectCacheEntries to hold + // in the rejection cache. + RejectCacheSize int + + // ChannelCacheSize is the maximum number of ChannelEdges to hold in the + // channel cache. + ChannelCacheSize int + + // BatchCommitInterval is the maximum duration the batch schedulers will + // wait before attempting to commit a pending set of updates. + BatchCommitInterval time.Duration + + // PreAllocCacheNumNodes is the number of nodes we expect to be in the + // graph cache, so we can pre-allocate the map accordingly. + PreAllocCacheNumNodes int + + // UseGraphCache denotes whether the in-memory graph cache should be + // used or a fallback version that uses the underlying database for + // path finding. + UseGraphCache bool + + // NoMigration specifies that underlying backend was opened in read-only + // mode and migrations shouldn't be performed. This can be useful for + // applications that use the channeldb package as a library. + NoMigration bool +} + +// DefaultOptions returns an Options populated with default values. +func DefaultOptions() *Options { + return &Options{ + RejectCacheSize: DefaultRejectCacheSize, + ChannelCacheSize: DefaultChannelCacheSize, + PreAllocCacheNumNodes: DefaultPreAllocCacheNumNodes, + UseGraphCache: true, + NoMigration: false, + } +} + +// OptionModifier is a function signature for modifying the default Options. +type OptionModifier func(*Options) + +// WithRejectCacheSize sets the RejectCacheSize to n. +func WithRejectCacheSize(n int) OptionModifier { + return func(o *Options) { + o.RejectCacheSize = n + } +} + +// WithChannelCacheSize sets the ChannelCacheSize to n. +func WithChannelCacheSize(n int) OptionModifier { + return func(o *Options) { + o.ChannelCacheSize = n + } +} + +// WithPreAllocCacheNumNodes sets the PreAllocCacheNumNodes to n. +func WithPreAllocCacheNumNodes(n int) OptionModifier { + return func(o *Options) { + o.PreAllocCacheNumNodes = n + } +} + +// WithBatchCommitInterval sets the batch commit interval for the interval batch +// schedulers. +func WithBatchCommitInterval(interval time.Duration) OptionModifier { + return func(o *Options) { + o.BatchCommitInterval = interval + } +} + +// WithUseGraphCache sets the UseGraphCache option to the given value. +func WithUseGraphCache(use bool) OptionModifier { + return func(o *Options) { + o.UseGraphCache = use + } +} diff --git a/channeldb/reject_cache.go b/graph/db/reject_cache.go similarity index 99% rename from channeldb/reject_cache.go rename to graph/db/reject_cache.go index acadb8780b..2a2721928b 100644 --- a/channeldb/reject_cache.go +++ b/graph/db/reject_cache.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb // rejectFlags is a compact representation of various metadata stored by the // reject cache about a particular channel. diff --git a/channeldb/reject_cache_test.go b/graph/db/reject_cache_test.go similarity index 99% rename from channeldb/reject_cache_test.go rename to graph/db/reject_cache_test.go index 6974f42573..f64c39c33d 100644 --- a/channeldb/reject_cache_test.go +++ b/graph/db/reject_cache_test.go @@ -1,4 +1,4 @@ -package channeldb +package graphdb import ( "reflect" diff --git a/graph/db/setup_test.go b/graph/db/setup_test.go new file mode 100644 index 0000000000..ce0e5c7b95 --- /dev/null +++ b/graph/db/setup_test.go @@ -0,0 +1,11 @@ +package graphdb + +import ( + "testing" + + "github.com/lightningnetwork/lnd/kvdb" +) + +func TestMain(m *testing.M) { + kvdb.RunTests(m) +} diff --git a/channeldb/graphsession/graph_session.go b/graph/graphsession/graph_session.go similarity index 95% rename from channeldb/graphsession/graph_session.go rename to graph/graphsession/graph_session.go index 30f1903287..6976fad79b 100644 --- a/channeldb/graphsession/graph_session.go +++ b/graph/graphsession/graph_session.go @@ -3,7 +3,7 @@ package graphsession import ( "fmt" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing" @@ -84,7 +84,7 @@ func (g *session) close() error { // // NOTE: Part of the routing.Graph interface. func (g *session) ForEachNodeChannel(nodePub route.Vertex, - cb func(channel *channeldb.DirectedChannel) error) error { + cb func(channel *graphdb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } @@ -129,7 +129,7 @@ type graph interface { // NOTE: if a nil tx is provided, then it is expected that the // implementation create a read only tx. ForEachNodeDirectedChannel(tx kvdb.RTx, node route.Vertex, - cb func(channel *channeldb.DirectedChannel) error) error + cb func(channel *graphdb.DirectedChannel) error) error // FetchNodeFeatures returns the features of a given node. If no // features are known for the node, an empty feature vector is returned. @@ -138,4 +138,4 @@ type graph interface { // A compile-time check to ensure that *channeldb.ChannelGraph implements the // graph interface. -var _ graph = (*channeldb.ChannelGraph)(nil) +var _ graph = (*graphdb.ChannelGraph)(nil) diff --git a/graph/interfaces.go b/graph/interfaces.go index 7ae79f9a9f..eb7f56603a 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -6,8 +6,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -23,7 +23,7 @@ type ChannelGraphSource interface { // AddNode is used to add information about a node to the router // database. If the node with this pubkey is not present in an existing // channel, it will be ignored. - AddNode(node *channeldb.LightningNode, + AddNode(node *models.LightningNode, op ...batch.SchedulerOption) error // AddEdge is used to add edge/channel to the topology of the router, @@ -69,8 +69,7 @@ type ChannelGraphSource interface { // ForAllOutgoingChannels is used to iterate over all channels // emanating from the "source" node which is the center of the // star-graph. - ForAllOutgoingChannels(cb func(tx kvdb.RTx, - c *models.ChannelEdgeInfo, + ForAllOutgoingChannels(cb func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy) error) error // CurrentBlockHeight returns the block height from POV of the router @@ -85,10 +84,10 @@ type ChannelGraphSource interface { // FetchLightningNode attempts to look up a target node by its identity // public key. channeldb.ErrGraphNodeNotFound is returned if the node // doesn't exist within the graph. - FetchLightningNode(route.Vertex) (*channeldb.LightningNode, error) + FetchLightningNode(route.Vertex) (*models.LightningNode, error) // ForEachNode is used to iterate over every node in the known graph. - ForEachNode(func(node *channeldb.LightningNode) error) error + ForEachNode(func(node *models.LightningNode) error) error } // DB is an interface describing a persisted Lightning Network graph. @@ -116,7 +115,7 @@ type DB interface { // channel within the known channel graph. The set of UTXO's (along with // their scripts) returned are the ones that need to be watched on // chain to detect channel closes on the resident blockchain. - ChannelView() ([]channeldb.EdgePoint, error) + ChannelView() ([]graphdb.EdgePoint, error) // PruneGraphNodes is a garbage collection method which attempts to // prune out any nodes from the channel graph that are currently @@ -129,7 +128,7 @@ type DB interface { // treated as the center node within a star-graph. This method may be // used to kick off a path finding algorithm in order to explore the // reachability of another node based off the source node. - SourceNode() (*channeldb.LightningNode, error) + SourceNode() (*models.LightningNode, error) // DisabledChannelIDs returns the channel ids of disabled channels. // A channel is disabled when two of the associated ChanelEdgePolicies @@ -142,13 +141,13 @@ type DB interface { // edges that exist at the time of the query. This can be used to // respond to peer queries that are seeking to fill in gaps in their // view of the channel graph. - FetchChanInfos(chanIDs []uint64) ([]channeldb.ChannelEdge, error) + FetchChanInfos(chanIDs []uint64) ([]graphdb.ChannelEdge, error) // ChanUpdatesInHorizon returns all the known channel edges which have // at least one edge that has an update timestamp within the specified // horizon. ChanUpdatesInHorizon(startTime, endTime time.Time) ( - []channeldb.ChannelEdge, error) + []graphdb.ChannelEdge, error) // DeleteChannelEdges removes edges with the given channel IDs from the // database and marks them as zombies. This ensures that we're unable to @@ -200,7 +199,7 @@ type DB interface { // update that node's information. Note that this method is expected to // only be called to update an already present node from a node // announcement, or to insert a node found in a channel update. - AddLightningNode(node *channeldb.LightningNode, + AddLightningNode(node *models.LightningNode, op ...batch.SchedulerOption) error // AddChannelEdge adds a new (undirected, blank) edge to the graph @@ -239,14 +238,14 @@ type DB interface { // FetchLightningNode attempts to look up a target node by its identity // public key. If the node isn't found in the database, then // ErrGraphNodeNotFound is returned. - FetchLightningNode(nodePub route.Vertex) (*channeldb.LightningNode, + FetchLightningNode(nodePub route.Vertex) (*models.LightningNode, error) // ForEachNode iterates through all the stored vertices/nodes in the // graph, executing the passed callback with each node encountered. If // the callback returns an error, then the transaction is aborted and // the iteration stops early. - ForEachNode(cb func(kvdb.RTx, *channeldb.LightningNode) error) error + ForEachNode(cb func(kvdb.RTx, *models.LightningNode) error) error // ForEachNodeChannel iterates through all channels of the given node, // executing the passed callback with an edge info structure and the diff --git a/graph/notifications.go b/graph/notifications.go index 14ea3d127d..76eabdb02f 100644 --- a/graph/notifications.go +++ b/graph/notifications.go @@ -10,8 +10,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" ) @@ -318,7 +317,7 @@ func addToTopologyChange(graph DB, update *TopologyChange, // Any node announcement maps directly to a NetworkNodeUpdate struct. // No further data munging or db queries are required. - case *channeldb.LightningNode: + case *models.LightningNode: pubKey, err := m.PubKey() if err != nil { return err diff --git a/graph/notifications_test.go b/graph/notifications_test.go index 09ebf1211b..39278bf13a 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -17,8 +17,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" @@ -77,14 +77,14 @@ var ( } ) -func createTestNode(t *testing.T) *channeldb.LightningNode { +func createTestNode(t *testing.T) *models.LightningNode { updateTime := prand.Int63() priv, err := btcec.NewPrivateKey() require.NoError(t, err) pub := priv.PubKey().SerializeCompressed() - n := &channeldb.LightningNode{ + n := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix(updateTime, 0), Addresses: testAddrs, @@ -99,7 +99,7 @@ func createTestNode(t *testing.T) *channeldb.LightningNode { } func randEdgePolicy(chanID *lnwire.ShortChannelID, - node *channeldb.LightningNode) (*models.ChannelEdgePolicy, error) { + node *models.LightningNode) (*models.ChannelEdgePolicy, error) { InboundFee := models.InboundFee{ Base: prand.Int31() * -1, @@ -315,7 +315,7 @@ func (m *mockChainView) Reset() { m.staleBlocks = make(chan *chainview.FilteredBlock, 10) } -func (m *mockChainView) UpdateFilter(ops []channeldb.EdgePoint, updateHeight uint32) error { +func (m *mockChainView) UpdateFilter(ops []graphdb.EdgePoint, _ uint32) error { m.Lock() defer m.Unlock() @@ -686,7 +686,7 @@ func TestNodeUpdateNotification(t *testing.T) { t.Fatalf("unable to add node: %v", err) } - assertNodeNtfnCorrect := func(t *testing.T, ann *channeldb.LightningNode, + assertNodeNtfnCorrect := func(t *testing.T, ann *models.LightningNode, nodeUpdate *NetworkNodeUpdate) { nodeKey, _ := ann.PubKey() @@ -1019,7 +1019,7 @@ func TestEncodeHexColor(t *testing.T) { type testCtx struct { builder *Builder - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph aliases map[string]route.Vertex @@ -1088,7 +1088,7 @@ func (c *testCtx) RestartBuilder(t *testing.T) { // makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing // purposes. -func makeTestGraph(t *testing.T, useCache bool) (*channeldb.ChannelGraph, +func makeTestGraph(t *testing.T, useCache bool) (*graphdb.ChannelGraph, kvdb.Backend, error) { // Create channelgraph for the first time. @@ -1099,11 +1099,8 @@ func makeTestGraph(t *testing.T, useCache bool) (*channeldb.ChannelGraph, t.Cleanup(backendCleanup) - opts := channeldb.DefaultOptions() - graph, err := channeldb.NewChannelGraph( - backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, - useCache, false, + graph, err := graphdb.NewChannelGraph( + backend, graphdb.WithUseGraphCache(useCache), ) if err != nil { return nil, nil, err @@ -1113,7 +1110,7 @@ func makeTestGraph(t *testing.T, useCache bool) (*channeldb.ChannelGraph, } type testGraphInstance struct { - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph graphBackend kvdb.Backend // aliasMap is a map from a node's alias to its public key. This type is diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index 98d910d899..a97709605e 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -4,8 +4,7 @@ import ( "fmt" "sync" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -151,7 +150,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { case *lnwire.NodeAnnouncement: // TODO(roasbeef): node ann needs to wait on existing channel updates return - case *channeldb.LightningNode: + case *models.LightningNode: return case *lnwire.AnnounceSignatures1: // TODO(roasbeef): need to wait on chan ann? @@ -195,7 +194,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy, scid=%v", msg.ChannelID) - case *channeldb.LightningNode: + case *models.LightningNode: vertex := route.Vertex(msg.PubKeyBytes) signals, ok = v.nodeAnnDependencies[vertex] @@ -291,7 +290,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { // For all other job types, we'll delete the tracking entries from the // map, as if we reach this point, then all dependants have already // finished executing and we can proceed. - case *channeldb.LightningNode: + case *models.LightningNode: delete(v.nodeAnnDependencies, route.Vertex(msg.PubKeyBytes)) case *lnwire.NodeAnnouncement: delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) diff --git a/htlcswitch/circuit.go b/htlcswitch/circuit.go index 700b087f52..eab1cdb200 100644 --- a/htlcswitch/circuit.go +++ b/htlcswitch/circuit.go @@ -5,7 +5,7 @@ import ( "io" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index 108692719b..ddad11aca9 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -625,9 +625,7 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB { path = t.TempDir() } - db, err := channeldb.Open(path) - require.NoError(t, err, "unable to open channel db") - t.Cleanup(func() { db.Close() }) + db := channeldb.OpenForTesting(t, path) return db } diff --git a/htlcswitch/held_htlc_set.go b/htlcswitch/held_htlc_set.go index e098a2e007..c04880dc3d 100644 --- a/htlcswitch/held_htlc_set.go +++ b/htlcswitch/held_htlc_set.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" ) // heldHtlcSet keeps track of outstanding intercepted forwards. It exposes diff --git a/htlcswitch/held_htlc_set_test.go b/htlcswitch/held_htlc_set_test.go index a0a5e5bb4b..ca1a1750bc 100644 --- a/htlcswitch/held_htlc_set_test.go +++ b/htlcswitch/held_htlc_set_test.go @@ -3,7 +3,7 @@ package htlcswitch import ( "testing" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/require" ) diff --git a/htlcswitch/htlcnotifier.go b/htlcswitch/htlcnotifier.go index d6da5327da..4d4d33374e 100644 --- a/htlcswitch/htlcnotifier.go +++ b/htlcswitch/htlcnotifier.go @@ -7,7 +7,7 @@ import ( "time" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 0d1fec9aed..c48436173f 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -8,8 +8,8 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index a3f4ff9a9a..d8f55afc69 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -6,8 +6,8 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 5790460055..b8f1ce8edd 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -15,9 +15,9 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 7259723403..4ee538a581 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -25,9 +25,9 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -2169,7 +2169,7 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, BaseFee: lnwire.NewMSatFromSatoshis(1), TimeLockDelta: 6, } - invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta) + invoiceRegistry = newMockRegistry(t) ) pCache := newMockPreimageCache() @@ -2267,7 +2267,6 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, t.Cleanup(func() { close(alicePeer.quit) - invoiceRegistry.cleanup() }) harness := singleLinkTestHarness{ diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 0a3364ae27..ce791bef32 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net" - "os" "path/filepath" "sync" "sync/atomic" @@ -23,10 +22,10 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" @@ -216,11 +215,7 @@ func initSwitchWithTempDB(t testing.TB, startingHeight uint32) (*Switch, error) { tempPath := filepath.Join(t.TempDir(), "switchdb") - db, err := channeldb.Open(tempPath) - if err != nil { - return nil, err - } - t.Cleanup(func() { db.Close() }) + db := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(startingHeight, db) if err != nil { @@ -254,9 +249,7 @@ func newMockServer(t testing.TB, name string, startingHeight uint32, t.Cleanup(func() { _ = htlcSwitch.Stop() }) - registry := newMockRegistry(defaultDelta) - - t.Cleanup(func() { registry.cleanup() }) + registry := newMockRegistry(t) return &mockServer{ t: t, @@ -977,37 +970,12 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { var _ ChannelLink = (*mockChannelLink)(nil) -func newDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := os.MkdirTemp("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - os.RemoveAll(tempDirName) - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - const testInvoiceCltvExpiry = 6 type mockInvoiceRegistry struct { settleChan chan lntypes.Hash registry *invoices.InvoiceRegistry - - cleanup func() } type mockChainNotifier struct { @@ -1024,11 +992,8 @@ func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( }, nil } -func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { - cdb, cleanup, err := newDB() - if err != nil { - panic(err) - } +func newMockRegistry(t testing.TB) *mockInvoiceRegistry { + cdb := channeldb.OpenForTesting(t, t.TempDir()) modifierMock := &invoices.MockHtlcModifier{} registry := invoices.NewRegistry( @@ -1046,7 +1011,6 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { return &mockInvoiceRegistry{ registry: registry, - cleanup: cleanup, } } diff --git a/htlcswitch/packet.go b/htlcswitch/packet.go index 31639dd5d1..e991858519 100644 --- a/htlcswitch/packet.go +++ b/htlcswitch/packet.go @@ -2,7 +2,7 @@ package htlcswitch import ( "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index 664197f765..f6def14652 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -101,11 +101,7 @@ func TestNetworkResultStore(t *testing.T) { const numResults = 4 - db, err := channeldb.Open(t.TempDir()) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { db.Close() }) + db := channeldb.OpenForTesting(t, t.TempDir()) store := newNetworkResultStore(db) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index cbc2a16dae..1a08275ec9 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -15,10 +15,10 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index f47deb0802..abfb8e4d5b 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -16,9 +16,9 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntest/mock" @@ -1002,9 +1002,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1096,9 +1094,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1192,9 +1188,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1286,9 +1280,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1385,9 +1377,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1471,9 +1461,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1541,9 +1529,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1622,9 +1608,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1698,9 +1682,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -1778,9 +1760,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf(err.Error()) } - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err, "unable reinit switch") @@ -1870,9 +1850,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { t.Fatalf(err.Error()) } - cdb3, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to reopen channeldb") - t.Cleanup(func() { cdb3.Close() }) + cdb3 := channeldb.OpenForTesting(t, tempPath) s3, err := initSwitchWithDB(testStartingHeight, cdb3) require.NoError(t, err, "unable reinit switch") @@ -3827,9 +3805,7 @@ func newInterceptableSwitchTestContext( tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err, "unable to open channeldb") - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err, "unable to init switch") @@ -4914,9 +4890,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err) - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err) @@ -4990,9 +4964,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { err = cdb.Close() require.NoError(t, err) - cdb2, err := channeldb.Open(tempPath) - require.NoError(t, err) - t.Cleanup(func() { cdb2.Close() }) + cdb2 := channeldb.OpenForTesting(t, tempPath) s2, err := initSwitchWithDB(testStartingHeight, cdb2) require.NoError(t, err) @@ -5130,9 +5102,7 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err) - defer cdb.Close() + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err) @@ -5471,9 +5441,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { tempPath := t.TempDir() - cdb, err := channeldb.Open(tempPath) - require.NoError(t, err) - t.Cleanup(func() { cdb.Close() }) + cdb := channeldb.OpenForTesting(t, tempPath) s, err := initSwitchWithDB(testStartingHeight, cdb) require.NoError(t, err) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index cdb4f1f4ea..0f4b28fb82 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -24,8 +24,8 @@ import ( "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" @@ -251,21 +251,8 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, return nil, nil, err } - dbAlice, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - - dbBob, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbAlice := channeldb.OpenForTesting(t, t.TempDir()) + dbBob := channeldb.OpenForTesting(t, t.TempDir()) estimator := chainfee.NewStaticEstimator(6000, 0) feePerKw, err := estimator.EstimateFeePerKW(1) @@ -403,11 +390,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, switch err { case nil: case kvdb.ErrDatabaseNotOpen: - dbAlice, err = channeldb.Open(dbAlice.Path()) - if err != nil { - return nil, errors.Errorf("unable to reopen alice "+ - "db: %v", err) - } + dbAlice = channeldb.OpenForTesting(t, dbAlice.Path()) aliceStoredChannels, err = dbAlice.ChannelStateDB(). FetchOpenChannels(aliceKeyPub) @@ -451,7 +434,7 @@ func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte, switch err { case nil: case kvdb.ErrDatabaseNotOpen: - dbBob, err = channeldb.Open(dbBob.Path()) + dbBob = channeldb.OpenForTesting(t, dbBob.Path()) if err != nil { return nil, errors.Errorf("unable to reopen bob "+ "db: %v", err) diff --git a/invoices/interface.go b/invoices/interface.go index c906da1c3f..526378a2b6 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -5,7 +5,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" diff --git a/invoices/invoices_test.go b/invoices/invoices_test.go index b6efd6e557..33e7ebb5bd 100644 --- a/invoices/invoices_test.go +++ b/invoices/invoices_test.go @@ -11,9 +11,9 @@ import ( "time" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/feature" + "github.com/lightningnetwork/lnd/graph/db/models" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/invoices/sql_store.go b/invoices/sql_store.go index eb465eabb4..e848297d9c 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -11,8 +11,8 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" diff --git a/invoices/update_invoice.go b/invoices/update_invoice.go index a2de1b8f21..5286003e4b 100644 --- a/invoices/update_invoice.go +++ b/invoices/update_invoice.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" ) diff --git a/itest/lnd_channel_policy_test.go b/itest/lnd_channel_policy_test.go index bb05209753..9aa0f09b88 100644 --- a/itest/lnd_channel_policy_test.go +++ b/itest/lnd_channel_policy_test.go @@ -254,7 +254,8 @@ func testUpdateChannelPolicy(ht *lntest.HarnessTest) { ChanPoint: chanPoint, }, } - bob.RPC.UpdateChannelPolicy(req) + updateResp := bob.RPC.UpdateChannelPolicy(req) + require.Empty(ht, updateResp.FailedUpdates, 0) // Wait for all nodes to have seen the policy update done by Bob. assertNodesPolicyUpdate(ht, nodes, bob, expectedPolicy, chanPoint) diff --git a/lnrpc/devrpc/config_active.go b/lnrpc/devrpc/config_active.go index da5cd5be97..c5d43c194b 100644 --- a/lnrpc/devrpc/config_active.go +++ b/lnrpc/devrpc/config_active.go @@ -5,7 +5,7 @@ package devrpc import ( "github.com/btcsuite/btcd/chaincfg" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" ) @@ -16,6 +16,6 @@ import ( // also be specified. type Config struct { ActiveNetParams *chaincfg.Params - GraphDB *channeldb.ChannelGraph + GraphDB *graphdb.ChannelGraph Switch *htlcswitch.Switch } diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index ad135e8dfb..ebd4591abd 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -16,9 +16,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" @@ -227,7 +226,7 @@ func (s *Server) ImportGraph(ctx context.Context, var err error for _, rpcNode := range graph.Nodes { - node := &channeldb.LightningNode{ + node := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix( int64(rpcNode.LastUpdate), 0, diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index df274f2da7..59f7df610b 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -18,7 +18,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" @@ -75,7 +76,7 @@ type AddInvoiceConfig struct { ChanDB *channeldb.ChannelStateDB // Graph holds a reference to the ChannelGraph database. - Graph *channeldb.ChannelGraph + Graph *graphdb.ChannelGraph // GenInvoiceFeatures returns a feature containing feature bits that // should be advertised on freshly generated invoices. diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 76a529f8c6..546b9cc725 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" "github.com/stretchr/testify/mock" diff --git a/lnrpc/invoicesrpc/config_active.go b/lnrpc/invoicesrpc/config_active.go index 9568ed8919..14799c67ba 100644 --- a/lnrpc/invoicesrpc/config_active.go +++ b/lnrpc/invoicesrpc/config_active.go @@ -6,6 +6,7 @@ package invoicesrpc import ( "github.com/btcsuite/btcd/chaincfg" "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" @@ -53,7 +54,7 @@ type Config struct { // GraphDB is a global database instance which is needed to access the // channel graph. - GraphDB *channeldb.ChannelGraph + GraphDB *graphdb.ChannelGraph // ChanStateDB is a possibly replicated db instance which contains the // channel state diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index 614a11888c..7bc366dce0 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -3,8 +3,8 @@ package routerrpc import ( "errors" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/channel.go b/lnwallet/channel.go index fe4351c476..54ef442a9b 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -25,8 +25,8 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 8fc3d8b468..e3b7d07f90 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -25,8 +25,8 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index a8edb1e7e6..49b79a139d 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -4,7 +4,7 @@ import ( "crypto/sha256" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index c85aa20668..c006aa2e50 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -311,16 +311,14 @@ func loadTestCredits(miner *rpctest.Harness, w *lnwallet.LightningWallet, // createTestWallet creates a test LightningWallet will a total of 20BTC // available for funding channels. -func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, - netParams *chaincfg.Params, notifier chainntnfs.ChainNotifier, - wc lnwallet.WalletController, keyRing keychain.SecretKeyRing, - signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) { +func createTestWallet(t *testing.T, tempTestDir string, + miningNode *rpctest.Harness, netParams *chaincfg.Params, + notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController, + keyRing keychain.SecretKeyRing, signer input.Signer, + bio lnwallet.BlockChainIO) *lnwallet.LightningWallet { dbDir := filepath.Join(tempTestDir, "cdb") - fullDB, err := channeldb.Open(dbDir) - if err != nil { - return nil, err - } + fullDB := channeldb.OpenForTesting(t, dbDir) cfg := lnwallet.Config{ Database: fullDB.ChannelStateDB(), @@ -335,20 +333,18 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, } wallet, err := lnwallet.NewLightningWallet(cfg) - if err != nil { - return nil, err - } + require.NoError(t, err) - if err := wallet.Startup(); err != nil { - return nil, err - } + require.NoError(t, wallet.Startup()) + + t.Cleanup(func() { + require.NoError(t, wallet.Shutdown()) + }) // Load our test wallet with 20 outputs each holding 4BTC. - if err := loadTestCredits(miningNode, wallet, 20, 4); err != nil { - return nil, err - } + require.NoError(t, loadTestCredits(miningNode, wallet, 20, 4)) - return wallet, nil + return wallet } func testGetRecoveryInfo(miner *rpctest.Harness, @@ -3206,9 +3202,7 @@ func TestLightningWallet(t *testing.T, targetBackEnd string) { rpcConfig := miningNode.RPCConfig() - tempDir := t.TempDir() - db, err := channeldb.Open(tempDir) - require.NoError(t, err, "unable to create db") + db := channeldb.OpenForTesting(t, t.TempDir()) testCfg := channeldb.CacheConfig{ QueryDisable: false, } @@ -3450,20 +3444,16 @@ func runTests(t *testing.T, walletDriver *lnwallet.WalletDriver, } // Funding via 20 outputs with 4BTC each. - alice, err := createTestWallet( - tempTestDirAlice, miningNode, netParams, + alice := createTestWallet( + t, tempTestDirAlice, miningNode, netParams, chainNotifier, aliceWalletController, aliceKeyRing, aliceSigner, bio, ) - require.NoError(t, err, "unable to create test ln wallet") - defer alice.Shutdown() - bob, err := createTestWallet( - tempTestDirBob, miningNode, netParams, + bob := createTestWallet( + t, tempTestDirBob, miningNode, netParams, chainNotifier, bobWalletController, bobKeyRing, bobSigner, bio, ) - require.NoError(t, err, "unable to create test ln wallet") - defer bob.Shutdown() // Both wallets should now have 80BTC available for // spending. diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 2253232962..ff9adfbd79 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -243,21 +243,8 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, return nil, nil, err } - dbAlice, err := channeldb.Open(t.TempDir(), dbModifiers...) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) - }) - - dbBob, err := channeldb.Open(t.TempDir(), dbModifiers...) - if err != nil { - return nil, nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbAlice := channeldb.OpenForTesting(t, t.TempDir(), dbModifiers...) + dbBob := channeldb.OpenForTesting(t, t.TempDir(), dbModifiers...) estimator := chainfee.NewStaticEstimator(6000, 0) feePerKw, err := estimator.EstimateFeePerKW(1) diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 7912772962..135d1866bc 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -914,11 +914,8 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp ) // Create temporary databases. - dbRemote, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) - - dbLocal, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) + dbRemote := channeldb.OpenForTesting(t, t.TempDir()) + dbLocal := channeldb.OpenForTesting(t, t.TempDir()) // Create the initial commitment transactions for the channel. feePerKw := chainfee.SatPerKWeight(feeRate) diff --git a/log.go b/log.go index 343d86f38e..795fb4d729 100644 --- a/log.go +++ b/log.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" @@ -194,6 +195,7 @@ func SetupLoggers(root *build.SubLoggerManager, interceptor signal.Interceptor) AddSubLogger( root, blindedpath.Subsystem, interceptor, blindedpath.UseLogger, ) + AddV1SubLogger(root, graphdb.Subsystem, interceptor, graphdb.UseLogger) } // AddSubLogger is a helper method to conveniently create and register the diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index c4db4009dc..feb3a5dd19 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -195,7 +196,7 @@ func (m *ChanStatusManager) start() error { // have been pruned from the channel graph but not yet from our // set of channels. We'll skip it as we can't determine its // initial state. - case errors.Is(err, channeldb.ErrEdgeNotFound): + case errors.Is(err, graphdb.ErrEdgeNotFound): log.Warnf("Unable to find channel policies for %v, "+ "skipping. This is typical if the channel is "+ "in the process of closing.", c.FundingOutpoint) @@ -580,7 +581,7 @@ func (m *ChanStatusManager) disableInactiveChannels() { // that the channel has been closed. Thus we remove the // outpoint from the set of tracked outpoints to prevent // further attempts. - if errors.Is(err, channeldb.ErrEdgeNotFound) { + if errors.Is(err, graphdb.ErrEdgeNotFound) { log.Debugf("Removing channel(%v) from "+ "consideration for passive disabling", outpoint) diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index c709a95f27..320981d630 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -14,7 +14,8 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" @@ -168,7 +169,7 @@ func (g *mockGraph) FetchChannelEdgesByOutpoint( info, ok := g.chanInfos[*op] if !ok { - return nil, nil, nil, channeldb.ErrEdgeNotFound + return nil, nil, nil, graphdb.ErrEdgeNotFound } pol1 := g.chanPols1[*op] @@ -697,7 +698,7 @@ var stateMachineTests = []stateMachineTest{ // Request that they be enabled, which should return an // error as the graph doesn't have an edge for them. h.assertEnables( - unknownChans, channeldb.ErrEdgeNotFound, false, + unknownChans, graphdb.ErrEdgeNotFound, false, ) // No updates should be sent as a result of the failure. h.assertNoUpdates(h.safeDisableTimeout) @@ -717,7 +718,7 @@ var stateMachineTests = []stateMachineTest{ // Request that they be disabled, which should return an // error as the graph doesn't have an edge for them. h.assertDisables( - unknownChans, channeldb.ErrEdgeNotFound, false, + unknownChans, graphdb.ErrEdgeNotFound, false, ) // No updates should be sent as a result of the failure. h.assertNoUpdates(h.safeDisableTimeout) @@ -747,7 +748,9 @@ var stateMachineTests = []stateMachineTest{ // Check that trying to enable the channel with unknown // edges results in a failure. - h.assertEnables(newChans, channeldb.ErrEdgeNotFound, false) + h.assertEnables( + newChans, graphdb.ErrEdgeNotFound, false, + ) // Now, insert edge policies for the channel into the // graph, starting with the channel enabled, and mark @@ -794,7 +797,9 @@ var stateMachineTests = []stateMachineTest{ // Check that trying to enable the channel with unknown // edges results in a failure. - h.assertDisables(rmChans, channeldb.ErrEdgeNotFound, false) + h.assertDisables( + rmChans, graphdb.ErrEdgeNotFound, false, + ) }, }, { diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 9644a523ff..571e0584dc 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index 61db16b16e..e8c5799b41 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -10,7 +10,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" diff --git a/netann/channel_update.go b/netann/channel_update.go index af91abdd24..efc5cf61e4 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -9,7 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" diff --git a/netann/interface.go b/netann/interface.go index d6cdb46d0e..aa559435d4 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -3,7 +3,7 @@ package netann import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" ) // DB abstracts the required database functionality needed by the diff --git a/peer/brontide.go b/peer/brontide.go index bb4c9f96db..40d3f94c43 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -22,13 +22,14 @@ import ( "github.com/lightningnetwork/lnd/buffer" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -235,7 +236,7 @@ type Config struct { // ChannelGraph is a pointer to the channel graph which is used to // query information about the set of known active channels. - ChannelGraph *channeldb.ChannelGraph + ChannelGraph *graphdb.ChannelGraph // ChainArb is used to subscribe to channel events, update contract signals, // and force close channels. @@ -1098,7 +1099,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( info, p1, p2, err := graph.FetchChannelEdgesByOutpoint( &chanPoint, ) - if err != nil && !errors.Is(err, channeldb.ErrEdgeNotFound) { + if err != nil && !errors.Is(err, graphdb.ErrEdgeNotFound) { return nil, err } diff --git a/peer/test_utils.go b/peer/test_utils.go index 9034bb5a96..eb510a53b1 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -19,9 +19,11 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntypes" @@ -201,13 +203,7 @@ func createTestPeerWithChannel(t *testing.T, updateChan func(a, return nil, err } - dbBob, err := channeldb.Open(t.TempDir()) - if err != nil { - return nil, err - } - t.Cleanup(func() { - require.NoError(t, dbBob.Close()) - }) + dbBob := channeldb.OpenForTesting(t, t.TempDir()) feePerKw, err := estimator.EstimateFeePerKW(1) if err != nil { @@ -607,11 +603,22 @@ func createTestPeer(t *testing.T) *peerTestCtx { const chanActiveTimeout = time.Minute - dbAlice, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, dbAlice.Close()) + dbPath := t.TempDir() + + graphBackend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ + DBPath: dbPath, + DBFileName: "graph.db", + NoFreelistSync: true, + AutoCompact: false, + AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, + DBTimeout: kvdb.DefaultDBTimeout, }) + require.NoError(t, err) + + dbAliceGraph, err := graphdb.NewChannelGraph(graphBackend) + require.NoError(t, err) + + dbAliceChannel := channeldb.OpenForTesting(t, dbPath) nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner) @@ -620,8 +627,8 @@ func createTestPeer(t *testing.T) *peerTestCtx { ChanStatusSampleInterval: 30 * time.Second, ChanEnableTimeout: chanActiveTimeout, ChanDisableTimeout: 2 * time.Minute, - DB: dbAlice.ChannelStateDB(), - Graph: dbAlice.ChannelGraph(), + DB: dbAliceChannel.ChannelStateDB(), + Graph: dbAliceGraph, MessageSigner: nodeSignerAlice, OurPubKey: aliceKeyPub, OurKeyLoc: testKeyLoc, @@ -663,7 +670,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { mockSwitch := &mockMessageSwitch{} // TODO(yy): change ChannelNotifier to be an interface. - channelNotifier := channelnotifier.New(dbAlice.ChannelStateDB()) + channelNotifier := channelnotifier.New(dbAliceChannel.ChannelStateDB()) require.NoError(t, channelNotifier.Start()) t.Cleanup(func() { require.NoError(t, channelNotifier.Stop(), @@ -707,7 +714,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { Switch: mockSwitch, ChanActiveTimeout: chanActiveTimeout, InterceptSwitch: interceptableSwitch, - ChannelDB: dbAlice.ChannelStateDB(), + ChannelDB: dbAliceChannel.ChannelStateDB(), FeeEstimator: estimator, Wallet: wallet, ChainNotifier: notifier, @@ -749,7 +756,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { mockSwitch: mockSwitch, peer: alicePeer, notifier: notifier, - db: dbAlice, + db: dbAliceChannel, privKey: aliceKeyPriv, mockConn: mockConn, customChan: receivedCustomChan, diff --git a/pilot.go b/pilot.go index 2a37b080d0..11333a0722 100644 --- a/pilot.go +++ b/pilot.go @@ -282,7 +282,7 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot, ChannelInfo: func(chanPoint wire.OutPoint) ( *autopilot.LocalChannel, error) { - channel, err := svr.chanStateDB.FetchChannel(nil, chanPoint) + channel, err := svr.chanStateDB.FetchChannel(chanPoint) if err != nil { return nil, err } diff --git a/routing/additional_edge.go b/routing/additional_edge.go index 5f2d42eebc..80061206e6 100644 --- a/routing/additional_edge.go +++ b/routing/additional_edge.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 3b80dadc7c..12e82131dc 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -3,8 +3,8 @@ package routing import ( "fmt" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -96,7 +96,7 @@ func newBandwidthManager(graph Graph, sourceNode route.Vertex, // First, we'll collect the set of outbound edges from the target // source node and add them to our bandwidth manager's map of channels. err := graph.ForEachNodeChannel(sourceNode, - func(channel *channeldb.DirectedChannel) error { + func(channel *graphdb.DirectedChannel) error { shortID := lnwire.NewShortChanIDFromInt( channel.ChannelID, ) diff --git a/routing/blindedpath/blinded_path.go b/routing/blindedpath/blinded_path.go index bc14daa40a..a1f9db7b6b 100644 --- a/routing/blindedpath/blinded_path.go +++ b/routing/blindedpath/blinded_path.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" diff --git a/routing/blindedpath/blinded_path_test.go b/routing/blindedpath/blinded_path_test.go index 51d028eafb..35db89afb7 100644 --- a/routing/blindedpath/blinded_path_test.go +++ b/routing/blindedpath/blinded_path_test.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" diff --git a/routing/blinding.go b/routing/blinding.go index 270f998d9f..7c84063469 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -6,8 +6,8 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 950cb02107..410dfaf643 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -6,8 +6,8 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" diff --git a/routing/chainview/bitcoind.go b/routing/chainview/bitcoind.go index cbddb37c10..b528091acb 100644 --- a/routing/chainview/bitcoind.go +++ b/routing/chainview/bitcoind.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/blockcache" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" ) // BitcoindFilteredChainView is an implementation of the FilteredChainView @@ -448,7 +448,7 @@ func (b *BitcoindFilteredChainView) chainFilterer() { // rewound to ensure all relevant notifications are dispatched. // // NOTE: This is part of the FilteredChainView interface. -func (b *BitcoindFilteredChainView) UpdateFilter(ops []channeldb.EdgePoint, +func (b *BitcoindFilteredChainView) UpdateFilter(ops []graphdb.EdgePoint, updateHeight uint32) error { newUtxos := make([]wire.OutPoint, len(ops)) diff --git a/routing/chainview/btcd.go b/routing/chainview/btcd.go index cf08abafe9..2a06fd179f 100644 --- a/routing/chainview/btcd.go +++ b/routing/chainview/btcd.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/blockcache" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" ) // BtcdFilteredChainView is an implementation of the FilteredChainView @@ -456,7 +456,7 @@ type filterUpdate struct { // rewound to ensure all relevant notifications are dispatched. // // NOTE: This is part of the FilteredChainView interface. -func (b *BtcdFilteredChainView) UpdateFilter(ops []channeldb.EdgePoint, +func (b *BtcdFilteredChainView) UpdateFilter(ops []graphdb.EdgePoint, updateHeight uint32) error { newUtxos := make([]wire.OutPoint, len(ops)) diff --git a/routing/chainview/interface.go b/routing/chainview/interface.go index cfa9fccf48..454c2ee619 100644 --- a/routing/chainview/interface.go +++ b/routing/chainview/interface.go @@ -3,7 +3,7 @@ package chainview import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" ) // FilteredChainView represents a subscription to a certain subset of the @@ -43,7 +43,7 @@ type FilteredChainView interface { // relevant notifications are dispatched, meaning blocks with a height // lower than the best known height might be sent over the // FilteredBlocks() channel. - UpdateFilter(ops []channeldb.EdgePoint, updateHeight uint32) error + UpdateFilter(ops []graphdb.EdgePoint, updateHeight uint32) error // FilterBlock takes a block hash, and returns a FilteredBlocks which // is the result of applying the current registered UTXO sub-set on the diff --git a/routing/chainview/interface_test.go b/routing/chainview/interface_test.go index 3d97ff2c35..b953389afd 100644 --- a/routing/chainview/interface_test.go +++ b/routing/chainview/interface_test.go @@ -21,7 +21,7 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required to register the boltdb walletdb implementation. "github.com/lightninglabs/neutrino" "github.com/lightningnetwork/lnd/blockcache" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/unittest" "github.com/lightningnetwork/lnd/lntest/wait" @@ -218,7 +218,7 @@ func testFilterBlockNotifications(node *rpctest.Harness, require.NoError(t, err, "unable to get current height") // Now we'll add both outpoints to the current filter. - filter := []channeldb.EdgePoint{ + filter := []graphdb.EdgePoint{ {FundingPkScript: targetScript, OutPoint: *outPoint1}, {FundingPkScript: targetScript, OutPoint: *outPoint2}, } @@ -328,7 +328,7 @@ func testUpdateFilterBackTrack(node *rpctest.Harness, // After the block has been mined+notified we'll update the filter with // a _prior_ height so a "rewind" occurs. - filter := []channeldb.EdgePoint{ + filter := []graphdb.EdgePoint{ {FundingPkScript: testScript, OutPoint: *outPoint}, } err = chainView.UpdateFilter(filter, uint32(currentHeight)) @@ -417,7 +417,7 @@ func testFilterSingleBlock(node *rpctest.Harness, chainView FilteredChainView, // Now we'll manually trigger filtering the block generated above. // First, we'll add the two outpoints to our filter. - filter := []channeldb.EdgePoint{ + filter := []graphdb.EdgePoint{ {FundingPkScript: testScript, OutPoint: *outPoint1}, {FundingPkScript: testScript, OutPoint: *outPoint2}, } diff --git a/routing/chainview/neutrino.go b/routing/chainview/neutrino.go index 21f04ae95b..8a6d418363 100644 --- a/routing/chainview/neutrino.go +++ b/routing/chainview/neutrino.go @@ -12,7 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/neutrino" "github.com/lightningnetwork/lnd/blockcache" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lntypes" ) @@ -320,7 +320,7 @@ func (c *CfFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredB // rewound to ensure all relevant notifications are dispatched. // // NOTE: This is part of the FilteredChainView interface. -func (c *CfFilteredChainView) UpdateFilter(ops []channeldb.EdgePoint, +func (c *CfFilteredChainView) UpdateFilter(ops []graphdb.EdgePoint, updateHeight uint32) error { log.Tracef("Updating chain filter with new UTXO's: %v", ops) diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 2baad92f1e..5fb271afa6 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -47,16 +47,13 @@ var ( func TestControlTowerSubscribeUnknown(t *testing.T) { t.Parallel() - db, err := initDB(t, false) - require.NoError(t, err, "unable to init db") + db := initDB(t, false) pControl := NewControlTower(channeldb.NewPaymentControl(db)) // Subscription should fail when the payment is not known. - _, err = pControl.SubscribePayment(lntypes.Hash{1}) - if err != channeldb.ErrPaymentNotInitiated { - t.Fatal("expected subscribe to fail for unknown payment") - } + _, err := pControl.SubscribePayment(lntypes.Hash{1}) + require.ErrorIs(t, err, channeldb.ErrPaymentNotInitiated) } // TestControlTowerSubscribeSuccess tests that payment updates for a @@ -64,8 +61,7 @@ func TestControlTowerSubscribeUnknown(t *testing.T) { func TestControlTowerSubscribeSuccess(t *testing.T) { t.Parallel() - db, err := initDB(t, false) - require.NoError(t, err, "unable to init db") + db := initDB(t, false) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -184,8 +180,7 @@ func TestPaymentControlSubscribeFail(t *testing.T) { func TestPaymentControlSubscribeAllSuccess(t *testing.T) { t.Parallel() - db, err := initDB(t, true) - require.NoError(t, err, "unable to init db: %v") + db := initDB(t, true) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -298,8 +293,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { func TestPaymentControlSubscribeAllImmediate(t *testing.T) { t.Parallel() - db, err := initDB(t, true) - require.NoError(t, err, "unable to init db: %v") + db := initDB(t, true) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -336,8 +330,7 @@ func TestPaymentControlSubscribeAllImmediate(t *testing.T) { func TestPaymentControlUnsubscribeSuccess(t *testing.T) { t.Parallel() - db, err := initDB(t, true) - require.NoError(t, err, "unable to init db: %v") + db := initDB(t, true) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -406,8 +399,7 @@ func TestPaymentControlUnsubscribeSuccess(t *testing.T) { func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, keepFailedPaymentAttempts bool) { - db, err := initDB(t, keepFailedPaymentAttempts) - require.NoError(t, err, "unable to init db") + db := initDB(t, keepFailedPaymentAttempts) pControl := NewControlTower(channeldb.NewPaymentControl(db)) @@ -525,17 +517,12 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, } } -func initDB(t *testing.T, keepFailedPaymentAttempts bool) (*channeldb.DB, error) { - db, err := channeldb.Open( - t.TempDir(), channeldb.OptionKeepFailedPaymentAttempts( +func initDB(t *testing.T, keepFailedPaymentAttempts bool) *channeldb.DB { + return channeldb.OpenForTesting( + t, t.TempDir(), channeldb.OptionKeepFailedPaymentAttempts( keepFailedPaymentAttempts, ), ) - if err != nil { - return nil, err - } - - return db, err } func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, diff --git a/routing/graph.go b/routing/graph.go index 2b1c85bc8f..7608ee92bb 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -15,7 +15,7 @@ type Graph interface { // ForEachNodeChannel calls the callback for every channel of the given // node. ForEachNodeChannel(nodePub route.Vertex, - cb func(channel *channeldb.DirectedChannel) error) error + cb func(channel *graphdb.DirectedChannel) error) error // FetchNodeFeatures returns the features of the given node. FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 085ac1a9f3..315b0dff22 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -338,11 +338,11 @@ var _ GraphSessionFactory = (*mockGraphSessionFactory)(nil) var _ Graph = (*mockGraphSessionFactory)(nil) type mockGraphSessionFactoryChanDB struct { - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph } func newMockGraphSessionFactoryFromChanDB( - graph *channeldb.ChannelGraph) *mockGraphSessionFactoryChanDB { + graph *graphdb.ChannelGraph) *mockGraphSessionFactoryChanDB { return &mockGraphSessionFactoryChanDB{ graph: graph, @@ -368,11 +368,11 @@ func (g *mockGraphSessionFactoryChanDB) NewGraphSession() (Graph, func() error, var _ GraphSessionFactory = (*mockGraphSessionFactoryChanDB)(nil) type mockGraphSessionChanDB struct { - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph tx kvdb.RTx } -func newMockGraphSessionChanDB(graph *channeldb.ChannelGraph) Graph { +func newMockGraphSessionChanDB(graph *graphdb.ChannelGraph) Graph { return &mockGraphSessionChanDB{ graph: graph, } @@ -392,7 +392,7 @@ func (g *mockGraphSessionChanDB) close() error { } func (g *mockGraphSessionChanDB) ForEachNodeChannel(nodePub route.Vertex, - cb func(channel *channeldb.DirectedChannel) error) error { + cb func(channel *graphdb.DirectedChannel) error) error { return g.graph.ForEachNodeDirectedChannel(g.tx, nodePub, cb) } diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index 2639928e02..d7380439ac 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -10,10 +10,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing" @@ -40,14 +39,13 @@ type Manager struct { // ForAllOutgoingChannels is required to iterate over all our local // channels. The ChannelEdgePolicy parameter may be nil. - ForAllOutgoingChannels func(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, + ForAllOutgoingChannels func(cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error // FetchChannel is used to query local channel parameters. Optionally an // existing db tx can be supplied. - FetchChannel func(tx kvdb.RTx, chanPoint wire.OutPoint) ( - *channeldb.OpenChannel, error) + FetchChannel func(chanPoint wire.OutPoint) (*channeldb.OpenChannel, + error) // AddEdge is used to add edge/channel to the topology of the router. AddEdge func(edge *models.ChannelEdgeInfo) error @@ -83,9 +81,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, policiesToUpdate := make(map[wire.OutPoint]models.ForwardingPolicy) // NOTE: edge may be nil when this function is called. - processChan := func( - tx kvdb.RTx, - info *models.ChannelEdgeInfo, + processChan := func(info *models.ChannelEdgeInfo, edge *models.ChannelEdgePolicy) error { // If we have a channel filter, and this channel isn't a part @@ -114,9 +110,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, } // Apply the new policy to the edge. - err := r.updateEdge( - tx, info.ChannelPoint, edge, newSchema, - ) + err := r.updateEdge(info.ChannelPoint, edge, newSchema) if err != nil { failedUpdates = append(failedUpdates, makeFailureItem(info.ChannelPoint, @@ -164,7 +158,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // Construct a list of failed policy updates. for chanPoint := range unprocessedChans { - channel, err := r.FetchChannel(nil, chanPoint) + channel, err := r.FetchChannel(chanPoint) switch { case errors.Is(err, channeldb.ErrChannelNotFound): failedUpdates = append(failedUpdates, @@ -203,7 +197,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, channel, newSchema, ) if failedUpdate == nil { - err = processChan(nil, info, edge) + err = processChan(info, edge) if err != nil { return nil, err } @@ -261,7 +255,7 @@ func (r *Manager) createMissingEdge(channel *channeldb.OpenChannel, // Validate the newly created edge policy with the user defined new // schema before adding the edge to the database. - err = r.updateEdge(nil, channel.FundingOutpoint, edge, newSchema) + err = r.updateEdge(channel.FundingOutpoint, edge, newSchema) if err != nil { return nil, nil, makeFailureItem( info.ChannelPoint, @@ -351,11 +345,11 @@ func (r *Manager) createEdge(channel *channeldb.OpenChannel, } // updateEdge updates the given edge with the new schema. -func (r *Manager) updateEdge(tx kvdb.RTx, chanPoint wire.OutPoint, +func (r *Manager) updateEdge(chanPoint wire.OutPoint, edge *models.ChannelEdgePolicy, newSchema routing.ChannelPolicy) error { - channel, err := r.FetchChannel(tx, chanPoint) + channel, err := r.FetchChannel(chanPoint) if err != nil { return err } diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index f4512fab2c..d428a846be 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -11,10 +11,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing" @@ -123,20 +122,19 @@ func TestManager(t *testing.T) { return nil } - forAllOutgoingChannels := func(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, + forAllOutgoingChannels := func(cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { for _, c := range channelSet { - if err := cb(nil, c.edgeInfo, ¤tPolicy); err != nil { + if err := cb(c.edgeInfo, ¤tPolicy); err != nil { return err } } return nil } - fetchChannel := func(tx kvdb.RTx, chanPoint wire.OutPoint) ( - *channeldb.OpenChannel, error) { + fetchChannel := func(chanPoint wire.OutPoint) (*channeldb.OpenChannel, + error) { if chanPoint == chanPointMissing { return &channeldb.OpenChannel{}, channeldb.ErrChannelNotFound diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index de03412343..cab7c97266 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -7,8 +7,8 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -166,12 +166,12 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // // NOTE: Part of the Graph interface. func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, - cb func(channel *channeldb.DirectedChannel) error) error { + cb func(channel *graphdb.DirectedChannel) error) error { // Look up the mock node. node, ok := m.nodes[nodePub] if !ok { - return channeldb.ErrGraphNodeNotFound + return graphdb.ErrGraphNodeNotFound } // Iterate over all of its channels. @@ -188,7 +188,7 @@ func (m *mockGraph) ForEachNodeChannel(nodePub route.Vertex, // Call the per channel callback. err := cb( - &channeldb.DirectedChannel{ + &graphdb.DirectedChannel{ ChannelID: channel.id, IsNode1: nodePub == node1, OtherNode: peer, diff --git a/routing/mock_test.go b/routing/mock_test.go index 35ac3ecd99..3cdb5ebaf2 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -8,8 +8,8 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/pathfind.go b/routing/pathfind.go index 43eae71036..db474e1e80 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -11,10 +11,10 @@ import ( "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -496,7 +496,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi - cb := func(channel *channeldb.DirectedChannel) error { + cb := func(channel *graphdb.DirectedChannel) error { if !channel.OutPolicySet { return nil } @@ -1299,7 +1299,7 @@ func processNodeForBlindedPath(g Graph, node route.Vertex, // Now, iterate over the node's channels in search for paths to this // node that can be used for blinded paths err = g.ForEachNodeChannel(node, - func(channel *channeldb.DirectedChannel) error { + func(channel *graphdb.DirectedChannel) error { // Keep track of how many incoming channels this node // has. We only use a node as an introduction node if it // has channels other than the one that lead us to it. diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 72f71600dd..81708d3930 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -21,9 +21,9 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" switchhop "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" @@ -155,7 +155,7 @@ type testChan struct { // makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing // purposes. -func makeTestGraph(t *testing.T, useCache bool) (*channeldb.ChannelGraph, +func makeTestGraph(t *testing.T, useCache bool) (*graphdb.ChannelGraph, kvdb.Backend, error) { // Create channelgraph for the first time. @@ -166,11 +166,8 @@ func makeTestGraph(t *testing.T, useCache bool) (*channeldb.ChannelGraph, t.Cleanup(backendCleanup) - opts := channeldb.DefaultOptions() - graph, err := channeldb.NewChannelGraph( - backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, - useCache, false, + graph, err := graphdb.NewChannelGraph( + backend, graphdb.WithUseGraphCache(useCache), ) if err != nil { return nil, nil, err @@ -217,7 +214,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( privKeyMap := make(map[string]*btcec.PrivateKey) channelIDs := make(map[route.Vertex]map[route.Vertex]uint64) links := make(map[lnwire.ShortChannelID]htlcswitch.ChannelLink) - var source *channeldb.LightningNode + var source *models.LightningNode // First we insert all the nodes within the graph as vertexes. for _, node := range g.Nodes { @@ -226,7 +223,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( return nil, err } - dbNode := &channeldb.LightningNode{ + dbNode := &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), LastUpdate: testTime, @@ -357,7 +354,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } err = graph.AddChannelEdge(&edgeInfo) - if err != nil && err != channeldb.ErrEdgeAlreadyExist { + if err != nil && !errors.Is(err, graphdb.ErrEdgeAlreadyExist) { return nil, err } @@ -477,7 +474,7 @@ type testChannel struct { } type testGraphInstance struct { - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph graphBackend kvdb.Backend // aliasMap is a map from a node's alias to its public key. This type is @@ -539,7 +536,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, nodeIndex := byte(0) addNodeWithAlias := func(alias string, features *lnwire.FeatureVector) ( - *channeldb.LightningNode, error) { + *models.LightningNode, error) { keyBytes := []byte{ 0, 0, 0, 0, 0, 0, 0, 0, @@ -554,7 +551,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, features = lnwire.EmptyFeatureVector() } - dbNode := &channeldb.LightningNode{ + dbNode := &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), LastUpdate: testTime, @@ -665,7 +662,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } err = graph.AddChannelEdge(&edgeInfo) - if err != nil && err != channeldb.ErrEdgeAlreadyExist { + if err != nil && !errors.Is(err, graphdb.ErrEdgeAlreadyExist) { return nil, err } @@ -1210,7 +1207,7 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { dogePubKey, err := btcec.ParsePubKey(dogePubKeyBytes) require.NoError(t, err, "unable to parse public key from bytes") - doge := &channeldb.LightningNode{} + doge := &models.LightningNode{} doge.AddPubKey(dogePubKey) doge.Alias = "doge" copy(doge.PubKeyBytes[:], dogePubKeyBytes) @@ -3026,7 +3023,7 @@ func runInboundFees(t *testing.T, useCache bool) { type pathFindingTestContext struct { t *testing.T - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph restrictParams RestrictParams bandwidthHints bandwidthHints pathFindingConfig PathFindingConfig @@ -3108,7 +3105,7 @@ func (c *pathFindingTestContext) assertPath(path []*unifiedEdge, // dbFindPath calls findPath after getting a db transaction from the database // graph. -func dbFindPath(graph *channeldb.ChannelGraph, +func dbFindPath(graph *graphdb.ChannelGraph, additionalEdges map[route.Vertex][]AdditionalEdge, bandwidthHints bandwidthHints, r *RestrictParams, cfg *PathFindingConfig, @@ -3148,7 +3145,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, // dbFindBlindedPaths calls findBlindedPaths after getting a db transaction from // the database graph. -func dbFindBlindedPaths(graph *channeldb.ChannelGraph, +func dbFindBlindedPaths(graph *graphdb.ChannelGraph, restrictions *blindedPathRestrictions) ([][]blindedHop, error) { sourceNode, err := graph.SourceNode() diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 93b214bdea..267ce3965d 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -10,8 +10,8 @@ import ( "github.com/davecgh/go-spew/spew" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" diff --git a/routing/payment_session.go b/routing/payment_session.go index 3cbacad89c..0afdf822fb 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -6,7 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index ccee9bc449..d5f1a6af41 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -2,9 +2,8 @@ package routing import ( "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -24,7 +23,7 @@ type SessionSource struct { GraphSessionFactory GraphSessionFactory // SourceNode is the graph's source node. - SourceNode *channeldb.LightningNode + SourceNode *models.LightningNode // GetLink is a method that allows querying the lower link layer // to determine the up to date available bandwidth at a prospective link @@ -101,7 +100,7 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( // we'll need to look at the next hint's start node. If // we've reached the end of the hints list, we can // assume we've reached the destination. - endNode := &channeldb.LightningNode{} + endNode := &models.LightningNode{} if i != len(routeHint)-1 { endNode.AddPubKey(routeHint[i+1].NodeID) } else { diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index d510f77f96..278e090440 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -4,8 +4,7 @@ import ( "testing" "time" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -89,7 +88,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create a minimal test node using the private key priv1. pub := priv1.PubKey().SerializeCompressed() - testNode := &channeldb.LightningNode{} + testNode := &models.LightningNode{} copy(testNode.PubKeyBytes[:], pub) nodeID, err := testNode.PubKey() diff --git a/routing/router.go b/routing/router.go index 1fea60ddbe..b92aa15023 100644 --- a/routing/router.go +++ b/routing/router.go @@ -18,9 +18,9 @@ import ( sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" diff --git a/routing/router_test.go b/routing/router_test.go index 6f6f2e4342..db72bf266c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -22,10 +22,11 @@ import ( "github.com/go-errors/errors" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -68,7 +69,7 @@ type testCtx struct { graphBuilder *mockGraphBuilder - graph *channeldb.ChannelGraph + graph *graphdb.ChannelGraph aliases map[string]route.Vertex @@ -191,7 +192,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, return ctx } -func createTestNode() (*channeldb.LightningNode, error) { +func createTestNode() (*models.LightningNode, error) { updateTime := rand.Int63() priv, err := btcec.NewPrivateKey() @@ -200,7 +201,7 @@ func createTestNode() (*channeldb.LightningNode, error) { } pub := priv.PubKey().SerializeCompressed() - n := &channeldb.LightningNode{ + n := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix(updateTime, 0), Addresses: testAddrs, @@ -2898,7 +2899,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // Now check that we can update the node info for the partial node // without messing up the channel graph. - n1 := &channeldb.LightningNode{ + n1 := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix(123, 0), Addresses: testAddrs, @@ -2911,7 +2912,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { require.NoError(t, ctx.graph.AddLightningNode(n1)) - n2 := &channeldb.LightningNode{ + n2 := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix(123, 0), Addresses: testAddrs, diff --git a/routing/unified_edges.go b/routing/unified_edges.go index 6c44372e99..c2e008e473 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -4,8 +4,8 @@ import ( "math" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -95,7 +95,7 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error { - cb := func(channel *channeldb.DirectedChannel) error { + cb := func(channel *graphdb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have // come prior to the pivot node in the route. diff --git a/routing/unified_edges_test.go b/routing/unified_edges_test.go index 82605e9b37..8fc79031ac 100644 --- a/routing/unified_edges_test.go +++ b/routing/unified_edges_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" diff --git a/rpcserver.go b/rpcserver.go index d23066a1fb..d7d2e0186c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -41,8 +41,6 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/graphsession" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" @@ -51,6 +49,9 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/graph/graphsession" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" @@ -2691,7 +2692,7 @@ func (r *rpcServer) CloseChannel(in *lnrpc.CloseChannelRequest, // First, we'll fetch the channel as is, as we'll need to examine it // regardless of if this is a force close or not. - channel, err := r.server.chanStateDB.FetchChannel(nil, *chanPoint) + channel, err := r.server.chanStateDB.FetchChannel(*chanPoint) if err != nil { return err } @@ -3037,7 +3038,7 @@ func createRPCCloseUpdate( // abandonChanFromGraph attempts to remove a channel from the channel graph. If // we can't find the chanID in the graph, then we assume it has already been // removed, and will return a nop. -func abandonChanFromGraph(chanGraph *channeldb.ChannelGraph, +func abandonChanFromGraph(chanGraph *graphdb.ChannelGraph, chanPoint *wire.OutPoint) error { // First, we'll obtain the channel ID. If we can't locate this, then @@ -3045,7 +3046,7 @@ func abandonChanFromGraph(chanGraph *channeldb.ChannelGraph, // the graph, so we'll return a nil error. chanID, err := chanGraph.ChannelID(chanPoint) switch { - case errors.Is(err, channeldb.ErrEdgeNotFound): + case errors.Is(err, graphdb.ErrEdgeNotFound): return nil case err != nil: return err @@ -3139,7 +3140,7 @@ func (r *rpcServer) AbandonChannel(_ context.Context, return nil, err } - dbChan, err := r.server.chanStateDB.FetchChannel(nil, *chanPoint) + dbChan, err := r.server.chanStateDB.FetchChannel(*chanPoint) switch { // If the channel isn't found in the set of open channels, then we can // continue on as it can't be loaded into the link/peer. @@ -6532,7 +6533,9 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // First iterate through all the known nodes (connected or unconnected // within the graph), collating their current state into the RPC // response. - err := graph.ForEachNode(func(_ kvdb.RTx, node *channeldb.LightningNode) error { + err := graph.ForEachNode(func(_ kvdb.RTx, + node *models.LightningNode) error { + lnNode := marshalNode(node) resp.Nodes = append(resp.Nodes, lnNode) @@ -6562,7 +6565,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, return nil }) - if err != nil && err != channeldb.ErrGraphNoEdgesFound { + if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) { return nil, err } @@ -6808,7 +6811,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // be returned. node, err := graph.FetchLightningNode(pubKey) switch { - case err == channeldb.ErrGraphNodeNotFound: + case errors.Is(err, graphdb.ErrGraphNodeNotFound): return nil, status.Error(codes.NotFound, err.Error()) case err != nil: return nil, err @@ -6860,7 +6863,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, }, nil } -func marshalNode(node *channeldb.LightningNode) *lnrpc.LightningNode { +func marshalNode(node *models.LightningNode) *lnrpc.LightningNode { nodeAddrs := make([]*lnrpc.NodeAddress, len(node.Addresses)) for i, addr := range node.Addresses { nodeAddr := &lnrpc.NodeAddress{ @@ -6931,7 +6934,7 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context, // each node so we can measure the graph diameter and degree stats // below. err := graph.ForEachNodeCached(func(node route.Vertex, - edges map[uint64]*channeldb.DirectedChannel) error { + edges map[uint64]*graphdb.DirectedChannel) error { // Increment the total number of nodes with each iteration. numNodes++ diff --git a/rpcserver_test.go b/rpcserver_test.go index 53ec6d0ac3..b4b66e719c 100644 --- a/rpcserver_test.go +++ b/rpcserver_test.go @@ -41,11 +41,7 @@ func (m *mockDataParser) InlineParseCustomData(msg proto.Message) error { func TestAuxDataParser(t *testing.T) { // We create an empty channeldb, so we can fetch some channels. - cdb, err := channeldb.Open(t.TempDir()) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, cdb.Close()) - }) + cdb := channeldb.OpenForTesting(t, t.TempDir()) r := &rpcServer{ server: &server{ diff --git a/server.go b/server.go index c186d36560..d9e86db4b5 100644 --- a/server.go +++ b/server.go @@ -33,8 +33,6 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/graphsession" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/cluster" @@ -44,6 +42,9 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/graph" + graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/graph/graphsession" "github.com/lightningnetwork/lnd/healthcheck" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -258,11 +259,11 @@ type server struct { fundingMgr *funding.Manager - graphDB *channeldb.ChannelGraph + graphDB *graphdb.ChannelGraph chanStateDB *channeldb.ChannelStateDB - addrSource chanbackup.AddressSource + addrSource channeldb.AddrSource // miscDB is the DB that contains all "other" databases within the main // channel DB that haven't been separated out yet. @@ -606,12 +607,14 @@ func newServer(cfg *Config, listenAddrs []net.Addr, HtlcInterceptor: invoiceHtlcModifier, } + addrSource := channeldb.NewMultiAddrSource(dbs.ChanStateDB, dbs.GraphDB) + s := &server{ cfg: cfg, implCfg: implCfg, - graphDB: dbs.GraphDB.ChannelGraph(), + graphDB: dbs.GraphDB, chanStateDB: dbs.ChanStateDB.ChannelStateDB(), - addrSource: dbs.ChanStateDB, + addrSource: addrSource, miscDB: dbs.ChanStateDB, invoicesDB: dbs.InvoiceDB, cc: cc, @@ -763,7 +766,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, IsChannelActive: s.htlcSwitch.HasActiveLink, ApplyChannelUpdate: s.applyChannelUpdate, DB: s.chanStateDB, - Graph: dbs.GraphDB.ChannelGraph(), + Graph: dbs.GraphDB, } chanStatusMgr, err := netann.NewChanStatusManager(chanStatusMgrCfg) @@ -853,10 +856,6 @@ func newServer(cfg *Config, listenAddrs []net.Addr, selfAddrs := make([]net.Addr, 0, len(externalIPs)) selfAddrs = append(selfAddrs, externalIPs...) - // As the graph can be obtained at anytime from the network, we won't - // replicate it, and instead it'll only be stored locally. - chanGraph := dbs.GraphDB.ChannelGraph() - // We'll now reconstruct a node announcement based on our current // configuration so we can send it out as a sort of heart beat within // the network. @@ -878,7 +877,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, if err != nil { return nil, err } - selfNode := &channeldb.LightningNode{ + selfNode := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Now(), Addresses: selfAddrs, @@ -915,7 +914,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // Finally, we'll update the representation on disk, and update our // cached in-memory version as well. - if err := chanGraph.SetSourceNode(selfNode); err != nil { + if err := dbs.GraphDB.SetSourceNode(selfNode); err != nil { return nil, fmt.Errorf("can't set self node: %w", err) } s.currentNodeAnn = nodeAnn @@ -1015,13 +1014,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MinProbability: routingConfig.MinRouteProbability, } - sourceNode, err := chanGraph.SourceNode() + sourceNode, err := dbs.GraphDB.SourceNode() if err != nil { return nil, fmt.Errorf("error getting source node: %w", err) } paymentSessionSource := &routing.SessionSource{ GraphSessionFactory: graphsession.NewGraphSessionFactory( - chanGraph, + dbs.GraphDB, ), SourceNode: sourceNode, MissionControl: s.defaultMC, @@ -1038,7 +1037,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.graphBuilder, err = graph.NewBuilder(&graph.Config{ SelfNode: selfNode.PubKeyBytes, - Graph: chanGraph, + Graph: dbs.GraphDB, Chain: cc.ChainIO, ChainView: cc.ChainView, Notifier: cc.ChainNotifier, @@ -1055,7 +1054,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.chanRouter, err = routing.New(routing.Config{ SelfNode: selfNode.PubKeyBytes, - RoutingGraph: graphsession.NewRoutingGraph(chanGraph), + RoutingGraph: graphsession.NewRoutingGraph(dbs.GraphDB), Chain: cc.ChainIO, Payer: s.htlcSwitch, Control: s.controlTower, @@ -1131,17 +1130,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.localChanMgr = &localchans.Manager{ SelfPub: nodeKeyDesc.PubKey, DefaultRoutingPolicy: cc.RoutingPolicy, - ForAllOutgoingChannels: func(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { + ForAllOutgoingChannels: func(cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy) error) error { return s.graphDB.ForEachNodeChannel(selfVertex, - func(tx kvdb.RTx, c *models.ChannelEdgeInfo, + func(_ kvdb.RTx, c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { // NOTE: The invoked callback here may // receive a nil channel policy. - return cb(tx, c, e) + return cb(c, e) }, ) }, @@ -1389,7 +1388,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, info, e1, e2, err := s.graphDB.FetchChannelEdgesByID( scid.ToUint64(), ) - if errors.Is(err, channeldb.ErrEdgeNotFound) { + if errors.Is(err, graphdb.ErrEdgeNotFound) { // This is unlikely but there is a slim chance of this // being hit if lnd was killed via SIGKILL and the // funding manager was stepping through the delete @@ -1628,7 +1627,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // static backup of the latest channel state. chanNotifier := &channelNotifier{ chanNotifier: s.channelNotifier, - addrs: dbs.ChanStateDB, + addrs: s.addrSource, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) startingChans, err := chanbackup.FetchStaticChanBackups( @@ -3185,7 +3184,7 @@ func (s *server) createNewHiddenService() error { // Finally, we'll update the on-disk version of our announcement so it // will eventually propagate to nodes in the network. - selfNode := &channeldb.LightningNode{ + selfNode := &models.LightningNode{ HaveNodeAnnouncement: true, LastUpdate: time.Unix(int64(newNodeAnn.Timestamp), 0), Addresses: newNodeAnn.Addresses, @@ -3448,7 +3447,7 @@ func (s *server) establishPersistentConnections() error { nodeAddrsMap[pubStr] = n return nil }) - if err != nil && err != channeldb.ErrGraphNoEdgesFound { + if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) { return err } diff --git a/subrpcserver_config.go b/subrpcserver_config.go index 9e9295931b..30755c05e4 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lncfg" @@ -112,7 +113,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, chanRouter *routing.ChannelRouter, routerBackend *routerrpc.RouterBackend, nodeSigner *netann.NodeSigner, - graphDB *channeldb.ChannelGraph, + graphDB *graphdb.ChannelGraph, chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, diff --git a/witness_beacon.go b/witness_beacon.go index 2bc3c08509..fe78f3665f 100644 --- a/witness_beacon.go +++ b/witness_beacon.go @@ -5,8 +5,8 @@ import ( "sync" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes"