From d24002158c94cc835b6433f5fa23941b6f778b0a Mon Sep 17 00:00:00 2001 From: Farber98 Date: Fri, 6 Dec 2024 12:32:24 -0300 Subject: [PATCH] address feedback --- pkg/solana/txm/pendingtx.go | 6 ------ pkg/solana/txm/pendingtx_test.go | 6 ++++-- pkg/solana/txm/txm.go | 21 ++++++--------------- pkg/solana/txm/txm_internal_test.go | 9 ++++----- 4 files changed, 14 insertions(+), 28 deletions(-) diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index 262b60936..9230b99db 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -212,12 +212,6 @@ func (c *pendingTxContext) ListAllSigs() []solana.Signature { return maps.Keys(c.sigToID) } -func (c *pendingTxContext) ListAllTxsIDs() []string { - c.lock.RLock() - defer c.lock.RUnlock() - return maps.Values(c.sigToID) -} - // ListAllExpiredBroadcastedTxs returns all the txes that are in broadcasted state and have expired for given block height compared against their lastValidBlockHeight. // Passing maxUint64 as currHeight will return all broadcasted txes. func (c *pendingTxContext) ListAllExpiredBroadcastedTxs(currBlockHeight uint64) []pendingTx { diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index 531435c8d..10bf4cb0c 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -48,8 +48,10 @@ func TestPendingTxContext_add_remove_multiple(t *testing.T) { // cannot add signature for non existent ID require.Error(t, txs.AddSignature(uuid.New().String(), solana.Signature{})) - // return list of txsIds - list := txs.ListAllTxsIDs() + list := make([]string, 0, n) + for _, id := range txs.sigToID { + list = append(list, id) + } assert.Equal(t, n, len(list)) // stop all sub processes diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index afc3b9213..ac2bed40f 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -570,24 +570,21 @@ func (txm *Txm) handleFinalizedSignatureStatus(sig solanaGo.Signature) { // An expired tx is one where it's blockhash lastValidBlockHeight is smaller than the current slot height. // If any error occurs during rebroadcast attempt, they are discarded, and the function continues with the next transaction. func (txm *Txm) rebroadcastExpiredTxs(ctx context.Context, client client.ReaderWriter) { - currBlockHeight, err := client.GetLatestBlock(ctx) - if err != nil || currBlockHeight == nil || currBlockHeight.BlockHeight == nil { + currBlock, err := client.GetLatestBlock(ctx) + if err != nil || currBlock == nil || currBlock.BlockHeight == nil { txm.lggr.Errorw("failed to get current block height", "error", err) return } // Rebroadcast all expired txes - for _, tx := range txm.txs.ListAllExpiredBroadcastedTxs(*currBlockHeight.BlockHeight) { - txm.lggr.Debugw("transaction expired, rebroadcasting", "id", tx.id, "signature", tx.signatures, "lastValidBlockHeight", tx.lastValidBlockHeight, "currentBlockHeight", *currBlockHeight.BlockHeight) - if len(tx.signatures) == 0 { // prevent panic, shouldn't happen. - txm.lggr.Errorw("no signatures found for expired transaction", "id", tx.id) - continue - } + for _, tx := range txm.txs.ListAllExpiredBroadcastedTxs(*currBlock.BlockHeight) { + txm.lggr.Debugw("transaction expired, rebroadcasting", "id", tx.id, "signature", tx.signatures, "lastValidBlockHeight", tx.lastValidBlockHeight, "currentBlockHeight", *currBlock.BlockHeight) // Removes all signatures associated to tx and cancels context. _, err := txm.txs.Remove(tx.id) if err != nil { txm.lggr.Errorw("failed to remove expired transaction", "id", tx.id, "error", err) continue } + tx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice() // update compute unit price (priority fee) for rebroadcast rebroadcastTx := pendingTx{ tx: tx.tx, cfg: tx.cfg, @@ -720,17 +717,11 @@ func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Tran } msg := pendingTx{ + id: id, tx: *tx, cfg: cfg, } - // If ID was not set by caller, create one. - if txID != nil && *txID != "" { - msg.id = *txID - } else { - msg.id = uuid.New().String() - } - select { case txm.chSend <- msg: default: diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index b7680e8ad..3f8e0c070 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -978,7 +978,6 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(5 * time.Second) mc := mocks.NewReaderWriter(t) mc.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Maybe() - mc.On("SlotHeight", mock.Anything).Return(uint64(0), nil).Maybe() // mock solana keystore mkey := keyMocks.NewSimpleKeystore(t) @@ -1292,7 +1291,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { txExpirationRebroadcast := true statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} - // Mock getLatestBlock to return a value greater than 0 + // Mock getLatestBlock to return a value greater than 0 for blockHeight getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { val := uint64(1500) return &rpc.GetBlockResult{ @@ -1304,14 +1303,14 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { defer func() { callCount++ }() if callCount < 1 { - // To force rebroadcast, first call needs to be smaller than slotHeight + // To force rebroadcast, first call needs to be smaller than blockHeight return &rpc.GetLatestBlockhashResult{ Value: &rpc.LatestBlockhashResult{ LastValidBlockHeight: uint64(1000), }, }, nil } - // following rebroadcast call will go through because lastValidBlockHeight is bigger than slotHeight + // following rebroadcast call will go through because lastValidBlockHeight is bigger than blockHeight return &rpc.GetLatestBlockhashResult{ Value: &rpc.LatestBlockhashResult{ LastValidBlockHeight: uint64(2000), @@ -1438,7 +1437,7 @@ func TestTxm_ExpirationRebroadcast(t *testing.T) { } // Mock LatestBlockhash to return an invalid blockhash in the first 3 attempts (initial + 2 rebroadcasts) - // the last one is valid because it is greater than the slotHeight + // the last one is valid because it is greater than the blockHeight expectedRebroadcastsCount := 3 callCount := 0 latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) {