diff --git a/x/crosschain/keeper/cctx.go b/x/crosschain/keeper/cctx.go index f70cc3ecf8..32d002bc7d 100644 --- a/x/crosschain/keeper/cctx.go +++ b/x/crosschain/keeper/cctx.go @@ -33,22 +33,18 @@ func (k Keeper) SetCctxAndNonceToCctxAndInboundHashToCctx( cctx types.CrossChainTx, tssPubkey string, ) { - // set mapping nonce => cctxIndex - - if cctx.CctxStatus.Status == types.CctxStatus_PendingOutbound || - cctx.CctxStatus.Status == types.CctxStatus_PendingRevert { - k.GetObserverKeeper().SetNonceToCctx(ctx, observerTypes.NonceToCctx{ - ChainId: cctx.GetCurrentOutboundParam().ReceiverChainId, - // #nosec G115 always in range - Nonce: int64(cctx.GetCurrentOutboundParam().TssNonce), - CctxIndex: cctx.Index, - Tss: tssPubkey, - }) - } - + k.UpdateNonceToCCTX(ctx, cctx, tssPubkey) k.SetCrossChainTx(ctx, cctx) + k.UpdateInboundHashToCCTX(ctx, cctx) + k.UpdateZetaAccounting(ctx, cctx) +} - // set mapping inboundHash -> cctxIndex +// UpdateInboundHashToCCTX updates the mapping between an inbound hash and a cctx index. +// A new index is added to the list of cctx indexes if it is not already present +func (k Keeper) UpdateInboundHashToCCTX( + ctx sdk.Context, + cctx types.CrossChainTx, +) { in, _ := k.GetInboundHashToCctx(ctx, cctx.InboundParams.ObservedHash) in.InboundHash = cctx.InboundParams.ObservedHash found := false @@ -62,12 +58,36 @@ func (k Keeper) SetCctxAndNonceToCctxAndInboundHashToCctx( in.CctxIndex = append(in.CctxIndex, cctx.Index) } k.SetInboundHashToCctx(ctx, in) +} +func (k Keeper) UpdateZetaAccounting( + ctx sdk.Context, + cctx types.CrossChainTx, +) { if cctx.CctxStatus.Status == types.CctxStatus_Aborted && cctx.InboundParams.CoinType == coin.CoinType_Zeta { k.AddZetaAbortedAmount(ctx, GetAbortedAmount(cctx)) } } +// UpdateNonceToCCTX updates the mapping between a nonce and a cctx index if the cctx is in a PendingOutbound or PendingRevert state +func (k Keeper) UpdateNonceToCCTX( + ctx sdk.Context, + cctx types.CrossChainTx, + tssPubkey string, +) { + // set mapping nonce => cctxIndex + if cctx.CctxStatus.Status == types.CctxStatus_PendingOutbound || + cctx.CctxStatus.Status == types.CctxStatus_PendingRevert { + k.GetObserverKeeper().SetNonceToCctx(ctx, observerTypes.NonceToCctx{ + ChainId: cctx.GetCurrentOutboundParam().ReceiverChainId, + // #nosec G115 always in range + Nonce: int64(cctx.GetCurrentOutboundParam().TssNonce), + CctxIndex: cctx.Index, + Tss: tssPubkey, + }) + } +} + // SetCrossChainTx set a specific cctx in the store from its index func (k Keeper) SetCrossChainTx(ctx sdk.Context, cctx types.CrossChainTx) { // only set the update timestamp if the block height is >0 to allow diff --git a/x/crosschain/keeper/cctx_test.go b/x/crosschain/keeper/cctx_test.go index d1e3f6fafc..22f3c61049 100644 --- a/x/crosschain/keeper/cctx_test.go +++ b/x/crosschain/keeper/cctx_test.go @@ -453,3 +453,206 @@ func Test_NewCCTX(t *testing.T) { require.Equal(t, types.ProtocolContractVersion_V1, cctx.ProtocolContractVersion) }) } + +func TestKeeper_UpdateNonceToCCTX(t *testing.T) { + t.Run("should set nonce to cctx if status is PendingOutbound", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + chainID := chains.Ethereum.ChainId + nonce := uint64(10) + + cctx := types.CrossChainTx{Index: "test", + OutboundParams: []*types.OutboundParams{{ReceiverChainId: chainID, TssNonce: nonce}}, + CctxStatus: &types.Status{Status: types.CctxStatus_PendingOutbound}, + } + tssPubkey := "test-tss-pubkey" + + // Act + k.UpdateNonceToCCTX(ctx, cctx, tssPubkey) + + // Assert + nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tssPubkey, chainID, int64(nonce)) + require.True(t, found) + require.Equal(t, cctx.Index, nonceToCctx.CctxIndex) + require.Equal(t, tssPubkey, nonceToCctx.Tss) + require.Equal(t, chainID, nonceToCctx.ChainId) + }) + + t.Run("should set nonce to cctx if status is PendingRevert", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + chainID := chains.Ethereum.ChainId + nonce := uint64(10) + + cctx := types.CrossChainTx{Index: "test", + OutboundParams: []*types.OutboundParams{{ReceiverChainId: chainID, TssNonce: nonce}}, + CctxStatus: &types.Status{Status: types.CctxStatus_PendingRevert}, + } + tssPubkey := "test-tss-pubkey" + + // Act + k.UpdateNonceToCCTX(ctx, cctx, tssPubkey) + + // Assert + nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tssPubkey, chainID, int64(nonce)) + require.True(t, found) + require.Equal(t, cctx.Index, nonceToCctx.CctxIndex) + require.Equal(t, tssPubkey, nonceToCctx.Tss) + require.Equal(t, chainID, nonceToCctx.ChainId) + }) + + t.Run("should not set nonce to cctx if status is not PendingOutbound or PendingRevert", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + chainID := chains.Ethereum.ChainId + nonce := uint64(10) + + cctx := types.CrossChainTx{Index: "test", + OutboundParams: []*types.OutboundParams{{ReceiverChainId: chainID, TssNonce: nonce}}, + CctxStatus: &types.Status{Status: types.CctxStatus_Aborted}, + } + tssPubkey := "test-tss-pubkey" + + // Act + k.UpdateNonceToCCTX(ctx, cctx, tssPubkey) + + // Assert + _, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tssPubkey, chainID, int64(nonce)) + require.False(t, found) + }) +} + +func TestKeeper_UpdateInboundHashToCCTX(t *testing.T) { + t.Run( + "should update inbound hash to cctx mapping if new cctx index is found for the same inbound hash", + func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + inboundHash := sample.Hash().String() + index1 := sample.ZetaIndex(t) + index2 := sample.ZetaIndex(t) + + inboundHashToCctx := types.InboundHashToCctx{ + InboundHash: inboundHash, + CctxIndex: []string{index1}, + } + k.SetInboundHashToCctx(ctx, inboundHashToCctx) + cctx := types.CrossChainTx{Index: index2, InboundParams: &types.InboundParams{ObservedHash: inboundHash}} + + // Act + k.UpdateInboundHashToCCTX(ctx, cctx) + + // Assert + inboundHashToCctx, found := k.GetInboundHashToCctx(ctx, inboundHash) + require.True(t, found) + require.Equal(t, inboundHash, inboundHashToCctx.InboundHash) + require.Equal(t, 2, len(inboundHashToCctx.CctxIndex)) + require.Contains(t, inboundHashToCctx.CctxIndex, index1) + require.Contains(t, inboundHashToCctx.CctxIndex, index2) + }, + ) + + t.Run("should do nothing if the cctx index is already in the mapping", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + inboundHash := sample.Hash().String() + index := sample.ZetaIndex(t) + + inboundHashToCctx := types.InboundHashToCctx{ + InboundHash: inboundHash, + CctxIndex: []string{index}, + } + k.SetInboundHashToCctx(ctx, inboundHashToCctx) + cctx := types.CrossChainTx{Index: index, InboundParams: &types.InboundParams{ObservedHash: inboundHash}} + + // Act + k.UpdateInboundHashToCCTX(ctx, cctx) + + // Assert + inboundHashToCctx, found := k.GetInboundHashToCctx(ctx, inboundHash) + require.True(t, found) + require.Equal(t, inboundHash, inboundHashToCctx.InboundHash) + require.Equal(t, 1, len(inboundHashToCctx.CctxIndex)) + require.Contains(t, inboundHashToCctx.CctxIndex, index) + }) + + t.Run("should add cctx index to mapping if InboundHashToCctx is not found", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + inboundHash := sample.Hash().String() + index := sample.ZetaIndex(t) + + cctx := types.CrossChainTx{Index: index, InboundParams: &types.InboundParams{ObservedHash: inboundHash}} + + // Act + k.UpdateInboundHashToCCTX(ctx, cctx) + + // Assert + inboundHashToCctx, found := k.GetInboundHashToCctx(ctx, inboundHash) + require.True(t, found) + require.Equal(t, inboundHash, inboundHashToCctx.InboundHash) + require.Equal(t, 1, len(inboundHashToCctx.CctxIndex)) + require.Contains(t, inboundHashToCctx.CctxIndex, index) + }) +} + +func TestKeeper_UpdateZetaAccounting(t *testing.T) { + t.Run("should update zeta accounting if cctx is aborted and coin type is zeta", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + amount := sdkmath.NewUint(100) + cctx := types.CrossChainTx{ + InboundParams: &types.InboundParams{CoinType: coin.CoinType_Zeta}, + CctxStatus: &types.Status{Status: types.CctxStatus_Aborted}, + OutboundParams: []*types.OutboundParams{{Amount: amount}}, + } + k.SetZetaAccounting(ctx, types.ZetaAccounting{AbortedZetaAmount: math.ZeroUint()}) + + // Act + k.UpdateZetaAccounting(ctx, cctx) + + // Assert + zetaAccounting, found := k.GetZetaAccounting(ctx) + require.True(t, found) + require.Equal(t, amount, zetaAccounting.AbortedZetaAmount) + }) + + t.Run("should not update zeta accounting if cctx is not aborted", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + amount := sdkmath.NewUint(100) + cctx := types.CrossChainTx{ + InboundParams: &types.InboundParams{CoinType: coin.CoinType_Zeta}, + CctxStatus: &types.Status{Status: types.CctxStatus_PendingOutbound}, + OutboundParams: []*types.OutboundParams{{Amount: amount}}, + } + k.SetZetaAccounting(ctx, types.ZetaAccounting{AbortedZetaAmount: math.ZeroUint()}) + + // Act + k.UpdateZetaAccounting(ctx, cctx) + + // Assert + zetaAccounting, found := k.GetZetaAccounting(ctx) + require.True(t, found) + require.Equal(t, math.ZeroUint(), zetaAccounting.AbortedZetaAmount) + }) + + t.Run("should update to amount if zeta accounting is not set", func(t *testing.T) { + // Arrange + k, ctx, _, _ := keepertest.CrosschainKeeper(t) + amount := sdkmath.NewUint(100) + cctx := types.CrossChainTx{ + InboundParams: &types.InboundParams{CoinType: coin.CoinType_Zeta}, + CctxStatus: &types.Status{Status: types.CctxStatus_Aborted}, + OutboundParams: []*types.OutboundParams{{Amount: amount}}, + } + + // Act + k.UpdateZetaAccounting(ctx, cctx) + + // Assert + zetaAccounting, found := k.GetZetaAccounting(ctx) + require.True(t, found) + require.Equal(t, amount, zetaAccounting.AbortedZetaAmount) + }) +}